Skip to content

Commit 2fb3183

Browse files
appleboyclaude
andcommitted
refactor(oauth): fix double body close, harden security, and clean up code
- Fix double resp.Body.Close() in makeAPICallWithAutoRefresh by moving defer after the 401 retry branch - Use crypto/subtle.ConstantTimeCompare for CSRF state validation - Replace panic with os.Exit(1) for retry client init failure - Extract isRefreshTokenError helper to deduplicate error parsing - Remove redundant tokenStoreMode global variable - Eliminate err2 naming by reusing err after consumption - Replace custom containsSubstring with strings.Contains in tests Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 84b3489 commit 2fb3183

3 files changed

Lines changed: 26 additions & 34 deletions

File tree

callback.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"context"
5+
"crypto/subtle"
56
"fmt"
67
"html"
78
"net"
@@ -71,9 +72,9 @@ func startCallbackServer(
7172
return
7273
}
7374

74-
// Validate state (CSRF protection).
75+
// Validate state (CSRF protection) using constant-time comparison.
7576
state := q.Get("state")
76-
if state != expectedState {
77+
if subtle.ConstantTimeCompare([]byte(state), []byte(expectedState)) != 1 {
7778
writeCallbackPage(w, false, "state_mismatch",
7879
"State parameter does not match. Possible CSRF attack.")
7980
sendResult(callbackResult{

main.go

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ var (
3636
callbackPort int
3737
scope string
3838
tokenFile string
39-
tokenStoreMode string
4039
tokenStore credstore.Store[credstore.Token]
4140
configOnce sync.Once
4241
retryClient *retry.Client
@@ -170,16 +169,16 @@ func doInitConfig() {
170169
var err error
171170
retryClient, err = retry.NewBackgroundClient(retry.WithHTTPClient(baseHTTPClient))
172171
if err != nil {
173-
panic(fmt.Sprintf("failed to create retry client: %v", err))
172+
fmt.Fprintf(os.Stderr, "Error: failed to create retry client: %v\n", err)
173+
os.Exit(1)
174174
}
175175

176176
const defaultKeyringService = "authgate-oauth-cli"
177-
tokenStoreMode = getConfig(*flagTokenStore, "TOKEN_STORE", "auto")
177+
tokenStoreMode := getConfig(*flagTokenStore, "TOKEN_STORE", "auto")
178178
var warnings []string
179-
var err2 error
180-
tokenStore, warnings, err2 = initTokenStore(tokenStoreMode, tokenFile, defaultKeyringService)
181-
if err2 != nil {
182-
fmt.Fprintf(os.Stderr, "Error: %v\n", err2)
179+
tokenStore, warnings, err = initTokenStore(tokenStoreMode, tokenFile, defaultKeyringService)
180+
if err != nil {
181+
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
183182
os.Exit(1)
184183
}
185184
configWarnings = append(configWarnings, warnings...)
@@ -298,6 +297,16 @@ func parseOAuthError(statusCode int, body []byte, action string) error {
298297
return fmt.Errorf("%s failed with status %d: %s", action, statusCode, string(body))
299298
}
300299

300+
// isRefreshTokenError checks whether the response body indicates an expired
301+
// or invalid refresh token (invalid_grant / invalid_token).
302+
func isRefreshTokenError(body []byte) bool {
303+
var errResp ErrorResponse
304+
if err := json.Unmarshal(body, &errResp); err == nil {
305+
return errResp.Error == "invalid_grant" || errResp.Error == "invalid_token"
306+
}
307+
return false
308+
}
309+
301310
// validateTokenResponse performs basic sanity checks on a token response.
302311
func validateTokenResponse(accessToken, tokenType string, expiresIn int) error {
303312
if accessToken == "" {
@@ -438,12 +447,8 @@ func refreshAccessToken(ctx context.Context, refreshToken string) (*tui.TokenSto
438447

439448
if resp.StatusCode != http.StatusOK {
440449
// Check for expired/invalid refresh token before general error handling.
441-
var errResp ErrorResponse
442-
if jsonErr := json.Unmarshal(body, &errResp); jsonErr == nil && errResp.Error != "" {
443-
if errResp.Error == "invalid_grant" || errResp.Error == "invalid_token" {
444-
return nil, tui.ErrRefreshTokenExpired
445-
}
446-
return nil, fmt.Errorf("%s: %s", errResp.Error, errResp.ErrorDescription)
450+
if isRefreshTokenError(body) {
451+
return nil, tui.ErrRefreshTokenExpired
447452
}
448453
return nil, parseOAuthError(resp.StatusCode, body, "refresh")
449454
}
@@ -522,10 +527,9 @@ func makeAPICallWithAutoRefresh(ctx context.Context, storage *tui.TokenStorage)
522527
if err != nil {
523528
return fmt.Errorf("API request failed: %w", err)
524529
}
525-
defer resp.Body.Close()
526530

527531
if resp.StatusCode == http.StatusUnauthorized {
528-
// Drain and close body immediately so the HTTP transport can reuse the connection.
532+
// Drain and close body so the HTTP transport can reuse the connection.
529533
_, _ = io.Copy(io.Discard, resp.Body)
530534
resp.Body.Close()
531535

@@ -554,8 +558,8 @@ func makeAPICallWithAutoRefresh(ctx context.Context, storage *tui.TokenStorage)
554558
if err != nil {
555559
return fmt.Errorf("retry failed: %w", err)
556560
}
557-
defer resp.Body.Close()
558561
}
562+
defer resp.Body.Close()
559563

560564
body, err := readResponseBody(resp.Body)
561565
if err != nil {

main_test.go

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"errors"
66
"path/filepath"
7+
"strings"
78
"testing"
89
"time"
910

@@ -179,7 +180,7 @@ func TestBuildAuthURL_ContainsRequiredParams(t *testing.T) {
179180
"code_challenge=test-challenge",
180181
"code_challenge_method=S256",
181182
} {
182-
if !containsSubstring(u, want) {
183+
if !strings.Contains(u, want) {
183184
t.Errorf("auth URL missing %q\nURL: %s", want, u)
184185
}
185186
}
@@ -273,7 +274,7 @@ func TestInitTokenStore_Invalid(t *testing.T) {
273274
if store != nil {
274275
t.Errorf("expected nil store on error, got %T", store)
275276
}
276-
if !containsSubstring(err.Error(), "invalid token-store value") {
277+
if !strings.Contains(err.Error(), "invalid token-store value") {
277278
t.Errorf("unexpected error message: %v", err)
278279
}
279280
}
@@ -309,17 +310,3 @@ func TestReadResponseBody(t *testing.T) {
309310
}
310311
})
311312
}
312-
313-
// containsSubstring is a helper to avoid importing strings in tests.
314-
func containsSubstring(s, sub string) bool {
315-
return len(s) >= len(sub) && findSubstring(s, sub)
316-
}
317-
318-
func findSubstring(s, sub string) bool {
319-
for i := 0; i <= len(s)-len(sub); i++ {
320-
if s[i:i+len(sub)] == sub {
321-
return true
322-
}
323-
}
324-
return false
325-
}

0 commit comments

Comments
 (0)