Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 62 additions & 5 deletions internal/authutil/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"strings"
"sync"

"go.datum.net/datumctl/internal/keyring"
"golang.org/x/oauth2"
Expand Down Expand Up @@ -86,10 +87,59 @@ func GetStoredCredentials(userKey string) (*StoredCredentials, error) {
return &creds, nil
}

// persistingTokenSource wraps an oauth2.TokenSource and persists token updates to the keyring.
type persistingTokenSource struct {
ctx context.Context
source oauth2.TokenSource
userKey string
creds *StoredCredentials
mu sync.Mutex
}

// Token implements oauth2.TokenSource.
// It retrieves a token from the underlying source and persists it to the keyring if refreshed.
func (p *persistingTokenSource) Token() (*oauth2.Token, error) {
p.mu.Lock()
defer p.mu.Unlock()

currentAccessToken := ""
if p.creds.Token != nil {
currentAccessToken = p.creds.Token.AccessToken
}

// Get token from the underlying source (may trigger refresh)
newToken, err := p.source.Token()
if err != nil {
var retrieveErr *oauth2.RetrieveError
if errors.As(err, &retrieveErr) {
if retrieveErr.ErrorCode == "invalid_grant" || retrieveErr.ErrorCode == "invalid_request" {
return nil, fmt.Errorf("Authentication session has expired or refresh token is no longer valid. Please re-authenticate using: `datumctl auth login`")
}
}
return nil, err
}

// Persist the token if it was refreshed
if newToken.AccessToken != currentAccessToken {
p.creds.Token = newToken

credsJSON, marshalErr := json.Marshal(p.creds)
if marshalErr != nil {
return newToken, fmt.Errorf("failed to marshal updated credentials: %w", marshalErr)
}

if setErr := keyring.Set(ServiceName, p.userKey, string(credsJSON)); setErr != nil {
return newToken, fmt.Errorf("failed to persist refreshed token to keyring: %w", setErr)
}
}

return newToken, nil
}

// GetTokenSource creates an oauth2.TokenSource for the active user.
// This source will automatically refresh the token if it's expired.
// This source will automatically refresh the token if it's expired and persist updates to the keyring.
func GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {
creds, _, err := GetActiveCredentials()
creds, userKey, err := GetActiveCredentials()
if err != nil {
return nil, err
}
Expand All @@ -105,9 +155,16 @@ func GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {
// RedirectURL not needed for token refresh
}

// Create a TokenSource with the stored token
// The oauth2 library handles refresh using the context, config, and refresh token.
return conf.TokenSource(ctx, creds.Token), nil
// Create the base TokenSource with the stored token
baseSource := conf.TokenSource(ctx, creds.Token)

// Wrap it with our persisting source
return &persistingTokenSource{
ctx: ctx,
source: baseSource,
userKey: userKey,
creds: creds,
}, nil
}

// GetUserIDFromToken extracts the user ID (sub claim) from the stored credentials.
Expand Down
58 changes: 6 additions & 52 deletions internal/cmd/auth/get_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ import (

"github.com/spf13/cobra"
"go.datum.net/datumctl/internal/authutil"
"go.datum.net/datumctl/internal/keyring"
"golang.org/x/oauth2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
clientauthv1 "k8s.io/client-go/pkg/apis/clientauthentication/v1"
)
Expand Down Expand Up @@ -49,63 +47,19 @@ func runGetToken(cmd *cobra.Command, args []string) error {
return fmt.Errorf("invalid --output format %q. Must be %s or %s", outputFormat, outputFormatToken, outputFormatK8sV1Creds)
}

// Get Active User Credential
activeUserKey, err := keyring.Get(authutil.ServiceName, authutil.ActiveUserKey)
// Get the token source (which handles refresh and persistence automatically)
tokenSource, err := authutil.GetTokenSource(ctx)
if err != nil {
if errors.Is(err, keyring.ErrNotFound) {
if errors.Is(err, authutil.ErrNoActiveUser) {
return errors.New("no active user found in keyring. Please login first using 'datumctl auth login'")
}
return fmt.Errorf("failed to get active user key from keyring: %w", err)
return fmt.Errorf("failed to get token source: %w", err)
}

credsJSON, err := keyring.Get(authutil.ServiceName, activeUserKey)
if err != nil {
return fmt.Errorf("failed to get credentials for active user '%s' from keyring", activeUserKey)
}

var foundCreds authutil.StoredCredentials
if err := json.Unmarshal([]byte(credsJSON), &foundCreds); err != nil {
return fmt.Errorf("failed to parse stored credential JSON for active user '%s'", activeUserKey)
}
foundUserKey := activeUserKey

// Check if Token pointer is nil
if foundCreds.Token == nil {
return fmt.Errorf("internal error: stored token for active user '%s' is nil", foundUserKey)
}

// Create oauth2.Config
conf := &oauth2.Config{
ClientID: foundCreds.ClientID,
Scopes: foundCreds.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: foundCreds.EndpointAuthURL,
TokenURL: foundCreds.EndpointTokenURL,
},
}

// Create TokenSource
currentToken := *foundCreds.Token
tokenSource := conf.TokenSource(ctx, &currentToken)

// Get fresh token
// Get fresh token (will refresh if needed and persist automatically)
newToken, err := tokenSource.Token()
if err != nil {
return fmt.Errorf("failed to refresh token for active user '%s': %w", foundUserKey, err)
}

// Update keyring if refreshed
if newToken.AccessToken != currentToken.AccessToken {
updatedCreds := foundCreds
updatedCreds.Token = newToken
credsJSONBytes, err := json.Marshal(updatedCreds)
if err == nil {
err = keyring.Set(authutil.ServiceName, foundUserKey, string(credsJSONBytes))
if err != nil {
// Print a warning instead of silently ignoring.
fmt.Fprintf(os.Stderr, "Warning: failed to update refreshed token in keyring for user '%s': %v\n", foundUserKey, err)
}
} // If marshalling failed, we can't save anyway, maybe log this too? (Optional)
return fmt.Errorf("failed to get token: %w", err)
}

// --- Output based on requested format ---
Expand Down