Skip to content

Commit 0926d14

Browse files
authored
Token provider support for Go driver (1/3) (#290)
Implements token provider support for the go driver
2 parents 00d41fe + 08d163d commit 0926d14

File tree

9 files changed

+587
-0
lines changed

9 files changed

+587
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package tokenprovider
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/http"
7+
8+
"github.com/databricks/databricks-sql-go/auth"
9+
"github.com/rs/zerolog/log"
10+
)
11+
12+
// TokenProviderAuthenticator implements auth.Authenticator using a TokenProvider.
13+
//
14+
// Authentication Flow:
15+
// 1. On each Authenticate() call, retrieves a token from the configured TokenProvider
16+
// 2. The provider may implement its own caching and refresh logic
17+
// 3. Validates the returned token is non-empty
18+
// 4. Sets the Authorization header with the token type and value
19+
//
20+
// The authenticator delegates all token management (caching, refresh, expiry)
21+
// to the underlying TokenProvider implementation.
22+
type TokenProviderAuthenticator struct {
23+
provider TokenProvider
24+
}
25+
26+
// NewAuthenticator creates an authenticator from a token provider
27+
func NewAuthenticator(provider TokenProvider) auth.Authenticator {
28+
return &TokenProviderAuthenticator{
29+
provider: provider,
30+
}
31+
}
32+
33+
// Authenticate implements auth.Authenticator
34+
func (a *TokenProviderAuthenticator) Authenticate(r *http.Request) error {
35+
ctx := r.Context()
36+
if ctx == nil {
37+
ctx = context.Background()
38+
}
39+
40+
token, err := a.provider.GetToken(ctx)
41+
if err != nil {
42+
return fmt.Errorf("token provider authenticator: failed to get token: %w", err)
43+
}
44+
45+
if token.AccessToken == "" {
46+
return fmt.Errorf("token provider authenticator: empty access token")
47+
}
48+
49+
token.SetAuthHeader(r)
50+
log.Debug().Msgf("token provider authenticator: authenticated using provider %s", a.provider.Name())
51+
52+
return nil
53+
}
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package tokenprovider
2+
3+
import (
4+
"context"
5+
"errors"
6+
"net/http"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func TestTokenProviderAuthenticator(t *testing.T) {
14+
t.Run("successful_authentication", func(t *testing.T) {
15+
provider := NewStaticTokenProvider("test-token-123")
16+
authenticator := NewAuthenticator(provider)
17+
18+
req, _ := http.NewRequest("GET", "http://example.com", nil)
19+
err := authenticator.Authenticate(req)
20+
21+
require.NoError(t, err)
22+
assert.Equal(t, "Bearer test-token-123", req.Header.Get("Authorization"))
23+
})
24+
25+
t.Run("authentication_with_custom_token_type", func(t *testing.T) {
26+
provider := NewStaticTokenProviderWithType("test-token", "MAC")
27+
authenticator := NewAuthenticator(provider)
28+
29+
req, _ := http.NewRequest("GET", "http://example.com", nil)
30+
err := authenticator.Authenticate(req)
31+
32+
require.NoError(t, err)
33+
assert.Equal(t, "MAC test-token", req.Header.Get("Authorization"))
34+
})
35+
36+
t.Run("authentication_error_propagation", func(t *testing.T) {
37+
provider := &mockProvider{
38+
tokenFunc: func() (*Token, error) {
39+
return nil, errors.New("provider failed")
40+
},
41+
}
42+
authenticator := NewAuthenticator(provider)
43+
44+
req, _ := http.NewRequest("GET", "http://example.com", nil)
45+
err := authenticator.Authenticate(req)
46+
47+
assert.Error(t, err)
48+
assert.Contains(t, err.Error(), "provider failed")
49+
assert.Empty(t, req.Header.Get("Authorization"))
50+
})
51+
52+
t.Run("empty_token_error", func(t *testing.T) {
53+
provider := &mockProvider{
54+
tokenFunc: func() (*Token, error) {
55+
return &Token{
56+
AccessToken: "",
57+
TokenType: "Bearer",
58+
}, nil
59+
},
60+
}
61+
authenticator := NewAuthenticator(provider)
62+
63+
req, _ := http.NewRequest("GET", "http://example.com", nil)
64+
err := authenticator.Authenticate(req)
65+
66+
assert.Error(t, err)
67+
assert.Contains(t, err.Error(), "empty access token")
68+
assert.Empty(t, req.Header.Get("Authorization"))
69+
})
70+
71+
t.Run("uses_request_context", func(t *testing.T) {
72+
ctx, cancel := context.WithCancel(context.Background())
73+
cancel() // Cancel immediately
74+
75+
provider := &mockProvider{
76+
tokenFunc: func() (*Token, error) {
77+
// This would normally check context cancellation
78+
return &Token{
79+
AccessToken: "test-token",
80+
TokenType: "Bearer",
81+
}, nil
82+
},
83+
}
84+
authenticator := NewAuthenticator(provider)
85+
86+
req, _ := http.NewRequestWithContext(ctx, "GET", "http://example.com", nil)
87+
err := authenticator.Authenticate(req)
88+
89+
// Even with cancelled context, this should work as our mock doesn't check it
90+
require.NoError(t, err)
91+
assert.Equal(t, "Bearer test-token", req.Header.Get("Authorization"))
92+
})
93+
94+
t.Run("external_token_integration", func(t *testing.T) {
95+
tokenFunc := func() (string, error) {
96+
return "external-token-456", nil
97+
}
98+
provider := NewExternalTokenProvider(tokenFunc)
99+
authenticator := NewAuthenticator(provider)
100+
101+
req, _ := http.NewRequest("POST", "http://example.com/api", nil)
102+
err := authenticator.Authenticate(req)
103+
104+
require.NoError(t, err)
105+
assert.Equal(t, "Bearer external-token-456", req.Header.Get("Authorization"))
106+
})
107+
}

auth/tokenprovider/external.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package tokenprovider
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"time"
7+
)
8+
9+
// ExternalTokenProvider provides tokens from an external source (passthrough).
10+
// This provider calls a user-supplied function to retrieve tokens on-demand.
11+
type ExternalTokenProvider struct {
12+
tokenSource func() (string, error)
13+
tokenType string
14+
}
15+
16+
// NewExternalTokenProvider creates a provider that gets tokens from an external function
17+
func NewExternalTokenProvider(tokenSource func() (string, error)) *ExternalTokenProvider {
18+
return &ExternalTokenProvider{
19+
tokenSource: tokenSource,
20+
tokenType: "Bearer",
21+
}
22+
}
23+
24+
// NewExternalTokenProviderWithType creates a provider with a custom token type
25+
func NewExternalTokenProviderWithType(tokenSource func() (string, error), tokenType string) *ExternalTokenProvider {
26+
return &ExternalTokenProvider{
27+
tokenSource: tokenSource,
28+
tokenType: tokenType,
29+
}
30+
}
31+
32+
// GetToken retrieves the token from the external source
33+
func (p *ExternalTokenProvider) GetToken(ctx context.Context) (*Token, error) {
34+
// Check for cancellation first
35+
if err := ctx.Err(); err != nil {
36+
return nil, fmt.Errorf("external token provider: context cancelled: %w", err)
37+
}
38+
39+
if p.tokenSource == nil {
40+
return nil, fmt.Errorf("external token provider: token source is nil")
41+
}
42+
43+
accessToken, err := p.tokenSource()
44+
if err != nil {
45+
return nil, fmt.Errorf("external token provider: failed to get token: %w", err)
46+
}
47+
48+
return &Token{
49+
AccessToken: accessToken,
50+
TokenType: p.tokenType,
51+
ExpiresAt: time.Time{}, // External tokens don't provide expiry info
52+
}, nil
53+
}
54+
55+
// Name returns the provider name
56+
func (p *ExternalTokenProvider) Name() string {
57+
return "external"
58+
}

