Skip to content

Commit e99019c

Browse files
wtrockigssbznjeroenvervaeke
authored
CLOUDP-282641: Use oauth2 token source (#471)
Co-authored-by: Gustavo Bazan <[email protected]> Co-authored-by: Jeroen Vervaeke <[email protected]>
1 parent 625dd5b commit e99019c

File tree

22 files changed

+346
-1170
lines changed

22 files changed

+346
-1170
lines changed

admin/atlas_client.go

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
package admin // import "go.mongodb.org/atlas-sdk/v20241023002/admin"
22

33
import (
4+
"context"
45
"errors"
56
"net/http"
67
"strings"
78

89
"github.com/mongodb-forks/digest"
9-
"go.mongodb.org/atlas-sdk/v20241023002/auth/credentials"
10+
"go.mongodb.org/atlas-sdk/v20241023002/auth"
11+
"go.mongodb.org/atlas-sdk/v20241023002/auth/clientcredentials"
1012
"go.mongodb.org/atlas-sdk/v20241023002/internal/core"
1113
)
1214

@@ -35,7 +37,7 @@ func NewClient(modifiers ...ClientModifier) (*APIClient, error) {
3537
return NewAPIClient(defaultConfig), nil
3638
}
3739

38-
// ClientModifiers lets you create function that controls configuration before creating client.
40+
// ClientModifier lets you create function that controls configuration before creating client.
3941
type ClientModifier func(*Configuration) error
4042

4143
// UseDigestAuth provides Digest authentication for Go SDK.
@@ -53,30 +55,26 @@ func UseDigestAuth(apiKey, apiSecret string) ClientModifier {
5355
}
5456
}
5557

