Skip to content

Commit 6dfd1ff

Browse files
committed
refactor: refactor OAuth callback server for safe, complete token exchange
- Add .authgate-tokens.json to .gitignore for sensitive file exclusion - Refactor callback server to handle the token exchange internally and return TokenStorage, not just the code - Ensure the browser receives the true exchange result by holding the HTTP response open until exchange completion - Update callback server timeout handling for write deadline with a dedicated constant - Make callback server concurrency-safe and idempotent for browser retries by using sync.Once - Update tests to validate TokenStorage and error handling, and add test coverage for exchange failures - Change performAuthCodeFlow to delegate code exchange to the callback handler, streamlining error handling and response Signed-off-by: appleboy <appleboy.tw@gmail.com>
1 parent dc6f663 commit 6dfd1ff

4 files changed

Lines changed: 151 additions & 52 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ go.work.sum
3131
# .idea/
3232
# .vscode/
3333
bin
34+
.authgate-tokens.json

callback.go

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,32 @@ import (
1313
const (
1414
// callbackTimeout is how long we wait for the browser to deliver the code.
1515
callbackTimeout = 5 * time.Minute
16+
17+
// callbackWriteTimeout is the HTTP write deadline for the callback handler.
18+
// It must exceed tokenExchangeTimeout to ensure the exchange result can be
19+
// written back to the browser before the connection times out.
20+
callbackWriteTimeout = 30 * time.Second
1621
)
1722

1823
// callbackResult holds the outcome of the local callback round-trip.
1924
type callbackResult struct {
20-
Code string
21-
Error string
22-
Desc string
25+
Storage *TokenStorage
26+
Error string
27+
Desc string
2328
}
2429

2530
// startCallbackServer starts a local HTTP server on the given port and waits
26-
// for the OAuth callback. It validates the returned state against expectedState
27-
// and returns the authorization code (or an error).
31+
// for the OAuth callback. It validates the returned state against expectedState,
32+
// then calls exchangeFn with the received authorization code. The HTTP response
33+
// is held open until exchangeFn returns so the browser reflects the true outcome.
2834
//
2935
// The server shuts itself down after the first request or when ctx is cancelled.
30-
func startCallbackServer(ctx context.Context, port int, expectedState string) (string, error) {
36+
func startCallbackServer(
37+
ctx context.Context,
38+
port int,
39+
expectedState string,
40+
exchangeFn func(ctx context.Context, code string) (*TokenStorage, error),
41+
) (*TokenStorage, error) {
3142
resultCh := make(chan callbackResult, 1)
3243

3344
// sendResult delivers the result exactly once. Any concurrent or subsequent
@@ -38,6 +49,14 @@ func startCallbackServer(ctx context.Context, port int, expectedState string) (s
3849
once.Do(func() { resultCh <- r })
3950
}
4051

52+
// exchangeOnce ensures the token exchange runs at most once even when the
53+
// browser retries the callback request.
54+
var (
55+
exchangeOnce sync.Once
56+
exchangeStorage *TokenStorage
57+
exchangeErr error
58+
)
59+
4160
mux := http.NewServeMux()
4261
mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
4362
q := r.URL.Query()
@@ -69,21 +88,32 @@ func startCallbackServer(ctx context.Context, port int, expectedState string) (s
6988
return
7089
}
7190

91+
// Hold the HTTP response open while exchanging the code for tokens so
92+
// the browser reflects the true outcome (success or failure).
93+
exchangeOnce.Do(func() {
94+
exchangeStorage, exchangeErr = exchangeFn(r.Context(), code)
95+
})
96+
if exchangeErr != nil {
97+
writeCallbackPage(w, false, "token_exchange_failed", exchangeErr.Error())
98+
sendResult(callbackResult{Error: "token_exchange_failed", Desc: exchangeErr.Error()})
99+
return
100+
}
101+
72102
writeCallbackPage(w, true, "", "")
73-
sendResult(callbackResult{Code: code})
103+
sendResult(callbackResult{Storage: exchangeStorage})
74104
})
75105

76106
srv := &http.Server{
77107
Addr: fmt.Sprintf("127.0.0.1:%d", port),
78108
Handler: mux,
79109
ReadTimeout: 10 * time.Second,
80-
WriteTimeout: 10 * time.Second,
110+
WriteTimeout: callbackWriteTimeout,
81111
}
82112

83113
// Use a listener so we can report the actual bound port.
84114
ln, err := (&net.ListenConfig{}).Listen(ctx, "tcp", srv.Addr)
85115
if err != nil {
86-
return "", fmt.Errorf("failed to start callback server on port %d: %w", port, err)
116+
return nil, fmt.Errorf("failed to start callback server on port %d: %w", port, err)
87117
}
88118

89119
// Serve in background; shut down after receiving the result.
@@ -102,17 +132,17 @@ func startCallbackServer(ctx context.Context, port int, expectedState string) (s
102132
case result := <-resultCh:
103133
if result.Error != "" {
104134
if result.Desc != "" {
105-
return "", fmt.Errorf("%s: %s", result.Error, result.Desc)
135+
return nil, fmt.Errorf("%s: %s", result.Error, result.Desc)
106136
}
107-
return "", fmt.Errorf("%s", result.Error)
137+
return nil, fmt.Errorf("%s", result.Error)
108138
}
109-
return result.Code, nil
139+
return result.Storage, nil
110140

111141
case <-ctx.Done():
112-
return "", ctx.Err()
142+
return nil, ctx.Err()
113143

114144
case <-time.After(callbackTimeout):
115-
return "", fmt.Errorf("timed out waiting for browser authorization (%s)", callbackTimeout)
145+
return nil, fmt.Errorf("timed out waiting for browser authorization (%s)", callbackTimeout)
116146
}
117147
}
118148

