Skip to content

Commit df94619

Browse files
authored
CLOUDP-278691: Credentials authentication (#429)
1 parent ce2323e commit df94619

File tree

10 files changed

+643
-10
lines changed

10 files changed

+643
-10
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ test-examples:
2828
.PHONY: fmt
2929
fmt:
3030
@echo "==> Fixing source code with gofmt..."
31-
gofmt -s -w ./**/*.go
31+
gofmt -s -w ./**/**/*.go
3232

3333
.PHONY: lint-fix
3434
lint-fix:

auth/credentials/api.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package credentials
2+
3+
import (
4+
"context"
5+
"net/http"
6+
)
7+
8+
// tokenAPIPath for obtaining OAuth Access Token from server
9+
//
10+
//nolint:gosec //url only
11+
const tokenAPIPath = "/api/oauth/token"
12+
13+
// serverURL for atlas API
14+
const serverURL = "https://cloud.mongodb.com" + tokenAPIPath
15+
16+
// AtlasTokenSourceOptions provides set of input arguments
17+
// for creation of credentials.TokenSource interface
18+
type AtlasTokenSourceOptions struct {
19+
ClientID string
20+
ClientSecret string
21+
// Custom Token source. InMemoryTokenCache being default
22+
TokenCache LocalTokenCache
23+
24+
// Custom context. context.Background() will be used by default
25+
Context *context.Context
26+
// Custom base url for fetching Token using TokenSource. Reserved for internal use.
27+
BaseURL *string
28+
}
29+
30+
// NewTokenSourceWithOptions initializes an OAuthTokenSource with advanced credentials.AtlasTokenSourceOptions
31+
// Use this method to initialize custom OAuth Token Cache (filesystem).
32+
func NewTokenSourceWithOptions(opts AtlasTokenSourceOptions) TokenSource {
33+
var tokenURL string
34+
if opts.BaseURL != nil {
35+
tokenURL = *opts.BaseURL + tokenAPIPath
36+
} else {
37+
tokenURL = serverURL
38+
}
39+
var ctx context.Context
40+
if opts.Context == nil {
41+
ctx = context.Background()
42+
} else {
43+
ctx = *opts.Context
44+
}
45+
46+
return &OAuthTokenSource{
47+
clientID: opts.ClientID,
48+
clientSecret: opts.ClientSecret,
49+
tokenURL: tokenURL,
50+
tokenCache: opts.TokenCache,
51+
ctx: ctx,
52+
}
53+
}
54+
55+
// NewTokenSource initializes OAuth Token Source that provides a way to obtain valid OAuth Tokens.
56+
// See credentials.NewTokenSourceWithOptions for advanced use cases.
57+
func NewTokenSource(clientID, clientSecret string) TokenSource {
58+
return NewTokenSourceWithOptions(AtlasTokenSourceOptions{
59+
ClientID: clientID,
60+
ClientSecret: clientSecret,
61+
TokenCache: &InMemoryTokenCache{},
62+
})
63+
}
64+
65+
// NewHTTPClientWithOAuthToken helper method for creating HTTP client with OAuth authentication support.
66+
// Use this method for performing requests using http.DefaultTransport.
67+
// For more advanced use cases please create your own credentials.OAuthCustomHTTPTransport.
68+
func NewHTTPClientWithOAuthToken(client TokenSource) *http.Client {
69+
return &http.Client{
70+
Transport: &OAuthCustomHTTPTransport{
71+
UnderlyingTransport: http.DefaultTransport,
72+
TokenSource: client,
73+
},
74+
}
75+
}

auth/credentials/doc.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Copyright 2024 MongoDB Inc
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
/*
16+
Package credentials provides an SDK internal client_credentials grant implementation https://datatracker.ietf.org/doc/html/rfc6749#section-1.3.4
17+
*/
18+
package credentials

auth/credentials/oauth.go

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
package credentials
2+
3+
import (
4+
"context"
5+
"encoding/base64"
6+
"encoding/json"
7+
"errors"
8+
"fmt"
9+
"net/http"
10+
"net/url"
11+
"strings"
12+
"time"
13+
)
14+
15+
// Token represents the internal OAuth2 Token structure
16+
type Token struct {
17+
AccessToken string `json:"access_token"`
18+
Expiry time.Time `json:"expiry,omitempty"`
19+
ExpiresIn int `json:"expires_in"`
20+
}
21+
22+
// TokenSource interface allows to fetch valid OAuth Token.
23+
type TokenSource interface {
24+
// GetValidToken retrieves the valid Token, refreshing it if necessary.
25+
GetValidToken() (*Token, error)
26+
}
27+
28+
// OAuthTokenSource manages the OAuth Token fetching and refreshing using a LocalTokenCache.
29+
type OAuthTokenSource struct {
30+
clientID string
31+
clientSecret string
32+
tokenURL string
33+
token *Token
34+
tokenCache LocalTokenCache
35+
ctx context.Context
36+
}
37+
38+
// GetValidToken retrieves the valid Token, refreshing it if necessary.
39+
func (c *OAuthTokenSource) GetValidToken() (*Token, error) {
40+
// Try to retrieve the Token string from the Token source
41+
tokenString, err := c.tokenCache.RetrieveToken(c.ctx)
42+
if err != nil || tokenString == nil {
43+
return c.refreshToken()
44+
}
45+
46+
// Parse the Token string into the Token structure (mock parse operation)
47+
c.token, err = parseToken(*tokenString)
48+
if err != nil || c.token.expired() {
49+
// Token is invalid or expired, refresh it
50+
return c.refreshToken()
51+
}
52+
53+
return c.token, nil
54+
}
55+
56+
// refreshToken fetches a new Token and saves it using the Token source.
57+
func (c *OAuthTokenSource) refreshToken() (*Token, error) {
58+
newToken, err := c.fetchToken()
59+
if err != nil {
60+
return nil, err
61+
}
62+
63+
// Save the access Token string to the Token source
64+
err = c.tokenCache.SaveToken(c.ctx, newToken.AccessToken)
65+
if err != nil {
66+
return nil, fmt.Errorf("failed to save Token: %w", err)
67+
}
68+
69+
c.token = newToken
70+
return newToken, nil
71+
}
72+
73+
// fetchToken makes a manual POST request to Server (tokenUrl) to fetch the access Token.
74+
func (c *OAuthTokenSource) fetchToken() (*Token, error) {
75+
data := url.Values{}
76+
data.Set("grant_type", "client_credentials")
77+
78+
req, err := http.NewRequest("POST", c.tokenURL, strings.NewReader(data.Encode()))
79+
if err != nil {
80+
return nil, err
81+
}
82+
req.SetBasicAuth(c.clientID, c.clientSecret)
83+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
84+
85+
client := &http.Client{}
86+
resp, err := client.Do(req)
87+
if err != nil {
88+
return nil, err
89+
}
90+
defer resp.Body.Close()
91+
92+
if resp.StatusCode != http.StatusOK {
93+
return nil, errors.New("failed to obtain Token, status: " + resp.Status)
94+
}
95+
96+
var tokenResp struct {
97+
AccessToken string `json:"access_token"`
98+
TokenType string `json:"token_type"`
99+
ExpiresIn int `json:"expires_in"`
100+
}
101+
102+
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
103+
return nil, err
104+
}
105+
106+
// Construct the Token with expiry time
107+
token := &Token{
108+
AccessToken: tokenResp.AccessToken,
109+
ExpiresIn: tokenResp.ExpiresIn,
110+
Expiry: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second),
111+
}
112+
return token, nil
113+
}
114+
115+
// Additional time for Access Tokens to not expire.
116+
const ExpiryDelta = 10 * time.Second
117+
118+
// expired checks if the Token is close to expiring.
119+
func (t *Token) expired() bool {
120+
if t.Expiry.IsZero() {
121+
return false
122+
}
123+
return t.Expiry.Round(0).Add(-ExpiryDelta).Before(time.Now())
124+
}
125+
126+
// Valid checks if the Token is still valid (present and not expired)
127+
func (t *Token) Valid() bool {
128+
return t != nil && t.AccessToken != "" && !t.expired()
129+
}
130+
131+
// ParseToken extracts expiry details from JWT Token
132+
func parseToken(accessToken string) (*Token, error) {
133+
parts := strings.Split(accessToken, ".")
134+
if len(parts) != 3 {
135+
return nil, errors.New("invalid access Token format")
136+
}
137+
138+
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
139+
if err != nil {
140+
return nil, err
141+
}
142+
143+
var tokenData struct {
144+
Exp int64 `json:"exp"`
145+
}
146+
if err := json.Unmarshal(payload, &tokenData); err != nil {
147+
return nil, err
148+
}
149+
150+
expiry := time.Unix(tokenData.Exp, 0)
151+
if time.Now().After(expiry) {
152+
return nil, errors.New("Token has expired")
153+
}
154+
155+
return &Token{
156+
AccessToken: accessToken,
157+
Expiry: expiry,
158+
ExpiresIn: int(time.Until(expiry).Seconds()),
159+
}, nil
160+
}

0 commit comments

Comments
 (0)