Skip to content

Commit aacf85b

Browse files
committed
fix: persist tokens after refresh
We were encountering an issue with users getting error messages after they've refreshed their access token. I was able to trace this down to the new access token and refresh token not being persisted in the keyring after a refresh occurred. This was resulting in an invalid refresh token error because the previous refresh token that existed in the keychain was revoked after a new token was issued. I also adjusted the error messaging so the user will be given a friendly error message if the refresh token is not valid and they need to re-authenticate.
1 parent 4a87f37 commit aacf85b

File tree

2 files changed

+76
-57
lines changed

2 files changed

+76
-57
lines changed

internal/authutil/credentials.go

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"errors"
99
"fmt"
1010
"strings"
11+
"sync"
1112

1213
"go.datum.net/datumctl/internal/keyring"
1314
"golang.org/x/oauth2"
@@ -86,10 +87,67 @@ func GetStoredCredentials(userKey string) (*StoredCredentials, error) {
8687
return &creds, nil
8788
}
8889

90+
// persistingTokenSource wraps an oauth2.TokenSource and persists token updates to the keyring.
91+
type persistingTokenSource struct {
92+
ctx context.Context
93+
source oauth2.TokenSource
94+
userKey string
95+
creds *StoredCredentials
96+
mu sync.Mutex
97+
}
98+
99+
// Token implements oauth2.TokenSource.
100+
// It retrieves a token from the underlying source and persists it to the keyring if refreshed.
101+
func (p *persistingTokenSource) Token() (*oauth2.Token, error) {
102+
p.mu.Lock()
103+
defer p.mu.Unlock()
104+
105+
// Get the current access token from credentials
106+
currentAccessToken := ""
107+
if p.creds.Token != nil {
108+
currentAccessToken = p.creds.Token.AccessToken
109+
}
110+
111+
// Get token from the underlying source (may trigger refresh)
112+
newToken, err := p.source.Token()
113+
if err != nil {
114+
// Check if this is an OAuth2 refresh error
115+
var retrieveErr *oauth2.RetrieveError
116+
if errors.As(err, &retrieveErr) {
117+
// Check for common refresh token errors
118+
if retrieveErr.ErrorCode == "invalid_grant" || retrieveErr.ErrorCode == "invalid_request" {
119+
return nil, fmt.Errorf("authentication session has expired or refresh token is no longer valid.\nPlease re-authenticate using: datumctl auth login")
120+
}
121+
}
122+
return nil, err
123+
}
124+
125+
// Check if the token was refreshed (access token changed)
126+
if newToken.AccessToken != currentAccessToken {
127+
// Update the stored credentials with the new token
128+
p.creds.Token = newToken
129+
130+
// Persist to keyring
131+
credsJSON, marshalErr := json.Marshal(p.creds)
132+
if marshalErr != nil {
133+
// Log error but don't fail the token retrieval
134+
// The token is still valid in memory for this command
135+
return newToken, fmt.Errorf("failed to marshal updated credentials: %w", marshalErr)
136+
}
137+
138+
if setErr := keyring.Set(ServiceName, p.userKey, string(credsJSON)); setErr != nil {
139+
// Log error but don't fail the token retrieval
140+
return newToken, fmt.Errorf("failed to persist refreshed token to keyring: %w", setErr)
141+
}
142+
}
143+
144+
return newToken, nil
145+
}
146+
89147
// GetTokenSource creates an oauth2.TokenSource for the active user.
90-
// This source will automatically refresh the token if it's expired.
148+
// This source will automatically refresh the token if it's expired and persist updates to the keyring.
91149
func GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {
92-
creds, _, err := GetActiveCredentials()
150+
creds, userKey, err := GetActiveCredentials()
93151
if err != nil {
94152
return nil, err
95153
}
@@ -105,9 +163,16 @@ func GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {
105163
// RedirectURL not needed for token refresh
106164
}
107165

108-
// Create a TokenSource with the stored token
109-
// The oauth2 library handles refresh using the context, config, and refresh token.
110-
return conf.TokenSource(ctx, creds.Token), nil
166+
// Create the base TokenSource with the stored token
167+
baseSource := conf.TokenSource(ctx, creds.Token)
168+
169+
// Wrap it with our persisting source
170+
return &persistingTokenSource{
171+
ctx: ctx,
172+
source: baseSource,
173+
userKey: userKey,
174+
creds: creds,
175+
}, nil
111176
}
112177