callback_test.go

Lines changed: 96 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,48 @@ import (
1010
"time"
1111
)
1212

13+
type serverResult struct {
14+
storage *TokenStorage
15+
err error
16+
}
17+
1318
// startCallbackServerAsync starts the callback server in a goroutine and
14-
// returns a channel that will receive the authorization code (or error string).
15-
func startCallbackServerAsync(t *testing.T, port int, state string) chan string {
19+
// returns a channel that will receive the final result (storage or error).
20+
func startCallbackServerAsync(
21+
t *testing.T,
22+
port int,
23+
state string,
24+
exchangeFn func(ctx context.Context, code string) (*TokenStorage, error),
25+
) chan serverResult {
1626
t.Helper()
17-
ch := make(chan string, 1)
27+
ch := make(chan serverResult, 1)
1828
go func() {
19-
code, err := startCallbackServer(context.Background(), port, state)
20-
if err != nil {
21-
ch <- "ERROR:" + err.Error()
22-
} else {
23-
ch <- code
24-
}
29+
storage, err := startCallbackServer(context.Background(), port, state, exchangeFn)
30+
ch <- serverResult{storage: storage, err: err}
2531
}()
2632
// Give the server a moment to bind.
2733
time.Sleep(50 * time.Millisecond)
2834
return ch
2935
}
3036

37+
// mockExchangeFn returns an exchangeFn that succeeds with a stub TokenStorage.
38+
func mockExchangeFn(t *testing.T) func(ctx context.Context, code string) (*TokenStorage, error) {
39+
t.Helper()
40+
return func(_ context.Context, _ string) (*TokenStorage, error) {
41+
return &TokenStorage{
42+
AccessToken: "mock-access-token",
43+
RefreshToken: "mock-refresh-token",
44+
TokenType: "Bearer",
45+
ExpiresAt: time.Now().Add(time.Hour),
46+
}, nil
47+
}
48+
}
49+
3150
func TestCallbackServer_Success(t *testing.T) {
3251
const port = 19001
3352
state := "test-state-success"
3453

35-
ch := startCallbackServerAsync(t, port, state)
54+
ch := startCallbackServerAsync(t, port, state, mockExchangeFn(t))
3655

3756
// Simulate the browser redirect.
3857
callbackURL := fmt.Sprintf(
@@ -53,11 +72,57 @@ func TestCallbackServer_Success(t *testing.T) {
5372
t.Errorf("expected success page, got: %s", string(body))
5473
}
5574

56-
// Check code returned to CLI.
75+
// Check that storage is returned to the CLI.
76+
select {
77+
case result := <-ch:
78+
if result.err != nil {
79+
t.Errorf("expected no error, got: %v", result.err)
80+
}
81+
if result.storage == nil || result.storage.AccessToken != "mock-access-token" {
82+
t.Errorf("unexpected storage: %+v", result.storage)
83+
}
84+
case <-time.After(3 * time.Second):
85+
t.Fatal("timed out waiting for callback result")
86+
}
87+
}
88+
89+
func TestCallbackServer_ExchangeFailure(t *testing.T) {
90+
const port = 19006
91+
state := "test-state-exchange-fail"
92+
93+
failFn := func(_ context.Context, _ string) (*TokenStorage, error) {
94+
return nil, fmt.Errorf("server returned status 400: invalid_grant")
95+
}
96+
ch := startCallbackServerAsync(t, port, state, failFn)
97+
98+
callbackURL := fmt.Sprintf(
99+
"http://127.0.0.1:%d/callback?code=badcode&state=%s",
100+
port, state,
101+
)
102+
resp, err := http.Get(callbackURL) //nolint:noctx,gosec
103+
if err != nil {
104+
t.Fatalf("GET callback failed: %v", err)
105+
}
106+
defer resp.Body.Close()
107+
108+
body, _ := io.ReadAll(resp.Body)
109+
if resp.StatusCode != http.StatusOK {
110+
t.Errorf("unexpected status %d", resp.StatusCode)
111+
}
112+
if !strings.Contains(string(body), "Authorization Failed") {
113+
t.Errorf("expected failure page, got: %s", string(body))
114+
}
115+
if !strings.Contains(string(body), "invalid_grant") {
116+
t.Errorf("expected error detail in page, got: %s", string(body))
117+
}
118+
57119
select {
58120
case result := <-ch:
59-
if result != "mycode123" {
60-
t.Errorf("expected code mycode123, got: %s", result)
121+
if result.err == nil {
122+
t.Error("expected an error, got nil")
123+
}
124+
if result.storage != nil {
125+
t.Errorf("expected nil storage, got: %+v", result.storage)
61126
}
62127
case <-time.After(3 * time.Second):
63128
t.Fatal("timed out waiting for callback result")
@@ -68,7 +133,7 @@ func TestCallbackServer_StateMismatch(t *testing.T) {
68133
const port = 19002
69134
state := "expected-state"
70135

71-
ch := startCallbackServerAsync(t, port, state)
136+
ch := startCallbackServerAsync(t, port, state, nil)
72137

73138
callbackURL := fmt.Sprintf(
74139
"http://127.0.0.1:%d/callback?code=mycode&state=wrong-state",
@@ -87,8 +152,8 @@ func TestCallbackServer_StateMismatch(t *testing.T) {
87152

88153
select {
89154
case result := <-ch:
90-
if !strings.HasPrefix(result, "ERROR:") {
91-
t.Errorf("expected error for state mismatch, got: %s", result)
155+
if result.err == nil {
156+
t.Errorf("expected error for state mismatch, got nil")
92157
}
93158
case <-time.After(3 * time.Second):
94159
t.Fatal("timed out waiting for callback result")
@@ -99,7 +164,7 @@ func TestCallbackServer_OAuthError(t *testing.T) {
99164
const port = 19003
100165
state := "state-for-error"
101166

102-
ch := startCallbackServerAsync(t, port, state)
167+
ch := startCallbackServerAsync(t, port, state, nil)
103168

104169
callbackURL := fmt.Sprintf(
105170
"http://127.0.0.1:%d/callback?error=access_denied&error_description=User+denied&state=%s",
@@ -118,11 +183,11 @@ func TestCallbackServer_OAuthError(t *testing.T) {
118183

119184
select {
120185
case result := <-ch:
121-
if !strings.HasPrefix(result, "ERROR:") {
122-
t.Errorf("expected error for access_denied, got: %s", result)
186+
if result.err == nil {
187+
t.Errorf("expected error for access_denied, got nil")
123188
}
124-
if !strings.Contains(result, "access_denied") {
125-
t.Errorf("expected error to mention access_denied, got: %s", result)
189+
if !strings.Contains(result.err.Error(), "access_denied") {
190+
t.Errorf("expected error to mention access_denied, got: %v", result.err)
126191
}
127192
case <-time.After(3 * time.Second):
128193
t.Fatal("timed out waiting for callback result")
@@ -137,7 +202,7 @@ func TestCallbackServer_DoubleCallback(t *testing.T) {
137202
const port = 19005
138203
state := "test-state-double"
139204

140-
ch := startCallbackServerAsync(t, port, state)
205+
ch := startCallbackServerAsync(t, port, state, mockExchangeFn(t))
141206

142207
url := fmt.Sprintf("http://127.0.0.1:%d/callback?code=mycode&state=%s", port, state)
143208

@@ -163,11 +228,14 @@ func TestCallbackServer_DoubleCallback(t *testing.T) {
163228
}
164229
}
165230

166-
// startCallbackServer must also return promptly.
231+
// startCallbackServer must also return promptly with a valid storage.
167232
select {
168233
case result := <-ch:
169-
if result != "mycode" {
170-
t.Errorf("expected mycode, got: %s", result)
234+
if result.err != nil {
235+
t.Errorf("expected no error, got: %v", result.err)
236+
}
237+
if result.storage == nil {
238+
t.Error("expected non-nil storage")
171239
}
172240
case <-time.After(3 * time.Second):
173241
t.Fatal("timed out waiting for callback result")
@@ -178,7 +246,7 @@ func TestCallbackServer_MissingCode(t *testing.T) {
178246
const port = 19004
179247
state := "state-for-missing-code"
180248

181-
ch := startCallbackServerAsync(t, port, state)
249+
ch := startCallbackServerAsync(t, port, state, nil)
182250

183251
// Correct state but no code parameter.
184252
callbackURL := fmt.Sprintf(
@@ -193,8 +261,8 @@ func TestCallbackServer_MissingCode(t *testing.T) {
193261

194262
select {
195263
case result := <-ch:
196-
if !strings.HasPrefix(result, "ERROR:") {
197-
t.Errorf("expected error for missing code, got: %s", result)
264+
if result.err == nil {
265+
t.Errorf("expected error for missing code, got nil")
198266
}
199267
case <-time.After(3 * time.Second):
200268
t.Fatal("timed out waiting for callback result")

main.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -339,20 +339,20 @@ func performAuthCodeFlow(ctx context.Context) (*TokenStorage, error) {
339339
fmt.Println("Browser opened. Please complete authorization in your browser.")
340340
}
341341

342-
// Start local callback server and wait for the code.
342+
// Start local callback server. The exchange runs inside the handler so the
343+
// browser sees the true outcome (success or failure) rather than a premature
344+
// success page.
343345
fmt.Printf("Step 2: Waiting for callback on http://localhost:%d/callback ...\n", callbackPort)
344-
code, err := startCallbackServer(ctx, callbackPort, state)
346+
storage, err := startCallbackServer(ctx, callbackPort, state,
347+
func(cbCtx context.Context, code string) (*TokenStorage, error) {
348+
fmt.Println("Authorization code received!")
349+
fmt.Println("Step 3: Exchanging authorization code for tokens...")
350+
return exchangeCode(cbCtx, code, pkce.Verifier)
351+
},
352+
)
345353
if err != nil {
346354
return nil, fmt.Errorf("authorization failed: %w", err)
347355
}
348-
fmt.Println("Authorization code received!")
349-
350-
// Exchange code for tokens.
351-
fmt.Println("Step 3: Exchanging authorization code for tokens...")
352-
storage, err := exchangeCode(ctx, code, pkce.Verifier)
353-
if err != nil {
354-
return nil, fmt.Errorf("token exchange failed: %w", err)
355-
}
356356

357357
if err := saveTokens(storage); err != nil {
358358
fmt.Printf("Warning: Failed to save tokens: %v\n", err)

0 commit comments

Comments
 (0)