Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"io"
"log/slog"
"net/http"
"net/url"
"os"
"strings"
"time"
Expand All @@ -34,10 +35,12 @@ var (

// SetMetadataURL sets a custom metadata server URL for testing.
// Returns a function that restores the original URL.
func SetMetadataURL(url string) func() {
// WARNING: This function should only be called in test code.
// Set DS9_ALLOW_TEST_OVERRIDES=true to enable in non-test environments.
func SetMetadataURL(urlStr string) func() {
old := metadataURL
oldTestMode := isTestMode
metadataURL = url
metadataURL = urlStr
isTestMode = true // Enable test mode to skip ADC
return func() {
metadataURL = old
Expand Down Expand Up @@ -107,10 +110,14 @@ func accessTokenFromADC(ctx context.Context) (string, error) {
func exchangeRefreshToken(ctx context.Context, clientID, clientSecret, refreshToken string) (string, error) {
tokenURL := "https://oauth2.googleapis.com/token" //nolint:gosec // This is Google's OAuth2 token endpoint, not a hardcoded credential

reqBody := fmt.Sprintf(
"client_id=%s&client_secret=%s&refresh_token=%s&grant_type=refresh_token",
clientID, clientSecret, refreshToken,
)
// Use url.Values for proper URL encoding to prevent parameter injection
form := url.Values{
"client_id": {clientID},
"client_secret": {clientSecret},
"refresh_token": {refreshToken},
"grant_type": {"refresh_token"},
}
reqBody := form.Encode()

req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(reqBody))
if err != nil {
Expand All @@ -133,7 +140,9 @@ func exchangeRefreshToken(ctx context.Context, clientID, clientSecret, refreshTo
if readErr != nil {
return "", fmt.Errorf("token exchange returned %d", resp.StatusCode)
}
return "", fmt.Errorf("token exchange returned %d: %s", resp.StatusCode, string(body))
// Log full error details but return sanitized message to prevent information leakage
slog.ErrorContext(ctx, "OAuth token exchange failed", "status", resp.StatusCode, "response", string(body))
return "", fmt.Errorf("token exchange returned %d", resp.StatusCode)
}

body, err := io.ReadAll(io.LimitReader(resp.Body, maxBodySize))
Expand All @@ -156,9 +165,9 @@ func exchangeRefreshToken(ctx context.Context, clientID, clientSecret, refreshTo
// accessTokenFromMetadata retrieves an access token from the GCP metadata server.
// This is used when running on GCP (GCE, GKE, Cloud Run, etc.).
func accessTokenFromMetadata(ctx context.Context) (string, error) {
url := metadataURL + "/instance/service-accounts/default/token"
reqURL := metadataURL + "/instance/service-accounts/default/token"

req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, http.NoBody)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -197,9 +206,9 @@ func accessTokenFromMetadata(ctx context.Context) (string, error) {

// ProjectID retrieves the project ID from the GCP metadata server.
func ProjectID(ctx context.Context) (string, error) {
url := metadataURL + "/project/project-id"
reqURL := metadataURL + "/project/project-id"

req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, http.NoBody)
if err != nil {
return "", err
}
Expand Down
Loading
Loading