Skip to content

Commit 833e3ba

Browse files
authored
Add Config.LocalServerCallbackPath (#235)
1 parent a6300b0 commit 833e3ba

File tree

6 files changed

+121
-18
lines changed

6 files changed

+121
-18
lines changed

e2e_test/e2e_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func TestHappyPath(t *testing.T) {
3636
t.Errorf("scope wants %s but %s", want, req.Scope)
3737
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
3838
}
39-
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost") {
39+
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/") {
4040
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
4141
}
4242
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
@@ -106,7 +106,7 @@ func TestRedirectURLHostname(t *testing.T) {
106106
t.Errorf("scope wants %s but %s", want, req.Scope)
107107
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
108108
}
109-
if !assertRedirectURI(t, req.RedirectURI, "http", "127.0.0.1") {
109+
if !assertRedirectURI(t, req.RedirectURI, "http", "127.0.0.1", "/") {
110110
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
111111
}
112112
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
@@ -177,7 +177,7 @@ func TestSuccessRedirect(t *testing.T) {
177177
t.Errorf("scope wants %s but %s", want, req.Scope)
178178
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
179179
}
180-
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost") {
180+
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/") {
181181
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
182182
}
183183
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
@@ -242,7 +242,7 @@ func TestSuccessRedirect(t *testing.T) {
242242
wg.Wait()
243243
}
244244

245-
func assertRedirectURI(t *testing.T, actualURI, scheme, hostname string) bool {
245+
func assertRedirectURI(t *testing.T, actualURI, scheme, hostname, path string) bool {
246246
redirect, err := url.Parse(actualURI)
247247
if err != nil {
248248
t.Errorf("could not parse redirect_uri: %s", err)
@@ -256,8 +256,8 @@ func assertRedirectURI(t *testing.T, actualURI, scheme, hostname string) bool {
256256
t.Errorf("redirect_uri wants hostname %s but was %s", hostname, actualHostname)
257257
return false
258258
}
259-
if redirect.Path != "" {
260-
t.Errorf("redirect_uri wants path `` but was %s", redirect.Path)
259+
if actualPath := redirect.Path; actualPath != path {
260+
t.Errorf("redirect_uri wants path %s but was %s", path, actualPath)
261261
return false
262262
}
263263
return true

e2e_test/localserveropts_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package e2e_test
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/http/httptest"
7+
"sync"
8+
"testing"
9+
"time"
10+
11+
"github.com/int128/oauth2cli"
12+
"github.com/int128/oauth2cli/e2e_test/authserver"
13+
"github.com/int128/oauth2cli/e2e_test/client"
14+
"golang.org/x/oauth2"
15+
)
16+
17+
func TestLocalServerCallbackPath(t *testing.T) {
18+
ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second)
19+
defer cancel()
20+
openBrowserCh := make(chan string)
21+
var wg sync.WaitGroup
22+
wg.Add(1)
23+
go func() {
24+
defer wg.Done()
25+
defer close(openBrowserCh)
26+
// Start a local server and get a token.
27+
testServer := httptest.NewServer(&authserver.Handler{
28+
TestingT: t,
29+
NewAuthorizationResponse: func(req authserver.AuthorizationRequest) string {
30+
if want := "email profile"; req.Scope != want {
31+
t.Errorf("scope wants %s but %s", want, req.Scope)
32+
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
33+
}
34+
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/callback") {
35+
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
36+
}
37+
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
38+
},
39+
NewTokenResponse: func(req authserver.TokenRequest) (int, string) {
40+
if want := "AUTH_CODE"; req.Code != want {
41+
t.Errorf("code wants %s but %s", want, req.Code)
42+
return 400, invalidGrantResponse
43+
}
44+
return 200, validTokenResponse
45+
},
46+
})
47+
defer testServer.Close()
48+
cfg := oauth2cli.Config{
49+
OAuth2Config: oauth2.Config{
50+
ClientID: "YOUR_CLIENT_ID",
51+
ClientSecret: "YOUR_CLIENT_SECRET",
52+
Scopes: []string{"email", "profile"},
53+
Endpoint: oauth2.Endpoint{
54+
AuthURL: testServer.URL + "/auth",
55+
TokenURL: testServer.URL + "/token",
56+
},
57+
},
58+
LocalServerCallbackPath: "/callback",
59+
LocalServerReadyChan: openBrowserCh,
60+
LocalServerMiddleware: loggingMiddleware(t),
61+
Logf: t.Logf,
62+
}
63+
token, err := oauth2cli.GetToken(ctx, cfg)
64+
if err != nil {
65+
t.Errorf("could not get a token: %s", err)
66+
return
67+
}
68+
if token.AccessToken != "ACCESS_TOKEN" {
69+
t.Errorf("AccessToken wants %s but %s", "ACCESS_TOKEN", token.AccessToken)
70+
}
71+
if token.RefreshToken != "REFRESH_TOKEN" {
72+
t.Errorf("RefreshToken wants %s but %s", "REFRESH_TOKEN", token.RefreshToken)
73+
}
74+
}()
75+
wg.Add(1)
76+
go func() {
77+
defer wg.Done()
78+
toURL, ok := <-openBrowserCh
79+
if !ok {
80+
t.Errorf("server already closed")
81+
return
82+
}
83+
client.GetAndVerify(t, toURL, 200, oauth2cli.DefaultLocalServerSuccessHTML)
84+
}()
85+
wg.Wait()
86+
}

e2e_test/pkce_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func TestPKCE(t *testing.T) {
4040
t.Errorf("scope wants %s but %s", want, req.Scope)
4141
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
4242
}
43-
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost") {
43+
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/") {
4444
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
4545
}
4646
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")

