Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"context"
"crypto/subtle"
"fmt"
"html"
"net"
Expand Down Expand Up @@ -71,9 +72,9 @@ func startCallbackServer(
return
}

// Validate state (CSRF protection).
// Validate state (CSRF protection) using constant-time comparison.
state := q.Get("state")
if state != expectedState {
if subtle.ConstantTimeCompare([]byte(state), []byte(expectedState)) != 1 {
Comment thread
appleboy marked this conversation as resolved.
Outdated
writeCallbackPage(w, false, "state_mismatch",
"State parameter does not match. Possible CSRF attack.")
sendResult(callbackResult{
Expand Down
36 changes: 20 additions & 16 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ var (
callbackPort int
scope string
tokenFile string
tokenStoreMode string
tokenStore credstore.Store[credstore.Token]
configOnce sync.Once
retryClient *retry.Client
Expand Down Expand Up @@ -170,16 +169,16 @@ func doInitConfig() {
var err error
retryClient, err = retry.NewBackgroundClient(retry.WithHTTPClient(baseHTTPClient))
if err != nil {
panic(fmt.Sprintf("failed to create retry client: %v", err))
fmt.Fprintf(os.Stderr, "Error: failed to create retry client: %v\n", err)
os.Exit(1)
}

const defaultKeyringService = "authgate-oauth-cli"
tokenStoreMode = getConfig(*flagTokenStore, "TOKEN_STORE", "auto")
tokenStoreMode := getConfig(*flagTokenStore, "TOKEN_STORE", "auto")
var warnings []string
var err2 error
tokenStore, warnings, err2 = initTokenStore(tokenStoreMode, tokenFile, defaultKeyringService)
if err2 != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err2)
tokenStore, warnings, err = initTokenStore(tokenStoreMode, tokenFile, defaultKeyringService)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
configWarnings = append(configWarnings, warnings...)
Expand Down Expand Up @@ -298,6 +297,16 @@ func parseOAuthError(statusCode int, body []byte, action string) error {
return fmt.Errorf("%s failed with status %d: %s", action, statusCode, string(body))
}

// isRefreshTokenError checks whether the response body indicates an expired
// or invalid refresh token (invalid_grant / invalid_token).
func isRefreshTokenError(body []byte) bool {
var errResp ErrorResponse
if err := json.Unmarshal(body, &errResp); err == nil {
return errResp.Error == "invalid_grant" || errResp.Error == "invalid_token"
}
return false
}

// validateTokenResponse performs basic sanity checks on a token response.
func validateTokenResponse(accessToken, tokenType string, expiresIn int) error {
if accessToken == "" {
Expand Down Expand Up @@ -438,12 +447,8 @@ func refreshAccessToken(ctx context.Context, refreshToken string) (*tui.TokenSto

if resp.StatusCode != http.StatusOK {
// Check for expired/invalid refresh token before general error handling.
var errResp ErrorResponse
if jsonErr := json.Unmarshal(body, &errResp); jsonErr == nil && errResp.Error != "" {
if errResp.Error == "invalid_grant" || errResp.Error == "invalid_token" {
return nil, tui.ErrRefreshTokenExpired
}
return nil, fmt.Errorf("%s: %s", errResp.Error, errResp.ErrorDescription)
if isRefreshTokenError(body) {
return nil, tui.ErrRefreshTokenExpired
}
return nil, parseOAuthError(resp.StatusCode, body, "refresh")
}
Expand Down Expand Up @@ -522,10 +527,9 @@ func makeAPICallWithAutoRefresh(ctx context.Context, storage *tui.TokenStorage)
if err != nil {
return fmt.Errorf("API request failed: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode == http.StatusUnauthorized {
// Drain and close body immediately so the HTTP transport can reuse the connection.
// Drain and close body so the HTTP transport can reuse the connection.
_, _ = io.Copy(io.Discard, resp.Body)
resp.Body.Close()

Expand Down Expand Up @@ -554,8 +558,8 @@ func makeAPICallWithAutoRefresh(ctx context.Context, storage *tui.TokenStorage)
if err != nil {
return fmt.Errorf("retry failed: %w", err)
}
defer resp.Body.Close()
}
defer resp.Body.Close()

body, err := readResponseBody(resp.Body)
if err != nil {
Expand Down
19 changes: 3 additions & 16 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"errors"
"path/filepath"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -179,7 +180,7 @@ func TestBuildAuthURL_ContainsRequiredParams(t *testing.T) {
"code_challenge=test-challenge",
"code_challenge_method=S256",
} {
if !containsSubstring(u, want) {
if !strings.Contains(u, want) {
t.Errorf("auth URL missing %q\nURL: %s", want, u)
}
}
Expand Down Expand Up @@ -273,7 +274,7 @@ func TestInitTokenStore_Invalid(t *testing.T) {
if store != nil {
t.Errorf("expected nil store on error, got %T", store)
}
if !containsSubstring(err.Error(), "invalid token-store value") {
if !strings.Contains(err.Error(), "invalid token-store value") {
t.Errorf("unexpected error message: %v", err)
}
}
Expand Down Expand Up @@ -309,17 +310,3 @@ func TestReadResponseBody(t *testing.T) {
}
})
}

// containsSubstring is a helper to avoid importing strings in tests.
func containsSubstring(s, sub string) bool {
return len(s) >= len(sub) && findSubstring(s, sub)
}

func findSubstring(s, sub string) bool {
for i := 0; i <= len(s)-len(sub); i++ {
if s[i:i+len(sub)] == sub {
return true
}
}
return false
}
Loading