Skip to content

Commit 5b0e054

Browse files
authored
Support reloading OAuth2 key file (#1441)
1 parent fa2b263 commit 5b0e054

File tree

11 files changed

+216
-598
lines changed

11 files changed

+216
-598
lines changed

oauth2/auth_suite_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,26 @@ func (te *MockTokenExchanger) ExchangeDeviceCode(_ context.Context,
5858
return te.ReturnsTokens, te.ReturnsError
5959
}
6060

61+
type MockGrantProvider struct {
62+
keyFile *KeyFile
63+
}
64+
65+
func (mgp *MockGrantProvider) GetGrant(audience string, opts *ClientCredentialsFlowOptions) (
66+
*AuthorizationGrant, error) {
67+
scopes := []string{mgp.keyFile.Scope}
68+
if opts != nil && len(opts.AdditionalScopes) > 0 {
69+
scopes = append(scopes, opts.AdditionalScopes...)
70+
}
71+
return &AuthorizationGrant{
72+
Type: GrantTypeClientCredentials,
73+
Audience: audience,
74+
ClientID: mgp.keyFile.ClientID,
75+
ClientCredentials: mgp.keyFile,
76+
TokenEndpoint: oidcEndpoints.TokenEndpoint,
77+
Scopes: scopes,
78+
}, nil
79+
}
80+
6181
var oidcEndpoints = OIDCWellKnownEndpoints{
6282
AuthorizationEndpoint: "http://issuer/auth/authorize",
6383
TokenEndpoint: "http://issuer/auth/token",

oauth2/cache/cache.go

Lines changed: 18 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ import (
2323
"time"
2424

2525
"github.com/apache/pulsar-client-go/oauth2"
26-
"github.com/apache/pulsar-client-go/oauth2/store"
27-
2826
"github.com/apache/pulsar-client-go/oauth2/clock"
2927
xoauth2 "golang.org/x/oauth2"
3028
)
@@ -43,24 +41,24 @@ const (
4341
)
4442

4543
// tokenCache implements a cache for the token associated with a specific audience.
46-
// it interacts with the store when the access token is near expiration or invalidated.
4744
// it is advisable to use a token cache instance per audience.
4845
type tokenCache struct {
49-
clock clock.Clock
50-
lock sync.Mutex
51-
store store.Store
52-
audience string
53-
refresher oauth2.AuthorizationGrantRefresher
54-
token *xoauth2.Token
46+
clock clock.Clock
47+
lock sync.Mutex
48+
audience string
49+
token *xoauth2.Token
50+
flow *oauth2.ClientCredentialsFlow
5551
}
5652

57-
func NewDefaultTokenCache(store store.Store, audience string,
58-
refresher oauth2.AuthorizationGrantRefresher) (CachingTokenSource, error) {
53+
func NewDefaultTokenCache(audience string,
54+
flow *oauth2.ClientCredentialsFlow) (CachingTokenSource, error) {
55+
if flow == nil {
56+
return nil, fmt.Errorf("flow cannot be nil")
57+
}
5958
cache := &tokenCache{
60-
clock: clock.RealClock{},
61-
store: store,
62-
audience: audience,
63-
refresher: refresher,
59+
clock: clock.RealClock{},
60+
audience: audience,
61+
flow: flow,
6462
}
6563
return cache, nil
6664
}
@@ -77,56 +75,24 @@ func (t *tokenCache) Token() (*xoauth2.Token, error) {
7775
return t.token, nil
7876
}
7977

80-
// load from the store and use the access token if it isn't expired
81-
grant, err := t.store.LoadGrant(t.audience)
78+
grant, err := t.flow.Authorize(t.audience)
8279
if err != nil {
83-
return nil, fmt.Errorf("LoadGrant: %w", err)
84-
}
85-
t.token = grant.Token
86-
if t.token != nil && t.validateAccessToken(*t.token) {
87-
return t.token, nil
80+
return nil, err
8881
}
89-
90-
// obtain and cache a fresh access token
91-
grant, err = t.refresher.Refresh(grant)
92-
if err != nil {
93-
return nil, fmt.Errorf("RefreshGrant: %w", err)
82+
if grant.Token == nil {
83+
return nil, fmt.Errorf("authorization succeeded but no token was returned")
9484
}
9585
t.token = grant.Token
96-
err = t.store.SaveGrant(t.audience, *grant)
97-
if err != nil {
98-
// TODO log rather than throw
99-
return nil, fmt.Errorf("SaveGrant: %w", err)
100-
}
10186

10287
return t.token, nil
10388
}
10489

105-
// InvalidateToken clears the access token (likely due to a response from the resource server).
106-
// Note that the token within the grant may contain a refresh token which should survive.
90+
// InvalidateToken clears the cached access token (likely due to a response from the resource server).
10791
func (t *tokenCache) InvalidateToken() error {
10892
t.lock.Lock()
10993
defer t.lock.Unlock()
11094

111-
previous := t.token
11295
t.token = nil
113-
114-
// clear from the store the access token that was returned earlier (unless the store has since been updated)
115-
if previous == nil || previous.AccessToken == "" {
116-
return nil
117-
}
118-
grant, err := t.store.LoadGrant(t.audience)
119-
if err != nil {
120-
return fmt.Errorf("LoadGrant: %w", err)
121-
}
122-
if grant.Token != nil && grant.Token.AccessToken == previous.AccessToken {
123-
grant.Token.Expiry = time.Unix(0, 0).Add(expiryDelta)
124-
err = t.store.SaveGrant(t.audience, *grant)
125-
if err != nil {
126-
// TODO log rather than throw
127-
return fmt.Errorf("SaveGrant: %w", err)
128-
}
129-
}
13096
return nil
13197
}
13298

oauth2/client_credentials_flow.go

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,10 @@ import (
3030
// ClientCredentialsFlow takes care of the mechanics needed for getting an access
3131
// token using the OAuth 2.0 "Client Credentials Flow"
3232
type ClientCredentialsFlow struct {
33-
options ClientCredentialsFlowOptions
34-
oidcWellKnownEndpoints OIDCWellKnownEndpoints
35-
keyfile *KeyFile
36-
exchanger ClientCredentialsExchanger
37-
clock clock.Clock
33+
options ClientCredentialsFlowOptions
34+
exchanger ClientCredentialsExchanger
35+
grantProvider GrantProvider
36+
clock clock.Clock
3837
}
3938

4039
// ClientCredentialsProvider abstracts getting client credentials
@@ -47,29 +46,24 @@ type ClientCredentialsExchanger interface {
4746
ExchangeClientCredentials(req ClientCredentialsExchangeRequest) (*TokenResult, error)
4847
}
4948

49+
// GrantProvider abstracts the creation of authorization grants from credentials
50+
type GrantProvider interface {
51+
GetGrant(audience string, options *ClientCredentialsFlowOptions) (*AuthorizationGrant, error)
52+
}
53+
5054
type ClientCredentialsFlowOptions struct {
5155
KeyFile string
5256
AdditionalScopes []string
5357
}
5458

55-
func newClientCredentialsFlow(
56-
options ClientCredentialsFlowOptions,
57-
keyfile *KeyFile,
58-
oidcWellKnownEndpoints OIDCWellKnownEndpoints,
59-
exchanger ClientCredentialsExchanger,
60-
clock clock.Clock) *ClientCredentialsFlow {
61-
return &ClientCredentialsFlow{
62-
options: options,
63-
oidcWellKnownEndpoints: oidcWellKnownEndpoints,
64-
keyfile: keyfile,
65-
exchanger: exchanger,
66-
clock: clock,
67-
}
59+
// DefaultGrantProvider provides authorization grants by loading credentials from a key file
60+
type DefaultGrantProvider struct {
6861
}
6962

70-
// NewDefaultClientCredentialsFlow provides an easy way to build up a default
71-
// client credentials flow with all the correct configuration.
72-
func NewDefaultClientCredentialsFlow(options ClientCredentialsFlowOptions) (*ClientCredentialsFlow, error) {
63+
// GetGrant creates an authorization grant by loading credentials from the key file and
64+
// merging the scopes from both the options and the key file configuration
65+
func (p *DefaultGrantProvider) GetGrant(audience string, options *ClientCredentialsFlowOptions) (
66+
*AuthorizationGrant, error) {
7367
credsProvider := NewClientCredentialsProviderFromKeyFile(options.KeyFile)
7468
keyFile, err := credsProvider.GetClientCredentials()
7569
if err != nil {
@@ -80,39 +74,58 @@ func NewDefaultClientCredentialsFlow(options ClientCredentialsFlowOptions) (*Cli
8074
if err != nil {
8175
return nil, err
8276
}
77+
// Merge the scopes of the options AdditionalScopes with the scopes read from the keyFile config
78+
var scopesToAdd []string
79+
if len(options.AdditionalScopes) > 0 {
80+
scopesToAdd = append(scopesToAdd, options.AdditionalScopes...)
81+
}
82+
83+
if keyFile.Scope != "" {
84+
scopesSplit := strings.Fields(keyFile.Scope)
85+
scopesToAdd = append(scopesToAdd, scopesSplit...)
86+
}
87+
88+
return &AuthorizationGrant{
89+
Type: GrantTypeClientCredentials,
90+
Audience: audience,
91+
ClientID: keyFile.ClientID,
92+
ClientCredentials: keyFile,
93+
TokenEndpoint: wellKnownEndpoints.TokenEndpoint,
94+
Scopes: scopesToAdd,
95+
}, nil
96+
}
97+
98+
func newClientCredentialsFlow(
99+
options ClientCredentialsFlowOptions,
100+
exchanger ClientCredentialsExchanger,
101+
grantProvider GrantProvider,
102+
clock clock.Clock) *ClientCredentialsFlow {
103+
return &ClientCredentialsFlow{
104+
options: options,
105+
exchanger: exchanger,
106+
grantProvider: grantProvider,
107+
clock: clock,
108+
}
109+
}
110+
111+
// NewDefaultClientCredentialsFlow provides an easy way to build up a default
112+
// client credentials flow with all the correct configuration.
113+
func NewDefaultClientCredentialsFlow(options ClientCredentialsFlowOptions) (*ClientCredentialsFlow, error) {
83114

84115
tokenRetriever := NewTokenRetriever(&http.Client{})
85116
return newClientCredentialsFlow(
86117
options,
87-
keyFile,
88-
*wellKnownEndpoints,
89118
tokenRetriever,
119+
&DefaultGrantProvider{},
90120
clock.RealClock{}), nil
91121
}
92122

93123
var _ Flow = &ClientCredentialsFlow{}
94124

95125
func (c *ClientCredentialsFlow) Authorize(audience string) (*AuthorizationGrant, error) {
96-
var err error
97-
98-
// Merge the scopes of the options AdditionalScopes with the scopes read from the keyFile config
99-
var scopesToAdd []string
100-
if len(c.options.AdditionalScopes) > 0 {
101-
scopesToAdd = append(scopesToAdd, c.options.AdditionalScopes...)
102-
}
103-
104-
if c.keyfile.Scope != "" {
105-
scopesSplit := strings.Split(c.keyfile.Scope, " ")
106-
scopesToAdd = append(scopesToAdd, scopesSplit...)
107-
}
108-
109-
grant := &AuthorizationGrant{
110-
Type: GrantTypeClientCredentials,
111-
Audience: audience,
112-
ClientID: c.keyfile.ClientID,
113-
ClientCredentials: c.keyfile,
114-
TokenEndpoint: c.oidcWellKnownEndpoints.TokenEndpoint,
115-
Scopes: scopesToAdd,
126+
grant, err := c.grantProvider.GetGrant(audience, &c.options)
127+
if err != nil {
128+
return nil, err
116129
}
117130

118131
// test the credentials and obtain an initial access token

oauth2/client_credentials_flow_test.go

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,25 +54,26 @@ var _ = ginkgo.Describe("ClientCredentialsFlow", func() {
5454

5555
var mockClock clock.Clock
5656
var mockTokenExchanger *MockTokenExchanger
57+
var mockGrantProvider *MockGrantProvider
5758

5859
ginkgo.BeforeEach(func() {
5960
mockClock = testing.NewFakeClock(time.Unix(0, 0))
6061
expectedTokens := TokenResult{AccessToken: "accessToken", RefreshToken: "refreshToken", ExpiresIn: 1234}
6162
mockTokenExchanger = &MockTokenExchanger{
6263
ReturnsTokens: &expectedTokens,
6364
}
65+
mockGrantProvider = &MockGrantProvider{
66+
keyFile: &clientCredentials,
67+
}
6468
})
6569

6670
ginkgo.It("invokes TokenExchanger with credentials", func() {
67-
additionalScope := "additional_scope"
6871
provider := newClientCredentialsFlow(
6972
ClientCredentialsFlowOptions{
70-
KeyFile: "test_keyfile",
71-
AdditionalScopes: []string{additionalScope},
73+
KeyFile: "test_keyfile",
7274
},
73-
&clientCredentials,
74-
oidcEndpoints,
7575
mockTokenExchanger,
76+
mockGrantProvider,
7677
mockClock,
7778
)
7879

@@ -83,7 +84,7 @@ var _ = ginkgo.Describe("ClientCredentialsFlow", func() {
8384
ClientID: clientCredentials.ClientID,
8485
ClientSecret: clientCredentials.ClientSecret,
8586
Audience: "test_audience",
86-
Scopes: []string{additionalScope, clientCredentials.Scope},
87+
Scopes: []string{clientCredentials.Scope},
8788
}))
8889
})
8990

@@ -92,9 +93,8 @@ var _ = ginkgo.Describe("ClientCredentialsFlow", func() {
9293
ClientCredentialsFlowOptions{
9394
KeyFile: "test_keyfile",
9495
},
95-
&clientCredentials,
96-
oidcEndpoints,
9796
mockTokenExchanger,
97+
mockGrantProvider,
9898
mockClock,
9999
)
100100

@@ -112,9 +112,8 @@ var _ = ginkgo.Describe("ClientCredentialsFlow", func() {
112112
ClientCredentialsFlowOptions{
113113
KeyFile: "test_keyfile",
114114
},
115-
&clientCredentials,
116-
oidcEndpoints,
117115
mockTokenExchanger,
116+
mockGrantProvider,
118117
mockClock,
119118
)
120119

0 commit comments

Comments
 (0)