e2e_test/tls_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func TestTLS(t *testing.T) {
3131
t.Errorf("scope wants %s but %s", want, req.Scope)
3232
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
3333
}
34-
if !assertRedirectURI(t, req.RedirectURI, "https", "localhost") {
34+
if !assertRedirectURI(t, req.RedirectURI, "https", "localhost", "/") {
3535
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
3636
}
3737
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")

oauth2cli.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ type Config struct {
8383
// This is required when LocalServerCertFile is set.
8484
LocalServerKeyFile string
8585

86+
// Callback path of the local server.
87+
// If your provider requires a specific path of the redirect URL, set it here.
88+
// Default to "/".
89+
LocalServerCallbackPath string
90+
8691
// Response HTML body on authorization completed.
8792
// Default to DefaultLocalServerSuccessHTML.
8893
LocalServerSuccessHTML string
@@ -119,6 +124,9 @@ func (cfg *Config) validateAndSetDefaults() error {
119124
}
120125
cfg.State = state
121126
}
127+
if cfg.LocalServerCallbackPath == "" {
128+
cfg.LocalServerCallbackPath = "/"
129+
}
122130
if cfg.LocalServerMiddleware == nil {
123131
cfg.LocalServerMiddleware = noopMiddleware
124132
}

server.go

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,16 @@ func receiveCodeViaLocalServer(ctx context.Context, cfg *Config) (string, error)
2222
defer localServerListener.Close()
2323

2424
localServerPort := localServerListener.Addr().(*net.TCPAddr).Port
25-
cfg.OAuth2Config.RedirectURL = constructRedirectURL(cfg, localServerPort)
25+
localServerURL := constructLocalServerURL(cfg, localServerPort)
26+
localServerIndexURL, err := localServerURL.Parse("/")
27+
if err != nil {
28+
return "", fmt.Errorf("construct the index URL: %w", err)
29+
}
30+
localServerCallbackURL, err := localServerURL.Parse(cfg.LocalServerCallbackPath)
31+
if err != nil {
32+
return "", fmt.Errorf("construct the callback URL: %w", err)
33+
}
34+
cfg.OAuth2Config.RedirectURL = localServerCallbackURL.String()
2635

2736
respCh := make(chan *authorizationResponse)
2837
server := http.Server{
@@ -84,7 +93,7 @@ func receiveCodeViaLocalServer(ctx context.Context, cfg *Config) (string, error)
8493
return nil
8594
}
8695
select {
87-
case cfg.LocalServerReadyChan <- cfg.OAuth2Config.RedirectURL:
96+
case cfg.LocalServerReadyChan <- localServerIndexURL.String():
8897
return nil
8998
case <-ctx.Done():
9099
return ctx.Err()
@@ -99,14 +108,14 @@ func receiveCodeViaLocalServer(ctx context.Context, cfg *Config) (string, error)
99108
return resp.code, resp.err
100109
}
101110

102-
func constructRedirectURL(cfg *Config, port int) string {
103-
var redirect url.URL
104-
redirect.Host = fmt.Sprintf("%s:%d", cfg.RedirectURLHostname, port)
105-
redirect.Scheme = "http"
111+
func constructLocalServerURL(cfg *Config, port int) url.URL {
112+
var localServer url.URL
113+
localServer.Host = fmt.Sprintf("%s:%d", cfg.RedirectURLHostname, port)
114+
localServer.Scheme = "http"
106115
if cfg.isLocalServerHTTPS() {
107-
redirect.Scheme = "https"
116+
localServer.Scheme = "https"
108117
}
109-
return redirect.String()
118+
return localServer
110119
}
111120

112121
type authorizationResponse struct {
@@ -123,11 +132,11 @@ type localServerHandler struct {
123132
func (h *localServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
124133
q := r.URL.Query()
125134
switch {
126-
case r.Method == "GET" && r.URL.Path == "/" && q.Get("error") != "":
135+
case r.Method == "GET" && r.URL.Path == h.config.LocalServerCallbackPath && q.Get("error") != "":
127136
h.onceRespCh.Do(func() {
128137
h.respCh <- h.handleErrorResponse(w, r)
129138
})
130-
case r.Method == "GET" && r.URL.Path == "/" && q.Get("code") != "":
139+
case r.Method == "GET" && r.URL.Path == h.config.LocalServerCallbackPath && q.Get("code") != "":
131140
h.onceRespCh.Do(func() {
132141
h.respCh <- h.handleCodeResponse(w, r)
133142
})

0 commit comments

Comments
 (0)