Skip to content

Commit 08d163d

Browse files
madhav-dbclaude
andcommitted
Address PR review comments
- Reduce token expiry buffer from 5 minutes to 30 seconds (matches SDK standard) - Add detailed documentation to TokenProviderAuthenticator explaining flow - Add ctx.Err() check in ExternalTokenProvider for cancellation support - Rename tokenFunc to tokenSource for better clarity - Remove duplicate empty token validation from ExternalTokenProvider - Update tests to reflect changes 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent c7bd4b2 commit 08d163d

File tree

4 files changed

+40
-27
lines changed

4 files changed

+40
-27
lines changed

auth/tokenprovider/authenticator.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,16 @@ import (
99
"github.com/rs/zerolog/log"
1010
)
1111

12-
// TokenProviderAuthenticator implements auth.Authenticator using a TokenProvider
12+
// TokenProviderAuthenticator implements auth.Authenticator using a TokenProvider.
13+
//
14+
// Authentication Flow:
15+
// 1. On each Authenticate() call, retrieves a token from the configured TokenProvider
16+
// 2. The provider may implement its own caching and refresh logic
17+
// 3. Validates the returned token is non-empty
18+
// 4. Sets the Authorization header with the token type and value
19+
//
20+
// The authenticator delegates all token management (caching, refresh, expiry)
21+
// to the underlying TokenProvider implementation.
1322
type TokenProviderAuthenticator struct {
1423
provider TokenProvider
1524
}

auth/tokenprovider/external.go

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,41 +6,43 @@ import (
66
"time"
77
)
88

9-
// ExternalTokenProvider provides tokens from an external source (passthrough)
9+
// ExternalTokenProvider provides tokens from an external source (passthrough).
10+
// This provider calls a user-supplied function to retrieve tokens on-demand.
1011
type ExternalTokenProvider struct {
11-
tokenFunc func() (string, error)
12-
tokenType string
12+
tokenSource func() (string, error)
13+
tokenType string
1314
}
1415

1516
// NewExternalTokenProvider creates a provider that gets tokens from an external function
16-
func NewExternalTokenProvider(tokenFunc func() (string, error)) *ExternalTokenProvider {
17+
func NewExternalTokenProvider(tokenSource func() (string, error)) *ExternalTokenProvider {
1718
return &ExternalTokenProvider{
18-
tokenFunc: tokenFunc,
19-
tokenType: "Bearer",
19+
tokenSource: tokenSource,
20+
tokenType: "Bearer",
2021
}
2122
}
2223

2324
// NewExternalTokenProviderWithType creates a provider with a custom token type
24-
func NewExternalTokenProviderWithType(tokenFunc func() (string, error), tokenType string) *ExternalTokenProvider {
25+
func NewExternalTokenProviderWithType(tokenSource func() (string, error), tokenType string) *ExternalTokenProvider {
2526
return &ExternalTokenProvider{
26-
tokenFunc: tokenFunc,
27-
tokenType: tokenType,
27+
tokenSource: tokenSource,
28+
tokenType: tokenType,
2829
}
2930
}
3031

3132
// GetToken retrieves the token from the external source
3233
func (p *ExternalTokenProvider) GetToken(ctx context.Context) (*Token, error) {
33-
if p.tokenFunc == nil {
34-
return nil, fmt.Errorf("external token provider: token function is nil")
34+
// Check for cancellation first
35+
if err := ctx.Err(); err != nil {
36+
return nil, fmt.Errorf("external token provider: context cancelled: %w", err)
3537
}
3638

37-
accessToken, err := p.tokenFunc()
38-
if err != nil {
39-
return nil, fmt.Errorf("external token provider: failed to get token: %w", err)
39+
if p.tokenSource == nil {
40+
return nil, fmt.Errorf("external token provider: token source is nil")
4041
}
4142

42-
if accessToken == "" {
43-
return nil, fmt.Errorf("external token provider: empty token returned")
43+
accessToken, err := p.tokenSource()
44+
if err != nil {
45+
return nil, fmt.Errorf("external token provider: failed to get token: %w", err)
4446
}
4547

4648
return &Token{

auth/tokenprovider/provider.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ func (t *Token) IsExpired() bool {
2929
if t.ExpiresAt.IsZero() {
3030
return false // No expiry means token doesn't expire
3131
}
32-
// Consider token expired 5 minutes before actual expiry for safety
33-
return time.Now().Add(5 * time.Minute).After(t.ExpiresAt)
32+
// Consider token expired 30 seconds before actual expiry for safety
33+
// This matches the standard buffer used by other Databricks SDKs
34+
return time.Now().Add(30 * time.Second).After(t.ExpiresAt)
3435
}
3536

3637
// SetAuthHeader sets the Authorization header on an HTTP request

auth/tokenprovider/provider_test.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ func TestToken_IsExpired(t *testing.T) {
4242
expected: false,
4343
},
4444
{
45-
name: "token_expires_within_5_minutes",
45+
name: "token_expires_within_30_seconds",
4646
token: &Token{
4747
AccessToken: "test-token",
48-
ExpiresAt: time.Now().Add(3 * time.Minute),
48+
ExpiresAt: time.Now().Add(15 * time.Second),
4949
},
50-
expected: true, // Should be considered expired due to 5-minute buffer
50+
expected: true, // Should be considered expired due to 30-second buffer
5151
},
5252
}
5353

@@ -171,17 +171,18 @@ func TestExternalTokenProvider(t *testing.T) {
171171
assert.Contains(t, err.Error(), "failed to get token")
172172
})
173173

174-
t.Run("empty_token_error", func(t *testing.T) {
174+
t.Run("empty_token_allowed", func(t *testing.T) {
175175
tokenFunc := func() (string, error) {
176176
return "", nil
177177
}
178178

179179
provider := NewExternalTokenProvider(tokenFunc)
180180
token, err := provider.GetToken(context.Background())
181181

182-
assert.Error(t, err)
183-
assert.Nil(t, token)
184-
assert.Contains(t, err.Error(), "empty token returned")
182+
assert.NoError(t, err)
183+
assert.NotNil(t, token)
184+
assert.Empty(t, token.AccessToken)
185+
// Empty tokens are validated at the authenticator level, not provider level
185186
})
186187

187188
t.Run("nil_function_error", func(t *testing.T) {
@@ -190,7 +191,7 @@ func TestExternalTokenProvider(t *testing.T) {
190191

191192
assert.Error(t, err)
192193
assert.Nil(t, token)
193-
assert.Contains(t, err.Error(), "token function is nil")
194+
assert.Contains(t, err.Error(), "token source is nil")
194195
})
195196

196197
t.Run("custom_token_type", func(t *testing.T) {

0 commit comments

Comments
 (0)