113178
// GetUserIDFromToken extracts the user ID (sub claim) from the stored credentials.

internal/cmd/auth/get_token.go

Lines changed: 6 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ import (
99

1010
"github.com/spf13/cobra"
1111
"go.datum.net/datumctl/internal/authutil"
12-
"go.datum.net/datumctl/internal/keyring"
13-
"golang.org/x/oauth2"
1412
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1513
clientauthv1 "k8s.io/client-go/pkg/apis/clientauthentication/v1"
1614
)
@@ -49,63 +47,19 @@ func runGetToken(cmd *cobra.Command, args []string) error {
4947
return fmt.Errorf("invalid --output format %q. Must be %s or %s", outputFormat, outputFormatToken, outputFormatK8sV1Creds)
5048
}
5149

52-
// Get Active User Credential
53-
activeUserKey, err := keyring.Get(authutil.ServiceName, authutil.ActiveUserKey)
50+
// Get the token source (which handles refresh and persistence automatically)
51+
tokenSource, err := authutil.GetTokenSource(ctx)
5452
if err != nil {
55-
if errors.Is(err, keyring.ErrNotFound) {
53+
if errors.Is(err, authutil.ErrNoActiveUser) {
5654
return errors.New("no active user found in keyring. Please login first using 'datumctl auth login'")
5755
}
58-
return fmt.Errorf("failed to get active user key from keyring: %w", err)
56+
return fmt.Errorf("failed to get token source: %w", err)
5957
}
6058

61-
credsJSON, err := keyring.Get(authutil.ServiceName, activeUserKey)
62-
if err != nil {
63-
return fmt.Errorf("failed to get credentials for active user '%s' from keyring", activeUserKey)
64-
}
65-
66-
var foundCreds authutil.StoredCredentials
67-
if err := json.Unmarshal([]byte(credsJSON), &foundCreds); err != nil {
68-
return fmt.Errorf("failed to parse stored credential JSON for active user '%s'", activeUserKey)
69-
}
70-
foundUserKey := activeUserKey
71-
72-
// Check if Token pointer is nil
73-
if foundCreds.Token == nil {
74-
return fmt.Errorf("internal error: stored token for active user '%s' is nil", foundUserKey)
75-
}
76-
77-
// Create oauth2.Config
78-
conf := &oauth2.Config{
79-
ClientID: foundCreds.ClientID,
80-
Scopes: foundCreds.Scopes,
81-
Endpoint: oauth2.Endpoint{
82-
AuthURL: foundCreds.EndpointAuthURL,
83-
TokenURL: foundCreds.EndpointTokenURL,
84-
},
85-
}
86-
87-
// Create TokenSource
88-
currentToken := *foundCreds.Token
89-
tokenSource := conf.TokenSource(ctx, &currentToken)
90-
91-
// Get fresh token
59+
// Get fresh token (will refresh if needed and persist automatically)
9260
newToken, err := tokenSource.Token()
9361
if err != nil {
94-
return fmt.Errorf("failed to refresh token for active user '%s': %w", foundUserKey, err)
95-
}
96-
97-
// Update keyring if refreshed
98-
if newToken.AccessToken != currentToken.AccessToken {
99-
updatedCreds := foundCreds
100-
updatedCreds.Token = newToken
101-
credsJSONBytes, err := json.Marshal(updatedCreds)
102-
if err == nil {
103-
err = keyring.Set(authutil.ServiceName, foundUserKey, string(credsJSONBytes))
104-
if err != nil {
105-
// Print a warning instead of silently ignoring.
106-
fmt.Fprintf(os.Stderr, "Warning: failed to update refreshed token in keyring for user '%s': %v\n", foundUserKey, err)
107-
}
108-
} // If marshalling failed, we can't save anyway, maybe log this too? (Optional)
62+
return fmt.Errorf("failed to get token: %w", err)
10963
}
11064

11165
// --- Output based on requested format ---

0 commit comments

Comments
 (0)