Skip to content

Commit 5571483

Browse files
committed
Merge branch 'main' into token-provider-federation
2 parents 1fb8c1a + 0926d14 commit 5571483

File tree

12 files changed

+3963
-235
lines changed

12 files changed

+3963
-235
lines changed

DESIGN_MULTI_STATEMENT_TRANSACTIONS.md

Lines changed: 1564 additions & 0 deletions
Large diffs are not rendered by default.

auth/oauth/m2m/m2m.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func (c *authClient) Authenticate(r *http.Request) error {
5757

5858
c.tokenSource = GetTokenSource(config)
5959
token, err := c.tokenSource.Token()
60-
log.Info().Msgf("token fetched successfully")
60+
log.Debug().Msgf("databricks OAuth token fetched successfully")
6161
if err != nil {
6262
log.Err(err).Msg("failed to get token")
6363

auth/tokenprovider/authenticator.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,16 @@ import (
99
"github.com/rs/zerolog/log"
1010
)
1111

12-
// TokenProviderAuthenticator implements auth.Authenticator using a TokenProvider
12+
// TokenProviderAuthenticator implements auth.Authenticator using a TokenProvider.
13+
//
14+
// Authentication Flow:
15+
// 1. On each Authenticate() call, retrieves a token from the configured TokenProvider
16+
// 2. The provider may implement its own caching and refresh logic
17+
// 3. Validates the returned token is non-empty
18+
// 4. Sets the Authorization header with the token type and value
19+
//
20+
// The authenticator delegates all token management (caching, refresh, expiry)
21+
// to the underlying TokenProvider implementation.
1322
type TokenProviderAuthenticator struct {
1423
provider TokenProvider
1524
}

auth/tokenprovider/authenticator_test.go

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -104,32 +104,4 @@ func TestTokenProviderAuthenticator(t *testing.T) {
104104
require.NoError(t, err)
105105
assert.Equal(t, "Bearer external-token-456", req.Header.Get("Authorization"))
106106
})
107-
108-
t.Run("cached_provider_integration", func(t *testing.T) {
109-
callCount := 0
110-
baseProvider := &mockProvider{
111-
tokenFunc: func() (*Token, error) {
112-
callCount++
113-
return &Token{
114-
AccessToken: "cached-token",
115-
TokenType: "Bearer",
116-
}, nil
117-
},
118-
name: "test",
119-
}
120-
121-
cachedProvider := NewCachedTokenProvider(baseProvider)
122-
authenticator := NewAuthenticator(cachedProvider)
123-
124-
// Multiple authentication attempts
125-
for i := 0; i < 3; i++ {
126-
req, _ := http.NewRequest("GET", "http://example.com", nil)
127-
err := authenticator.Authenticate(req)
128-
require.NoError(t, err)
129-
assert.Equal(t, "Bearer cached-token", req.Header.Get("Authorization"))
130-
}
131-
132-
// Should only call base provider once due to caching
133-
assert.Equal(t, 1, callCount)
134-
})
135107
}

auth/tokenprovider/external.go

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,41 +6,43 @@ import (
66
"time"
77
)
88

9-
// ExternalTokenProvider provides tokens from an external source (passthrough)
9+
// ExternalTokenProvider provides tokens from an external source (passthrough).
10+
// This provider calls a user-supplied function to retrieve tokens on-demand.
1011
type ExternalTokenProvider struct {
11-
tokenFunc func() (string, error)
12-
tokenType string
12+
tokenSource func() (string, error)
13+
tokenType string
1314
}
1415

