Skip to content

Commit 14e2653

Browse files
appleboyclaude
andauthored
refactor(oauth): deduplicate token parsing and fix struct copy bug (#12)
* refactor(oauth): deduplicate token parsing and fix struct copy bug - Extract shared tokenResponse struct and parseOAuthError helper to eliminate duplicate error parsing - Simplify PKCE branching to always set code_verifier unconditionally - Fix incomplete struct copy in makeAPICallWithAutoRefresh using full struct assignment - Drain response body on 401 to allow HTTP connection reuse - Replace time.After with time.NewTimer to prevent timer leak - Extract quitInterrupted method to deduplicate interrupt-handling blocks Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(oauth): address PR review feedback for error handling and connection reuse - Add errResp.Error != "" guard in refreshAccessToken to avoid empty error messages when JSON unmarshals successfully but contains no OAuth error - Close response body immediately after draining on 401 so the HTTP transport can reuse the connection for the retry request Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 97a37a8 commit 14e2653

3 files changed

Lines changed: 50 additions & 54 deletions

File tree

callback.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ func startCallbackServer(
130130
}()
131131

132132
// Wait for callback, timeout, or context cancellation.
133+
timer := time.NewTimer(callbackTimeout)
134+
defer timer.Stop()
135+
133136
select {
134137
case result := <-resultCh:
135138
if result.Error != "" {
@@ -143,7 +146,7 @@ func startCallbackServer(
143146
case <-ctx.Done():
144147
return nil, ctx.Err()
145148

146-
case <-time.After(callbackTimeout):
149+
case <-timer.C:
147150
return nil, fmt.Errorf("timed out waiting for browser authorization (%s)", callbackTimeout)
148151
}
149152
}

main.go

Lines changed: 34 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,25 @@ type ErrorResponse struct {
204204
ErrorDescription string `json:"error_description"`
205205
}
206206

207+
// tokenResponse is the JSON structure returned by /oauth/token.
208+
type tokenResponse struct {
209+
AccessToken string `json:"access_token"`
210+
RefreshToken string `json:"refresh_token"`
211+
TokenType string `json:"token_type"`
212+
ExpiresIn int `json:"expires_in"`
213+
Scope string `json:"scope"`
214+
}
215+
216+
// parseOAuthError attempts to extract a structured OAuth error from a non-200
217+
// response body. Falls back to including the raw body in the error message.
218+
func parseOAuthError(statusCode int, body []byte, action string) error {
219+
var errResp ErrorResponse
220+
if jsonErr := json.Unmarshal(body, &errResp); jsonErr == nil && errResp.Error != "" {
221+
return fmt.Errorf("%s: %s", errResp.Error, errResp.ErrorDescription)
222+
}
223+
return fmt.Errorf("%s failed with status %d: %s", action, statusCode, string(body))
224+
}
225+
207226
func loadTokens() (*tui.TokenStorage, error) {
208227
data, err := os.ReadFile(tokenFile)
209228
if err != nil {
@@ -317,13 +336,10 @@ func exchangeCode(ctx context.Context, code, codeVerifier string) (*tui.TokenSto
317336
data.Set("redirect_uri", redirectURI)
318337
data.Set("client_id", clientID)
319338

320-
if isPublicClient() {
321-
// Public client: send code_verifier for PKCE verification.
322-
data.Set("code_verifier", codeVerifier)
323-
} else {
324-
// Confidential client: send client_secret (and also verifier for PKCE).
339+
// PKCE is always enabled (defense in depth).
340+
data.Set("code_verifier", codeVerifier)
341+
if !isPublicClient() {
325342
data.Set("client_secret", clientSecret)
326-
data.Set("code_verifier", codeVerifier)
327343
}
328344

329345
req, err := http.NewRequestWithContext(
@@ -349,24 +365,10 @@ func exchangeCode(ctx context.Context, code, codeVerifier string) (*tui.TokenSto
349365
}
350366

351367
if resp.StatusCode != http.StatusOK {
352-
var errResp ErrorResponse
353-
if jsonErr := json.Unmarshal(body, &errResp); jsonErr == nil && errResp.Error != "" {
354-
return nil, fmt.Errorf("%s: %s", errResp.Error, errResp.ErrorDescription)
355-
}
356-
return nil, fmt.Errorf(
357-
"token exchange failed with status %d: %s",
358-
resp.StatusCode,
359-
string(body),
360-
)
368+
return nil, parseOAuthError(resp.StatusCode, body, "token exchange")
361369
}
362370

363-
var tokenResp struct {
364-
AccessToken string `json:"access_token"`
365-
RefreshToken string `json:"refresh_token"`
366-
TokenType string `json:"token_type"`
367-
ExpiresIn int `json:"expires_in"`
368-
Scope string `json:"scope"`
369-
}
371+
var tokenResp tokenResponse
370372
if err := json.Unmarshal(body, &tokenResp); err != nil {
371373
return nil, fmt.Errorf("failed to parse token response: %w", err)
372374
}
@@ -427,22 +429,18 @@ func refreshAccessToken(ctx context.Context, refreshToken string) (*tui.TokenSto
427429
}
428430

429431
if resp.StatusCode != http.StatusOK {
432+
// Check for expired/invalid refresh token before general error handling.
430433
var errResp ErrorResponse
431-
if jsonErr := json.Unmarshal(body, &errResp); jsonErr == nil {
434+
if jsonErr := json.Unmarshal(body, &errResp); jsonErr == nil && errResp.Error != "" {
432435
if errResp.Error == "invalid_grant" || errResp.Error == "invalid_token" {
433436
return nil, tui.ErrRefreshTokenExpired
434437
}
435438
return nil, fmt.Errorf("%s: %s", errResp.Error, errResp.ErrorDescription)
436439
}
437-
return nil, fmt.Errorf("refresh failed with status %d: %s", resp.StatusCode, string(body))
440+
return nil, parseOAuthError(resp.StatusCode, body, "refresh")
438441
}
439442

440-
var tokenResp struct {
441-
AccessToken string `json:"access_token"`
442-
RefreshToken string `json:"refresh_token"`
443-
TokenType string `json:"token_type"`
444-
ExpiresIn int `json:"expires_in"`
445-
}
443+
var tokenResp tokenResponse
446444
if err := json.Unmarshal(body, &tokenResp); err != nil {
447445
return nil, fmt.Errorf("failed to parse token response: %w", err)
448446
}
@@ -498,11 +496,7 @@ func verifyToken(ctx context.Context, accessToken string) (string, error) {
498496
}
499497

500498
if resp.StatusCode != http.StatusOK {
501-
var errResp ErrorResponse
502-
if jsonErr := json.Unmarshal(body, &errResp); jsonErr == nil {
503-
return "", fmt.Errorf("%s: %s", errResp.Error, errResp.ErrorDescription)
504-
}
505-
return "", fmt.Errorf("server returned status %d: %s", resp.StatusCode, string(body))
499+
return "", parseOAuthError(resp.StatusCode, body, "token verification")
506500
}
507501

508502
return string(body), nil
@@ -523,6 +517,10 @@ func makeAPICallWithAutoRefresh(ctx context.Context, storage *tui.TokenStorage)
523517
defer resp.Body.Close()
524518

525519
if resp.StatusCode == http.StatusUnauthorized {
520+
// Drain and close body immediately so the HTTP transport can reuse the connection.
521+
_, _ = io.Copy(io.Discard, resp.Body)
522+
resp.Body.Close()
523+
526524
newStorage, err := refreshAccessToken(ctx, storage.RefreshToken)
527525
if err != nil {
528526
if err == tui.ErrRefreshTokenExpired {
@@ -531,9 +529,7 @@ func makeAPICallWithAutoRefresh(ctx context.Context, storage *tui.TokenStorage)
531529
return fmt.Errorf("refresh failed: %w", err)
532530
}
533531

534-
storage.AccessToken = newStorage.AccessToken
535-
storage.RefreshToken = newStorage.RefreshToken
536-
storage.ExpiresAt = newStorage.ExpiresAt
532+
*storage = *newStorage
537533

538534
req, err = http.NewRequestWithContext(
539535
ctx,

tui/model.go

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,7 @@ func (m OAuthModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
196196
case msgTokenRefreshed:
197197
if msg.err != nil {
198198
if isContextCanceled(msg.err) {
199-
m.ExitCode = 130
200-
m.interrupted = true
201-
return m, tea.Quit
199+
return m.quitInterrupted()
202200
}
203201
m.stepStatuses[stepRefreshToken] = statusFailed
204202
m.stepMessages[stepRefreshToken] = msg.err.Error()
@@ -216,9 +214,7 @@ func (m OAuthModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
216214
case msgAuthFlowReady:
217215
if msg.err != nil {
218216
if isContextCanceled(msg.err) {
219-
m.ExitCode = 130
220-
m.interrupted = true
221-
return m, tea.Quit
217+
return m.quitInterrupted()
222218
}
223219
m.stepStatuses[stepAuthFlow] = statusFailed
224220
m.stepMessages[stepAuthFlow] = msg.err.Error()
@@ -246,9 +242,7 @@ func (m OAuthModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
246242
case msgCallbackReceived:
247243
if msg.err != nil {
248244
if isContextCanceled(msg.err) {
249-
m.ExitCode = 130
250-
m.interrupted = true
251-
return m, tea.Quit
245+
return m.quitInterrupted()
252246
}
253247
m.stepStatuses[stepWaitCallback] = statusFailed
254248
m.stepMessages[stepWaitCallback] = msg.err.Error()
@@ -267,9 +261,7 @@ func (m OAuthModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
267261
case msgTokenVerified:
268262
if msg.err != nil {
269263
if isContextCanceled(msg.err) {
270-
m.ExitCode = 130
271-
m.interrupted = true
272-
return m, tea.Quit
264+
return m.quitInterrupted()
273265
}
274266
// Verification failure is non-fatal — still proceed to API call.
275267
m.stepStatuses[stepVerifyToken] = statusFailed
@@ -283,9 +275,7 @@ func (m OAuthModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
283275
case msgAPICallDone:
284276
if msg.err != nil {
285277
if isContextCanceled(msg.err) {
286-
m.ExitCode = 130
287-
m.interrupted = true
288-
return m, tea.Quit
278+
return m.quitInterrupted()
289279
}
290280
if errors.Is(msg.err, ErrRefreshTokenExpired) {
291281
// Refresh token expired during API call — restart auth sub-steps.
@@ -324,3 +314,10 @@ func (m OAuthModel) startStep(s step, cmd tea.Cmd) (tea.Model, tea.Cmd) {
324314
func isContextCanceled(err error) bool {
325315
return errors.Is(err, context.Canceled)
326316
}
317+
318+
// quitInterrupted marks the model as interrupted (exit code 130) and returns tea.Quit.
319+
func (m OAuthModel) quitInterrupted() (tea.Model, tea.Cmd) {
320+
m.ExitCode = 130
321+
m.interrupted = true
322+
return m, tea.Quit
323+
}

0 commit comments

Comments
 (0)