auth/tokenprovider/provider.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package tokenprovider
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"time"
7+
)
8+
9+
// TokenProvider is the interface for providing tokens from various sources
10+
type TokenProvider interface {
11+
// GetToken retrieves a valid access token
12+
GetToken(ctx context.Context) (*Token, error)
13+
14+
// Name returns the provider name for logging/debugging
15+
Name() string
16+
}
17+
18+
// Token represents an access token with metadata
19+
type Token struct {
20+
AccessToken string
21+
TokenType string
22+
ExpiresAt time.Time
23+
RefreshToken string
24+
Scopes []string
25+
}
26+
27+
// IsExpired checks if the token has expired
28+
func (t *Token) IsExpired() bool {
29+
if t.ExpiresAt.IsZero() {
30+
return false // No expiry means token doesn't expire
31+
}
32+
// Consider token expired 30 seconds before actual expiry for safety
33+
// This matches the standard buffer used by other Databricks SDKs
34+
return time.Now().Add(30 * time.Second).After(t.ExpiresAt)
35+
}
36+
37+
// SetAuthHeader sets the Authorization header on an HTTP request
38+
func (t *Token) SetAuthHeader(r *http.Request) {
39+
tokenType := t.TokenType
40+
if tokenType == "" {
41+
tokenType = "Bearer"
42+
}
43+
r.Header.Set("Authorization", tokenType+" "+t.AccessToken)
44+
}

0 commit comments

Comments
 (0)