56-
// UseOAuthAuth provides OAuthAuth authentication for Go SDK.
58+
// UseOAuthAuth provides OAuth authentication for Go SDK.
5759
// Method is provided as helper to create a default HTTP client that supports OAuth (Service Accounts) authentication.
58-
// credentials.LocalTokenCache can be supplied to reuse OAuth Token across application restarts.
59-
// Warning: for advanced use cases please use credentials.NewTokenSource directly in your code pass it to UseHTTPClient method.
60+
//
6061
// Warning: any previously set httpClient will be overwritten. To fully customize HttpClient use UseHTTPClient method.
61-
func UseOAuthAuth(clientID, clientSecret string, tokenCache credentials.LocalTokenCache) ClientModifier {
62+
func UseOAuthAuth(ctx context.Context, clientID, clientSecret string) ClientModifier {
63+
ctx2 := ctx
64+
if hc := ctx.Value(auth.HTTPClient); hc == nil {
65+
client := http.DefaultClient
66+
client.Transport = &clientcredentials.Transport{
67+
Base: http.DefaultTransport, UserAgent: core.DefaultUserAgent,
68+
}
69+
ctx2 = context.WithValue(ctx, auth.HTTPClient, client)
70+
}
71+
oauth := clientcredentials.NewConfig(clientID, clientSecret)
6272
return func(c *Configuration) error {
63-
var tokenSource credentials.TokenSource
64-
if tokenCache != nil {
65-
tokenSource = credentials.NewTokenSourceWithOptions(credentials.AtlasTokenSourceOptions{
66-
ClientID: clientID,
67-
ClientSecret: clientSecret,
68-
TokenCache: tokenCache,
69-
BaseURL: &c.Servers[0].URL,
70-
})
71-
} else {
72-
tokenSource = credentials.NewTokenSourceWithOptions(credentials.AtlasTokenSourceOptions{
73-
ClientID: clientID,
74-
ClientSecret: clientSecret,
75-
BaseURL: &c.Servers[0].URL,
76-
})
73+
if len(c.Servers) > 0 {
74+
oauth.TokenURL = c.Servers[0].URL + clientcredentials.TokenAPIPath
75+
oauth.RevokeURL = c.Servers[0].URL + clientcredentials.RevokeAPIPath
7776
}
78-
httpClient := credentials.NewHTTPClientWithOAuthToken(tokenSource)
79-
c.HTTPClient = httpClient
77+
c.HTTPClient = oauth.Client(ctx2)
8078
return nil
8179
}
8280
}
@@ -86,7 +84,7 @@ func UseOAuthAuth(clientID, clientSecret string, tokenCache credentials.LocalTok
8684
// UseHTTPClient set custom http client implementation.
8785
//
8886
// Warning: UseHTTPClient overrides any previously set httpClient including the one set by UseDigestAuth.
89-
// To set a custom http client with HTTP diggest support use:
87+
// To set a custom http client with HTTP digest support use:
9088
//
9189
// transport := digest.NewTransportWithHTTPRoundTripper(apiKey, apiSecret, yourHttpTransport)
9290
// client := UseHTTPClient(transport.Client())

auth/auth.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package auth
2+
3+
import "golang.org/x/oauth2"
4+
5+
// oauth2 alias
6+
var (
7+
// HTTPClient is the context key to use with golang.org/x/net/context's
8+
// WithValue function to associate an *http.Client value with a context.
9+
HTTPClient = oauth2.HTTPClient
10+
// ReuseTokenSource returns a TokenSource which repeatedly returns the
11+
// same token as long as it's valid, starting with t.
12+
// When its cached token is invalid, a new token is obtained from src.
13+
//
14+
// ReuseTokenSource is typically used to reuse tokens from a cache
15+
// (such as a file on disk) between runs of a program, rather than
16+
// obtaining new tokens unnecessarily.
17+
//
18+
// The initial token t may be nil, in which case the TokenSource is
19+
// wrapped in a caching version if it isn't one already. This also
20+
// means it's always safe to wrap ReuseTokenSource around any other
21+
// TokenSource without adverse effects.
22+
ReuseTokenSource = oauth2.ReuseTokenSource
23+
// NewClient creates an *http.Client from a Context and TokenSource.
24+
// The returned client is not valid beyond the lifetime of the context.
25+
//
26+
// Note that if a custom *http.Client is provided via the Context it
27+
// is used only for token acquisition and is not used to configure the
28+
// *http.Client returned from NewClient.
29+
//
30+
// As a special case, if src is nil, a non-OAuth2 client is returned
31+
// using the provided context. This exists to support related OAuth2
32+
// packages.
33+
NewClient = oauth2.NewClient
34+
)
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package clientcredentials
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"golang.org/x/oauth2"
8+
"io"
9+
"net/http"
10+
"net/url"
11+
"strings"
12+
13+
"go.mongodb.org/atlas-sdk/v20241023002/auth"
14+
"go.mongodb.org/atlas-sdk/v20241023002/internal/core"
15+
"golang.org/x/oauth2/clientcredentials"
16+
)
17+
18+
const (
19+
// TokenAPIPath for getting OAuth Access Token from server
20+
TokenAPIPath = "/api/oauth/token" //nolint:gosec //url only
21+
// serverTokenURL for Token Atlas API
22+
serverTokenURL = core.DefaultCloudURL + TokenAPIPath
23+
// RevokeAPIPath for revoking OAuth Access Token from server
24+
RevokeAPIPath = "/api/oauth/revoke"
25+
26+
// serverRevokeURL for Revoke Atlas API
27+
serverRevokeURL = core.DefaultCloudURL + RevokeAPIPath
28+
userAgent = "User-Agent"
29+
)
30+
31+
func NewConfig(clientID, clientSecret string) *Config {
32+
c := &Config{}
33+
c.ClientID = clientID
34+
c.ClientSecret = clientSecret
35+
c.RevokeURL = serverRevokeURL
36+
c.TokenURL = serverTokenURL
37+
c.AuthStyle = oauth2.AuthStyleInHeader
38+
c.userAgent = core.DefaultUserAgent
39+
40+
return c
41+
}
42+
43+
// Config describes a 2-legged OAuth2 flow, with both the
44+
// client application information and the server's endpoint URLs.
45+
//
46+
// NOTE: Config values are used only internally
47+
// and should not be overridden by clients
48+
type Config struct {
49+
clientcredentials.Config
50+
RevokeURL string
51+
userAgent string
52+
}
53+
54+
func (c *Config) Client(ctx context.Context) *http.Client {
55+
client := c.Config.Client(ctx)
56+
client.Transport = &Transport{
57+
Base: client.Transport,
58+
UserAgent: core.DefaultUserAgent,
59+
}
60+
return client
61+
}
62+
63+
// RevokeToken revokes OAuth Token
64+
func (c *Config) RevokeToken(ctx context.Context, t *auth.Token) error {
65+
if c.RevokeURL == "" {
66+
return errors.New("endpoint missing RevokeURL")
67+
}
68+
if !t.Valid() {
69+
return nil // nothing to do
70+
}
71+
v := url.Values{
72+
"token": {t.AccessToken},
73+
"token_type_hint": {"access_token"},
74+
}
75+
76+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.RevokeURL, strings.NewReader(v.Encode()))
77+
if err != nil {
78+
return err
79+
}
80+
req.SetBasicAuth(url.QueryEscape(c.ClientID), url.QueryEscape(c.ClientSecret))
81+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
82+
req.Header.Set(userAgent, c.userAgent)
83+
84+
client := http.DefaultClient
85+
resp, err := client.Do(req)
86+
if err != nil {
87+
return err
88+
}
89+
defer resp.Body.Close()
90+
91+
if resp.StatusCode != http.StatusOK {
92+
if resp.StatusCode == http.StatusTooManyRequests {
93+
msg, _ := io.ReadAll(resp.Body)
94+
formattedMessage := fmt.Sprintf("%s %s: HTTP %v Detail: %v Reason: %v",
95+
http.MethodPost, c.RevokeURL, resp.StatusCode,
96+
"Token Revocation request was rate limited", string(msg))
97+
return errors.New(formattedMessage)
98+
}
99+
formattedMessage := fmt.Sprintf("%s %s: HTTP %v Detail: %v Reason: %v",
100+
http.MethodPost, c.RevokeURL, resp.StatusCode,
101+
"Failed to revoke Access Token when fetching new OAuth Token from remote server",
102+
resp.Header.Get("www-authenticate"))
103+
return errors.New(formattedMessage)
104+
}
105+
return nil
106+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package clientcredentials
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
"time"
9+
10+
"github.com/stretchr/testify/assert"
11+
"go.mongodb.org/atlas-sdk/v20241023002/auth"
12+
)
13+
14+
// mockOAuthRevokeEndpoint creates a mock OAuth revoke endpoint,
15+
// that simulates token revocation responses.
16+
func mockOAuthRevokeEndpoint(statusCode int) *httptest.Server {
17+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
18+
if r.Method != http.MethodPost || r.FormValue("token") == "" {
19+
http.Error(w, "invalid request", http.StatusBadRequest)
20+
return
21+
}
22+
w.WriteHeader(statusCode)
23+
})
24+
return httptest.NewServer(handler)
25+
}
26+
27+
// Test OAuthTokenSource_RevokeToken_Success tests successful token revocation.
28+
func TestOAuthTokenSource_RevokeToken_Success(t *testing.T) {
29+
mockServer := mockOAuthRevokeEndpoint(http.StatusOK)
30+
defer mockServer.Close()
31+
32+
config := NewConfig("clientID", "clientSecret")
33+
config.RevokeURL = mockServer.URL
34+
expiry := time.Now().Add(1 * time.Hour)
35+
err := config.RevokeToken(context.Background(), &auth.Token{
36+
AccessToken: "test",
37+
Expiry: expiry,
38+
ExpiresIn: expiry.Unix(),
39+
})
40+
assert.NoError(t, err)
41+
}
42+
43+
// TestOAuthTokenSource_RevokeToken_Failure tests token revocation failure due to unauthorized access.
44+
func TestOAuthTokenSource_RevokeToken_Failure(t *testing.T) {
45+
mockServer := mockOAuthRevokeEndpoint(http.StatusUnauthorized)
46+
defer mockServer.Close()
47+
48+
config := NewConfig("clientID", "clientSecret")
49+
expiry := time.Now().Add(1 * time.Hour)
50+
err := config.RevokeToken(context.Background(), &auth.Token{
51+
AccessToken: "test",
52+
Expiry: expiry,
53+
ExpiresIn: expiry.Unix(),
54+
})
55+
assert.Error(t, err)
56+
assert.ErrorContains(t, err, "Failed to revoke Access Token when fetching new OAuth Token from remote server")
57+
}

auth/credentials/doc.go renamed to auth/clientcredentials/doc.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@
1313
// limitations under the License.
1414

1515
/*
16-
Package credentials provides an SDK internal client_credentials grant implementation https://datatracker.ietf.org/doc/html/rfc6749#section-1.3.4
16+
Package credentials provide an SDK internal client_credentials grant implementation https://datatracker.ietf.org/doc/html/rfc6749#section-1.3.4
1717
*/
18-
package credentials
18+
package clientcredentials
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package clientcredentials
2+
3+
import (
4+
"net/http"
5+
)
6+
7+
// Transport supplies custom user agent to token requests
8+
type Transport struct {
9+
Base http.RoundTripper
10+
UserAgent string
11+
}
12+
13+
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
14+
if t.UserAgent != "" {
15+
req.Header.Set("User-Agent", t.UserAgent)
16+
}
17+
return t.base().RoundTrip(req)
18+
}
19+
20+
func (t *Transport) base() http.RoundTripper {
21+
if t.Base != nil {
22+
return t.Base
23+
}
24+
return http.DefaultTransport
25+
}

0 commit comments

Comments
 (0)