diff --git a/internal/authutil/credentials.go b/internal/authutil/credentials.go index e0b3ee6..af77f54 100644 --- a/internal/authutil/credentials.go +++ b/internal/authutil/credentials.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "strings" + "sync" "go.datum.net/datumctl/internal/keyring" "golang.org/x/oauth2" @@ -86,10 +87,59 @@ func GetStoredCredentials(userKey string) (*StoredCredentials, error) { return &creds, nil } +// persistingTokenSource wraps an oauth2.TokenSource and persists token updates to the keyring. +type persistingTokenSource struct { + ctx context.Context + source oauth2.TokenSource + userKey string + creds *StoredCredentials + mu sync.Mutex +} + +// Token implements oauth2.TokenSource. +// It retrieves a token from the underlying source and persists it to the keyring if refreshed. +func (p *persistingTokenSource) Token() (*oauth2.Token, error) { + p.mu.Lock() + defer p.mu.Unlock() + + currentAccessToken := "" + if p.creds.Token != nil { + currentAccessToken = p.creds.Token.AccessToken + } + + // Get token from the underlying source (may trigger refresh) + newToken, err := p.source.Token() + if err != nil { + var retrieveErr *oauth2.RetrieveError + if errors.As(err, &retrieveErr) { + if retrieveErr.ErrorCode == "invalid_grant" || retrieveErr.ErrorCode == "invalid_request" { + return nil, fmt.Errorf("Authentication session has expired or refresh token is no longer valid. Please re-authenticate using: `datumctl auth login`") + } + } + return nil, err + } + + // Persist the token if it was refreshed + if newToken.AccessToken != currentAccessToken { + p.creds.Token = newToken + + credsJSON, marshalErr := json.Marshal(p.creds) + if marshalErr != nil { + return newToken, fmt.Errorf("failed to marshal updated credentials: %w", marshalErr) + } + + if setErr := keyring.Set(ServiceName, p.userKey, string(credsJSON)); setErr != nil { + return newToken, fmt.Errorf("failed to persist refreshed token to keyring: %w", setErr) + } + } + + return newToken, nil +} + // GetTokenSource creates an oauth2.TokenSource for the active user. -// This source will automatically refresh the token if it's expired. +// This source will automatically refresh the token if it's expired and persist updates to the keyring. func GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) { - creds, _, err := GetActiveCredentials() + creds, userKey, err := GetActiveCredentials() if err != nil { return nil, err } @@ -105,9 +155,16 @@ func GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) { // RedirectURL not needed for token refresh } - // Create a TokenSource with the stored token - // The oauth2 library handles refresh using the context, config, and refresh token. - return conf.TokenSource(ctx, creds.Token), nil + // Create the base TokenSource with the stored token + baseSource := conf.TokenSource(ctx, creds.Token) + + // Wrap it with our persisting source + return &persistingTokenSource{ + ctx: ctx, + source: baseSource, + userKey: userKey, + creds: creds, + }, nil } // GetUserIDFromToken extracts the user ID (sub claim) from the stored credentials. diff --git a/internal/cmd/auth/get_token.go b/internal/cmd/auth/get_token.go index 3ddb8d2..e2d3c21 100644 --- a/internal/cmd/auth/get_token.go +++ b/internal/cmd/auth/get_token.go @@ -9,8 +9,6 @@ import ( "github.com/spf13/cobra" "go.datum.net/datumctl/internal/authutil" - "go.datum.net/datumctl/internal/keyring" - "golang.org/x/oauth2" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" clientauthv1 "k8s.io/client-go/pkg/apis/clientauthentication/v1" ) @@ -49,63 +47,19 @@ func runGetToken(cmd *cobra.Command, args []string) error { return fmt.Errorf("invalid --output format %q. Must be %s or %s", outputFormat, outputFormatToken, outputFormatK8sV1Creds) } - // Get Active User Credential - activeUserKey, err := keyring.Get(authutil.ServiceName, authutil.ActiveUserKey) + // Get the token source (which handles refresh and persistence automatically) + tokenSource, err := authutil.GetTokenSource(ctx) if err != nil { - if errors.Is(err, keyring.ErrNotFound) { + if errors.Is(err, authutil.ErrNoActiveUser) { return errors.New("no active user found in keyring. Please login first using 'datumctl auth login'") } - return fmt.Errorf("failed to get active user key from keyring: %w", err) + return fmt.Errorf("failed to get token source: %w", err) } - credsJSON, err := keyring.Get(authutil.ServiceName, activeUserKey) - if err != nil { - return fmt.Errorf("failed to get credentials for active user '%s' from keyring", activeUserKey) - } - - var foundCreds authutil.StoredCredentials - if err := json.Unmarshal([]byte(credsJSON), &foundCreds); err != nil { - return fmt.Errorf("failed to parse stored credential JSON for active user '%s'", activeUserKey) - } - foundUserKey := activeUserKey - - // Check if Token pointer is nil - if foundCreds.Token == nil { - return fmt.Errorf("internal error: stored token for active user '%s' is nil", foundUserKey) - } - - // Create oauth2.Config - conf := &oauth2.Config{ - ClientID: foundCreds.ClientID, - Scopes: foundCreds.Scopes, - Endpoint: oauth2.Endpoint{ - AuthURL: foundCreds.EndpointAuthURL, - TokenURL: foundCreds.EndpointTokenURL, - }, - } - - // Create TokenSource - currentToken := *foundCreds.Token - tokenSource := conf.TokenSource(ctx, ¤tToken) - - // Get fresh token + // Get fresh token (will refresh if needed and persist automatically) newToken, err := tokenSource.Token() if err != nil { - return fmt.Errorf("failed to refresh token for active user '%s': %w", foundUserKey, err) - } - - // Update keyring if refreshed - if newToken.AccessToken != currentToken.AccessToken { - updatedCreds := foundCreds - updatedCreds.Token = newToken - credsJSONBytes, err := json.Marshal(updatedCreds) - if err == nil { - err = keyring.Set(authutil.ServiceName, foundUserKey, string(credsJSONBytes)) - if err != nil { - // Print a warning instead of silently ignoring. - fmt.Fprintf(os.Stderr, "Warning: failed to update refreshed token in keyring for user '%s': %v\n", foundUserKey, err) - } - } // If marshalling failed, we can't save anyway, maybe log this too? (Optional) + return fmt.Errorf("failed to get token: %w", err) } // --- Output based on requested format ---