From 3d0c2b7333e0abba8ecab7f0fe84a73952359135 Mon Sep 17 00:00:00 2001 From: Hidetake Iwata Date: Sat, 25 Jan 2025 16:06:52 +0900 Subject: [PATCH] Add `Config.LocalServerCallbackPath` --- e2e_test/e2e_test.go | 12 ++--- e2e_test/localserveropts_test.go | 86 ++++++++++++++++++++++++++++++++ e2e_test/pkce_test.go | 2 +- e2e_test/tls_test.go | 2 +- oauth2cli.go | 8 +++ server.go | 29 +++++++---- 6 files changed, 121 insertions(+), 18 deletions(-) create mode 100644 e2e_test/localserveropts_test.go diff --git a/e2e_test/e2e_test.go b/e2e_test/e2e_test.go index 764b7a8..7cb4e4c 100644 --- a/e2e_test/e2e_test.go +++ b/e2e_test/e2e_test.go @@ -36,7 +36,7 @@ func TestHappyPath(t *testing.T) { t.Errorf("scope wants %s but %s", want, req.Scope) return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI) } - if !assertRedirectURI(t, req.RedirectURI, "http", "localhost") { + if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/") { return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI) } return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE") @@ -106,7 +106,7 @@ func TestRedirectURLHostname(t *testing.T) { t.Errorf("scope wants %s but %s", want, req.Scope) return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI) } - if !assertRedirectURI(t, req.RedirectURI, "http", "127.0.0.1") { + if !assertRedirectURI(t, req.RedirectURI, "http", "127.0.0.1", "/") { return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI) } return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE") @@ -177,7 +177,7 @@ func TestSuccessRedirect(t *testing.T) { t.Errorf("scope wants %s but %s", want, req.Scope) return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI) } - if !assertRedirectURI(t, req.RedirectURI, "http", "localhost") { + if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/") { return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI) } return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE") @@ -242,7 +242,7 @@ func TestSuccessRedirect(t *testing.T) { wg.Wait() } -func assertRedirectURI(t *testing.T, actualURI, scheme, hostname string) bool { +func assertRedirectURI(t *testing.T, actualURI, scheme, hostname, path string) bool { redirect, err := url.Parse(actualURI) if err != nil { t.Errorf("could not parse redirect_uri: %s", err) @@ -256,8 +256,8 @@ func assertRedirectURI(t *testing.T, actualURI, scheme, hostname string) bool { t.Errorf("redirect_uri wants hostname %s but was %s", hostname, actualHostname) return false } - if redirect.Path != "" { - t.Errorf("redirect_uri wants path `` but was %s", redirect.Path) + if actualPath := redirect.Path; actualPath != path { + t.Errorf("redirect_uri wants path %s but was %s", path, actualPath) return false } return true diff --git a/e2e_test/localserveropts_test.go b/e2e_test/localserveropts_test.go new file mode 100644 index 0000000..5afe7a6 --- /dev/null +++ b/e2e_test/localserveropts_test.go @@ -0,0 +1,86 @@ +package e2e_test + +import ( + "context" + "fmt" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/int128/oauth2cli" + "github.com/int128/oauth2cli/e2e_test/authserver" + "github.com/int128/oauth2cli/e2e_test/client" + "golang.org/x/oauth2" +) + +func TestLocalServerCallbackPath(t *testing.T) { + ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second) + defer cancel() + openBrowserCh := make(chan string) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + defer close(openBrowserCh) + // Start a local server and get a token. + testServer := httptest.NewServer(&authserver.Handler{ + TestingT: t, + NewAuthorizationResponse: func(req authserver.AuthorizationRequest) string { + if want := "email profile"; req.Scope != want { + t.Errorf("scope wants %s but %s", want, req.Scope) + return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI) + } + if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/callback") { + return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI) + } + return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE") + }, + NewTokenResponse: func(req authserver.TokenRequest) (int, string) { + if want := "AUTH_CODE"; req.Code != want { + t.Errorf("code wants %s but %s", want, req.Code) + return 400, invalidGrantResponse + } + return 200, validTokenResponse + }, + }) + defer testServer.Close() + cfg := oauth2cli.Config{ + OAuth2Config: oauth2.Config{ + ClientID: "YOUR_CLIENT_ID", + ClientSecret: "YOUR_CLIENT_SECRET", + Scopes: []string{"email", "profile"}, + Endpoint: oauth2.Endpoint{ + AuthURL: testServer.URL + "/auth", + TokenURL: testServer.URL + "/token", + }, + }, + LocalServerCallbackPath: "/callback", + LocalServerReadyChan: openBrowserCh, + LocalServerMiddleware: loggingMiddleware(t), + Logf: t.Logf, + } + token, err := oauth2cli.GetToken(ctx, cfg) + if err != nil { + t.Errorf("could not get a token: %s", err) + return + } + if token.AccessToken != "ACCESS_TOKEN" { + t.Errorf("AccessToken wants %s but %s", "ACCESS_TOKEN", token.AccessToken) + } + if token.RefreshToken != "REFRESH_TOKEN" { + t.Errorf("RefreshToken wants %s but %s", "REFRESH_TOKEN", token.RefreshToken) + } + }() + wg.Add(1) + go func() { + defer wg.Done() + toURL, ok := <-openBrowserCh + if !ok { + t.Errorf("server already closed") + return + } + client.GetAndVerify(t, toURL, 200, oauth2cli.DefaultLocalServerSuccessHTML) + }() + wg.Wait() +} diff --git a/e2e_test/pkce_test.go b/e2e_test/pkce_test.go index 7fc70a9..76f5d43 100644 --- a/e2e_test/pkce_test.go +++ b/e2e_test/pkce_test.go @@ -40,7 +40,7 @@ func TestPKCE(t *testing.T) { t.Errorf("scope wants %s but %s", want, req.Scope) return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI) } - if !assertRedirectURI(t, req.RedirectURI, "http", "localhost") { + if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/") { return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI) } return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE") diff --git a/e2e_test/tls_test.go b/e2e_test/tls_test.go index e45d89d..4dbf8d0 100644 --- a/e2e_test/tls_test.go +++ b/e2e_test/tls_test.go @@ -31,7 +31,7 @@ func TestTLS(t *testing.T) { t.Errorf("scope wants %s but %s", want, req.Scope) return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI) } - if !assertRedirectURI(t, req.RedirectURI, "https", "localhost") { + if !assertRedirectURI(t, req.RedirectURI, "https", "localhost", "/") { return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI) } return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE") diff --git a/oauth2cli.go b/oauth2cli.go index aa7430f..0fa3811 100644 --- a/oauth2cli.go +++ b/oauth2cli.go @@ -83,6 +83,11 @@ type Config struct { // This is required when LocalServerCertFile is set. LocalServerKeyFile string + // Callback path of the local server. + // If your provider requires a specific path of the redirect URL, set it here. + // Default to "/". + LocalServerCallbackPath string + // Response HTML body on authorization completed. // Default to DefaultLocalServerSuccessHTML. LocalServerSuccessHTML string @@ -119,6 +124,9 @@ func (cfg *Config) validateAndSetDefaults() error { } cfg.State = state } + if cfg.LocalServerCallbackPath == "" { + cfg.LocalServerCallbackPath = "/" + } if cfg.LocalServerMiddleware == nil { cfg.LocalServerMiddleware = noopMiddleware } diff --git a/server.go b/server.go index afe7fa8..618308c 100644 --- a/server.go +++ b/server.go @@ -22,7 +22,16 @@ func receiveCodeViaLocalServer(ctx context.Context, cfg *Config) (string, error) defer localServerListener.Close() localServerPort := localServerListener.Addr().(*net.TCPAddr).Port - cfg.OAuth2Config.RedirectURL = constructRedirectURL(cfg, localServerPort) + localServerURL := constructLocalServerURL(cfg, localServerPort) + localServerIndexURL, err := localServerURL.Parse("/") + if err != nil { + return "", fmt.Errorf("construct the index URL: %w", err) + } + localServerCallbackURL, err := localServerURL.Parse(cfg.LocalServerCallbackPath) + if err != nil { + return "", fmt.Errorf("construct the callback URL: %w", err) + } + cfg.OAuth2Config.RedirectURL = localServerCallbackURL.String() respCh := make(chan *authorizationResponse) server := http.Server{ @@ -84,7 +93,7 @@ func receiveCodeViaLocalServer(ctx context.Context, cfg *Config) (string, error) return nil } select { - case cfg.LocalServerReadyChan <- cfg.OAuth2Config.RedirectURL: + case cfg.LocalServerReadyChan <- localServerIndexURL.String(): return nil case <-ctx.Done(): return ctx.Err() @@ -99,14 +108,14 @@ func receiveCodeViaLocalServer(ctx context.Context, cfg *Config) (string, error) return resp.code, resp.err } -func constructRedirectURL(cfg *Config, port int) string { - var redirect url.URL - redirect.Host = fmt.Sprintf("%s:%d", cfg.RedirectURLHostname, port) - redirect.Scheme = "http" +func constructLocalServerURL(cfg *Config, port int) url.URL { + var localServer url.URL + localServer.Host = fmt.Sprintf("%s:%d", cfg.RedirectURLHostname, port) + localServer.Scheme = "http" if cfg.isLocalServerHTTPS() { - redirect.Scheme = "https" + localServer.Scheme = "https" } - return redirect.String() + return localServer } type authorizationResponse struct { @@ -123,11 +132,11 @@ type localServerHandler struct { func (h *localServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() switch { - case r.Method == "GET" && r.URL.Path == "/" && q.Get("error") != "": + case r.Method == "GET" && r.URL.Path == h.config.LocalServerCallbackPath && q.Get("error") != "": h.onceRespCh.Do(func() { h.respCh <- h.handleErrorResponse(w, r) }) - case r.Method == "GET" && r.URL.Path == "/" && q.Get("code") != "": + case r.Method == "GET" && r.URL.Path == h.config.LocalServerCallbackPath && q.Get("code") != "": h.onceRespCh.Do(func() { h.respCh <- h.handleCodeResponse(w, r) })