1516
// NewExternalTokenProvider creates a provider that gets tokens from an external function
16-
func NewExternalTokenProvider(tokenFunc func() (string, error)) *ExternalTokenProvider {
17+
func NewExternalTokenProvider(tokenSource func() (string, error)) *ExternalTokenProvider {
1718
return &ExternalTokenProvider{
18-
tokenFunc: tokenFunc,
19-
tokenType: "Bearer",
19+
tokenSource: tokenSource,
20+
tokenType: "Bearer",
2021
}
2122
}
2223

2324
// NewExternalTokenProviderWithType creates a provider with a custom token type
24-
func NewExternalTokenProviderWithType(tokenFunc func() (string, error), tokenType string) *ExternalTokenProvider {
25+
func NewExternalTokenProviderWithType(tokenSource func() (string, error), tokenType string) *ExternalTokenProvider {
2526
return &ExternalTokenProvider{
26-
tokenFunc: tokenFunc,
27-
tokenType: tokenType,
27+
tokenSource: tokenSource,
28+
tokenType: tokenType,
2829
}
2930
}
3031

3132
// GetToken retrieves the token from the external source
3233
func (p *ExternalTokenProvider) GetToken(ctx context.Context) (*Token, error) {
33-
if p.tokenFunc == nil {
34-
return nil, fmt.Errorf("external token provider: token function is nil")
34+
// Check for cancellation first
35+
if err := ctx.Err(); err != nil {
36+
return nil, fmt.Errorf("external token provider: context cancelled: %w", err)
3537
}
3638

37-
accessToken, err := p.tokenFunc()
38-
if err != nil {
39-
return nil, fmt.Errorf("external token provider: failed to get token: %w", err)
39+
if p.tokenSource == nil {
40+
return nil, fmt.Errorf("external token provider: token source is nil")
4041
}
4142

42-
if accessToken == "" {
43-
return nil, fmt.Errorf("external token provider: empty token returned")
43+
accessToken, err := p.tokenSource()
44+
if err != nil {
45+
return nil, fmt.Errorf("external token provider: failed to get token: %w", err)
4446
}
4547

4648
return &Token{

auth/tokenprovider/provider.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ func (t *Token) IsExpired() bool {
2929
if t.ExpiresAt.IsZero() {
3030
return false // No expiry means token doesn't expire
3131
}
32-
// Consider token expired 5 minutes before actual expiry for safety
33-
return time.Now().Add(5 * time.Minute).After(t.ExpiresAt)
32+
// Consider token expired 30 seconds before actual expiry for safety
33+
// This matches the standard buffer used by other Databricks SDKs
34+
return time.Now().Add(30 * time.Second).After(t.ExpiresAt)
3435
}
3536

3637
// SetAuthHeader sets the Authorization header on an HTTP request

auth/tokenprovider/provider_test.go

Lines changed: 9 additions & 187 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ import (
44
"context"
55
"errors"
66
"net/http"
7-
"sync"
8-
"sync/atomic"
97
"testing"
108
"time"
119

@@ -44,12 +42,12 @@ func TestToken_IsExpired(t *testing.T) {
4442
expected: false,
4543
},
4644
{
47-
name: "token_expires_within_5_minutes",
45+
name: "token_expires_within_30_seconds",
4846
token: &Token{
4947
AccessToken: "test-token",
50-
ExpiresAt: time.Now().Add(3 * time.Minute),
48+
ExpiresAt: time.Now().Add(15 * time.Second),
5149
},
52-
expected: true, // Should be considered expired due to 5-minute buffer
50+
expected: true, // Should be considered expired due to 30-second buffer
5351
},
5452
}
5553

@@ -173,17 +171,18 @@ func TestExternalTokenProvider(t *testing.T) {
173171
assert.Contains(t, err.Error(), "failed to get token")
174172
})
175173

176-
t.Run("empty_token_error", func(t *testing.T) {
174+
t.Run("empty_token_allowed", func(t *testing.T) {
177175
tokenFunc := func() (string, error) {
178176
return "", nil
179177
}
180178

181179
provider := NewExternalTokenProvider(tokenFunc)
182180
token, err := provider.GetToken(context.Background())
183181

184-
assert.Error(t, err)
185-
assert.Nil(t, token)
186-
assert.Contains(t, err.Error(), "empty token returned")
182+
assert.NoError(t, err)
183+
assert.NotNil(t, token)
184+
assert.Empty(t, token.AccessToken)
185+
// Empty tokens are validated at the authenticator level, not provider level
187186
})
188187

189188
t.Run("nil_function_error", func(t *testing.T) {
@@ -192,7 +191,7 @@ func TestExternalTokenProvider(t *testing.T) {
192191

193192
assert.Error(t, err)
194193
assert.Nil(t, token)
195-
assert.Contains(t, err.Error(), "token function is nil")
194+
assert.Contains(t, err.Error(), "token source is nil")
196195
})
197196

198197
t.Run("custom_token_type", func(t *testing.T) {
@@ -228,183 +227,6 @@ func TestExternalTokenProvider(t *testing.T) {
228227
})
229228
}
230229

231-
func TestCachedTokenProvider(t *testing.T) {
232-
t.Run("caches_valid_token", func(t *testing.T) {
233-
callCount := 0
234-
baseProvider := &mockProvider{
235-
tokenFunc: func() (*Token, error) {
236-
callCount++
237-
return &Token{
238-
AccessToken: "cached-token",
239-
TokenType: "Bearer",
240-
ExpiresAt: time.Now().Add(1 * time.Hour),
241-
}, nil
242-
},
243-
name: "mock",
244-
}
245-
246-
cachedProvider := NewCachedTokenProvider(baseProvider)
247-
248-
// First call - should fetch from base provider
249-
token1, err1 := cachedProvider.GetToken(context.Background())
250-
require.NoError(t, err1)
251-
assert.Equal(t, "cached-token", token1.AccessToken)
252-
assert.Equal(t, 1, callCount)
253-
254-
// Second call - should use cache
255-
token2, err2 := cachedProvider.GetToken(context.Background())
256-
require.NoError(t, err2)
257-
assert.Equal(t, "cached-token", token2.AccessToken)
258-
assert.Equal(t, 1, callCount) // Should still be 1
259-
})
260-
261-
t.Run("refreshes_expired_token", func(t *testing.T) {
262-
callCount := 0
263-
baseProvider := &mockProvider{
264-
tokenFunc: func() (*Token, error) {
265-
callCount++
266-
// Return token that expires soon
267-
return &Token{
268-
AccessToken: "token-" + string(rune(callCount)),
269-
TokenType: "Bearer",
270-
ExpiresAt: time.Now().Add(2 * time.Minute), // Within refresh threshold
271-
}, nil
272-
},
273-
name: "mock",
274-
}
275-
276-
cachedProvider := NewCachedTokenProvider(baseProvider)
277-
cachedProvider.RefreshThreshold = 5 * time.Minute
278-
279-
// First call
280-
token1, err1 := cachedProvider.GetToken(context.Background())
281-
require.NoError(t, err1)
282-
assert.Equal(t, "token-\x01", token1.AccessToken)
283-
assert.Equal(t, 1, callCount)
284-
285-
// Second call - should refresh because token expires within threshold
286-
token2, err2 := cachedProvider.GetToken(context.Background())
287-
require.NoError(t, err2)
288-
assert.Equal(t, "token-\x02", token2.AccessToken)
289-
assert.Equal(t, 2, callCount)
290-
})
291-
292-
t.Run("handles_provider_error", func(t *testing.T) {
293-
baseProvider := &mockProvider{
294-
tokenFunc: func() (*Token, error) {
295-
return nil, errors.New("provider error")
296-
},
297-
name: "mock",
298-
}
299-
300-
cachedProvider := NewCachedTokenProvider(baseProvider)
301-
token, err := cachedProvider.GetToken(context.Background())
302-
303-
assert.Error(t, err)
304-
assert.Nil(t, token)
305-
assert.Contains(t, err.Error(), "provider error")
306-
})
307-
308-
t.Run("no_expiry_token_not_refreshed", func(t *testing.T) {
309-
callCount := 0
310-
baseProvider := &mockProvider{
311-
tokenFunc: func() (*Token, error) {
312-
callCount++
313-
return &Token{
314-
AccessToken: "permanent-token",
315-
TokenType: "Bearer",
316-
ExpiresAt: time.Time{}, // No expiry
317-
}, nil
318-
},
319-
name: "mock",
320-
}
321-
322-
cachedProvider := NewCachedTokenProvider(baseProvider)
323-
324-
// Multiple calls should all use cache
325-
for i := 0; i < 5; i++ {
326-
token, err := cachedProvider.GetToken(context.Background())
327-
require.NoError(t, err)
328-
assert.Equal(t, "permanent-token", token.AccessToken)
329-
}
330-
331-
assert.Equal(t, 1, callCount) // Should only be called once
332-
})
333-
334-
t.Run("clear_cache", func(t *testing.T) {
335-
callCount := 0
336-
baseProvider := &mockProvider{
337-
tokenFunc: func() (*Token, error) {
338-
callCount++
339-
return &Token{
340-
AccessToken: "token-" + string(rune(callCount)),
341-
TokenType: "Bearer",
342-
ExpiresAt: time.Now().Add(1 * time.Hour),
343-
}, nil
344-
},
345-
name: "mock",
346-
}
347-
348-
cachedProvider := NewCachedTokenProvider(baseProvider)
349-
350-
// First call
351-
token1, _ := cachedProvider.GetToken(context.Background())
352-
assert.Equal(t, "token-\x01", token1.AccessToken)
353-
assert.Equal(t, 1, callCount)
354-
355-
// Clear cache
356-
cachedProvider.ClearCache()
357-
358-
// Next call should fetch new token
359-
token2, _ := cachedProvider.GetToken(context.Background())
360-
assert.Equal(t, "token-\x02", token2.AccessToken)
361-
assert.Equal(t, 2, callCount)
362-
})
363-
364-
t.Run("concurrent_access", func(t *testing.T) {
365-
var callCount atomic.Int32
366-
baseProvider := &mockProvider{
367-
tokenFunc: func() (*Token, error) {
368-
// Simulate slow token fetch
369-
time.Sleep(100 * time.Millisecond)
370-
callCount.Add(1)
371-
return &Token{
372-
AccessToken: "concurrent-token",
373-
TokenType: "Bearer",
374-
ExpiresAt: time.Now().Add(1 * time.Hour),
375-
}, nil
376-
},
377-
name: "mock",
378-
}
379-
380-
cachedProvider := NewCachedTokenProvider(baseProvider)
381-
382-
// Launch multiple goroutines
383-
var wg sync.WaitGroup
384-
for i := 0; i < 10; i++ {
385-
wg.Add(1)
386-
go func() {
387-
defer wg.Done()
388-
token, err := cachedProvider.GetToken(context.Background())
389-
assert.NoError(t, err)
390-
assert.Equal(t, "concurrent-token", token.AccessToken)
391-
}()
392-
}
393-
394-
wg.Wait()
395-
396-
// Should only fetch token once despite concurrent access
397-
assert.Equal(t, int32(1), callCount.Load())
398-
})
399-
400-
t.Run("provider_name", func(t *testing.T) {
401-
baseProvider := &mockProvider{name: "test-provider"}
402-
cachedProvider := NewCachedTokenProvider(baseProvider)
403-
404-
assert.Equal(t, "cached[test-provider]", cachedProvider.Name())
405-
})
406-
}
407-
408230
// Mock provider for testing
409231
type mockProvider struct {
410232
tokenFunc func() (*Token, error)

0 commit comments

Comments
 (0)