Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions e2e_test/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestHappyPath(t *testing.T) {
t.Errorf("scope wants %s but %s", want, req.Scope)
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
}
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost") {
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/") {
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
}
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
Expand Down Expand Up @@ -106,7 +106,7 @@ func TestRedirectURLHostname(t *testing.T) {
t.Errorf("scope wants %s but %s", want, req.Scope)
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
}
if !assertRedirectURI(t, req.RedirectURI, "http", "127.0.0.1") {
if !assertRedirectURI(t, req.RedirectURI, "http", "127.0.0.1", "/") {
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
}
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
Expand Down Expand Up @@ -177,7 +177,7 @@ func TestSuccessRedirect(t *testing.T) {
t.Errorf("scope wants %s but %s", want, req.Scope)
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
}
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost") {
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/") {
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
}
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
Expand Down Expand Up @@ -242,7 +242,7 @@ func TestSuccessRedirect(t *testing.T) {
wg.Wait()
}

func assertRedirectURI(t *testing.T, actualURI, scheme, hostname string) bool {
func assertRedirectURI(t *testing.T, actualURI, scheme, hostname, path string) bool {
redirect, err := url.Parse(actualURI)
if err != nil {
t.Errorf("could not parse redirect_uri: %s", err)
Expand All @@ -256,8 +256,8 @@ func assertRedirectURI(t *testing.T, actualURI, scheme, hostname string) bool {
t.Errorf("redirect_uri wants hostname %s but was %s", hostname, actualHostname)
return false
}
if redirect.Path != "" {
t.Errorf("redirect_uri wants path `` but was %s", redirect.Path)
if actualPath := redirect.Path; actualPath != path {
t.Errorf("redirect_uri wants path %s but was %s", path, actualPath)
return false
}
return true
Expand Down
86 changes: 86 additions & 0 deletions e2e_test/localserveropts_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package e2e_test

import (
"context"
"fmt"
"net/http/httptest"
"sync"
"testing"
"time"

"github.com/int128/oauth2cli"
"github.com/int128/oauth2cli/e2e_test/authserver"
"github.com/int128/oauth2cli/e2e_test/client"
"golang.org/x/oauth2"
)

func TestLocalServerCallbackPath(t *testing.T) {
ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second)
defer cancel()
openBrowserCh := make(chan string)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
defer close(openBrowserCh)
// Start a local server and get a token.
testServer := httptest.NewServer(&authserver.Handler{
TestingT: t,
NewAuthorizationResponse: func(req authserver.AuthorizationRequest) string {
if want := "email profile"; req.Scope != want {
t.Errorf("scope wants %s but %s", want, req.Scope)
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
}
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/callback") {
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
}
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
},
NewTokenResponse: func(req authserver.TokenRequest) (int, string) {
if want := "AUTH_CODE"; req.Code != want {
t.Errorf("code wants %s but %s", want, req.Code)
return 400, invalidGrantResponse
}
return 200, validTokenResponse
},
})
defer testServer.Close()
cfg := oauth2cli.Config{
OAuth2Config: oauth2.Config{
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
Scopes: []string{"email", "profile"},
Endpoint: oauth2.Endpoint{
AuthURL: testServer.URL + "/auth",
TokenURL: testServer.URL + "/token",
},
},
LocalServerCallbackPath: "/callback",
LocalServerReadyChan: openBrowserCh,
LocalServerMiddleware: loggingMiddleware(t),
Logf: t.Logf,
}
token, err := oauth2cli.GetToken(ctx, cfg)
if err != nil {
t.Errorf("could not get a token: %s", err)
return
}
if token.AccessToken != "ACCESS_TOKEN" {
t.Errorf("AccessToken wants %s but %s", "ACCESS_TOKEN", token.AccessToken)
}
if token.RefreshToken != "REFRESH_TOKEN" {
t.Errorf("RefreshToken wants %s but %s", "REFRESH_TOKEN", token.RefreshToken)
}
}()
wg.Add(1)
go func() {
defer wg.Done()
toURL, ok := <-openBrowserCh
if !ok {
t.Errorf("server already closed")
return
}
client.GetAndVerify(t, toURL, 200, oauth2cli.DefaultLocalServerSuccessHTML)
}()
wg.Wait()
}
2 changes: 1 addition & 1 deletion e2e_test/pkce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestPKCE(t *testing.T) {
t.Errorf("scope wants %s but %s", want, req.Scope)
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
}
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost") {
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/") {
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
}
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
Expand Down
2 changes: 1 addition & 1 deletion e2e_test/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestTLS(t *testing.T) {
t.Errorf("scope wants %s but %s", want, req.Scope)
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
}
if !assertRedirectURI(t, req.RedirectURI, "https", "localhost") {
if !assertRedirectURI(t, req.RedirectURI, "https", "localhost", "/") {
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
}
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
Expand Down
8 changes: 8 additions & 0 deletions oauth2cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ type Config struct {
// This is required when LocalServerCertFile is set.
LocalServerKeyFile string

// Callback path of the local server.
// If your provider requires a specific path of the redirect URL, set it here.
// Default to "/".
LocalServerCallbackPath string

// Response HTML body on authorization completed.
// Default to DefaultLocalServerSuccessHTML.
LocalServerSuccessHTML string
Expand Down Expand Up @@ -119,6 +124,9 @@ func (cfg *Config) validateAndSetDefaults() error {
}
cfg.State = state
}
if cfg.LocalServerCallbackPath == "" {
cfg.LocalServerCallbackPath = "/"
}
if cfg.LocalServerMiddleware == nil {
cfg.LocalServerMiddleware = noopMiddleware
}
Expand Down
29 changes: 19 additions & 10 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,16 @@ func receiveCodeViaLocalServer(ctx context.Context, cfg *Config) (string, error)
defer localServerListener.Close()

localServerPort := localServerListener.Addr().(*net.TCPAddr).Port
cfg.OAuth2Config.RedirectURL = constructRedirectURL(cfg, localServerPort)
localServerURL := constructLocalServerURL(cfg, localServerPort)
localServerIndexURL, err := localServerURL.Parse("/")
if err != nil {
return "", fmt.Errorf("construct the index URL: %w", err)
}
localServerCallbackURL, err := localServerURL.Parse(cfg.LocalServerCallbackPath)
if err != nil {
return "", fmt.Errorf("construct the callback URL: %w", err)
}
cfg.OAuth2Config.RedirectURL = localServerCallbackURL.String()

respCh := make(chan *authorizationResponse)
server := http.Server{
Expand Down Expand Up @@ -84,7 +93,7 @@ func receiveCodeViaLocalServer(ctx context.Context, cfg *Config) (string, error)
return nil
}
select {
case cfg.LocalServerReadyChan <- cfg.OAuth2Config.RedirectURL:
case cfg.LocalServerReadyChan <- localServerIndexURL.String():
return nil
case <-ctx.Done():
return ctx.Err()
Expand All @@ -99,14 +108,14 @@ func receiveCodeViaLocalServer(ctx context.Context, cfg *Config) (string, error)
return resp.code, resp.err
}

func constructRedirectURL(cfg *Config, port int) string {
var redirect url.URL
redirect.Host = fmt.Sprintf("%s:%d", cfg.RedirectURLHostname, port)
redirect.Scheme = "http"
func constructLocalServerURL(cfg *Config, port int) url.URL {
var localServer url.URL
localServer.Host = fmt.Sprintf("%s:%d", cfg.RedirectURLHostname, port)
localServer.Scheme = "http"
if cfg.isLocalServerHTTPS() {
redirect.Scheme = "https"
localServer.Scheme = "https"
}
return redirect.String()
return localServer
}

type authorizationResponse struct {
Expand All @@ -123,11 +132,11 @@ type localServerHandler struct {
func (h *localServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
switch {
case r.Method == "GET" && r.URL.Path == "/" && q.Get("error") != "":
case r.Method == "GET" && r.URL.Path == h.config.LocalServerCallbackPath && q.Get("error") != "":
h.onceRespCh.Do(func() {
h.respCh <- h.handleErrorResponse(w, r)
})
case r.Method == "GET" && r.URL.Path == "/" && q.Get("code") != "":
case r.Method == "GET" && r.URL.Path == h.config.LocalServerCallbackPath && q.Get("code") != "":
h.onceRespCh.Do(func() {
h.respCh <- h.handleCodeResponse(w, r)
})
Expand Down
Loading