Skip to content

Commit 342dd58

Browse files
wtrockilantoli
andauthored
CLOUDP-280007: Auth Revoke + Error cases + testing (#459)
Co-authored-by: Leo Antoli <[email protected]>
1 parent d01808a commit 342dd58

File tree

5 files changed

+337
-65
lines changed

5 files changed

+337
-65
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,4 @@ admin/.openapi-generator
4242
scripts/gh-md-toc
4343
docs/.openapi-generator/*
4444

45+
/.trunk

auth/credentials/api.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,14 @@ import (
1313
//nolint:gosec //url only
1414
const tokenAPIPath = "/api/oauth/token"
1515

16-
// serverURL for atlas API
17-
const serverURL = core.DefaultCloudURL + tokenAPIPath
16+
// revokeAPIPath for revoking OAuth Access Token from server
17+
const revokeAPIPath = "/api/oauth/revoke"
18+
19+
// serverURL for Token Atlas API
20+
const serverTokenURL = core.DefaultCloudURL + tokenAPIPath
21+
22+
// serverURL for Revoke Atlas API
23+
const serverRevokeURL = core.DefaultCloudURL + revokeAPIPath
1824

1925
// AtlasTokenSourceOptions provides set of input arguments
2026
// for creation of credentials.TokenSource interface
@@ -35,11 +41,14 @@ type AtlasTokenSourceOptions struct {
3541
// NewTokenSourceWithOptions initializes an OAuthTokenSource with advanced credentials.AtlasTokenSourceOptions
3642
func NewTokenSourceWithOptions(opts AtlasTokenSourceOptions) TokenSource {
3743
var tokenURL string
44+
var revokeUrl string
3845
if opts.BaseURL != nil {
3946
baseUrlNoSuffix := strings.TrimSuffix(*opts.BaseURL, "/")
4047
tokenURL = baseUrlNoSuffix + tokenAPIPath
48+
revokeUrl = baseUrlNoSuffix + revokeAPIPath
4149
} else {
42-
tokenURL = serverURL
50+
tokenURL = serverTokenURL
51+
revokeUrl = serverRevokeURL
4352
}
4453
var userAgent string
4554
if opts.UserAgent != "" {
@@ -59,6 +68,7 @@ func NewTokenSourceWithOptions(opts AtlasTokenSourceOptions) TokenSource {
5968
clientSecret: opts.ClientSecret,
6069
userAgent: userAgent,
6170
tokenURL: tokenURL,
71+
revokeURL: revokeUrl,
6272
tokenCache: opts.TokenCache,
6373
ctx: ctx,
6474
}

auth/credentials/oauth.go

Lines changed: 127 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/json"
77
"errors"
88
"fmt"
9+
"io"
910
"net/http"
1011
"net/url"
1112
"strings"
@@ -19,10 +20,17 @@ type Token struct {
1920
ExpiresIn int `json:"expires_in"`
2021
}
2122

22-
// TokenSource interface allows to fetch valid OAuth Token.
23+
// TokenSource interface allows to fetch and revoke valid OAuth Access Tokens.
2324
type TokenSource interface {
2425
// GetValidToken retrieves the valid Token, refreshing it if necessary.
26+
// Implementation will check if token exist in the cache,
27+
// otherwise we will fetch new token from server and cache it.
28+
// If cached token expired it will be automatically removed and new OAuth Token will be requested
2529
GetValidToken() (*Token, error)
30+
31+
// RevokeToken revokes the Access Token while also removing it from the Access Token Cache.
32+
// When Access Token is expired or missing revoke will return without any action.
33+
RevokeToken() error
2634
}
2735

2836
// OAuthTokenSource manages the OAuth Token fetching and refreshing using a LocalTokenCache.
@@ -31,11 +39,33 @@ type OAuthTokenSource struct {
3139
clientSecret string
3240
userAgent string
3341
tokenURL string
42+
revokeURL string
3443
token *Token
3544
tokenCache LocalTokenCache
3645
ctx context.Context
3746
}
3847

48+
func (c *OAuthTokenSource) RevokeToken() error {
49+
tokenString, err := c.tokenCache.RetrieveToken(c.ctx)
50+
if err != nil {
51+
return err
52+
}
53+
if tokenString != nil && *tokenString != "" {
54+
err := c.revokeTokenInRemoteServer(*tokenString)
55+
if err != nil {
56+
return err
57+
}
58+
59+
// Revoked token can be removed from cache
60+
err = c.tokenCache.SaveToken(c.ctx, "")
61+
if err != nil {
62+
return err
63+
}
64+
}
65+
// No revocation needed for empty tokens.
66+
return nil
67+
}
68+
3969
// GetValidToken retrieves the valid Token, refreshing it if necessary.
4070
func (c *OAuthTokenSource) GetValidToken() (*Token, error) {
4171
// Try to retrieve the Token string from the Token source
@@ -56,7 +86,7 @@ func (c *OAuthTokenSource) GetValidToken() (*Token, error) {
5686

5787
// refreshToken fetches a new Token and saves it using the Token source.
5888
func (c *OAuthTokenSource) refreshToken() (*Token, error) {
59-
newToken, err := c.fetchToken()
89+
newToken, err := c.fetchTokenFromRemoteServer()
6090
if err != nil {
6191
return nil, err
6292
}
@@ -71,50 +101,7 @@ func (c *OAuthTokenSource) refreshToken() (*Token, error) {
71101
return newToken, nil
72102
}
73103

74-
// fetchToken makes a manual POST request to Server (tokenUrl) to fetch the access Token.
75-
func (c *OAuthTokenSource) fetchToken() (*Token, error) {
76-
data := url.Values{}
77-
data.Set("grant_type", "client_credentials")
78-
79-
req, err := http.NewRequest("POST", c.tokenURL, strings.NewReader(data.Encode()))
80-
if err != nil {
81-
return nil, err
82-
}
83-
req.SetBasicAuth(c.clientID, c.clientSecret)
84-
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
85-
req.Header.Set("User-Agent", c.userAgent)
86-
87-
client := &http.Client{}
88-
resp, err := client.Do(req)
89-
if err != nil {
90-
return nil, err
91-
}
92-
defer resp.Body.Close()
93-
94-
if resp.StatusCode != http.StatusOK {
95-
return nil, errors.New("failed to obtain Token, status: " + resp.Status)
96-
}
97-
98-
var tokenResp struct {
99-
AccessToken string `json:"access_token"`
100-
TokenType string `json:"token_type"`
101-
ExpiresIn int `json:"expires_in"`
102-
}
103-
104-
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
105-
return nil, err
106-
}
107-
108-
// Construct the Token with expiry time
109-
token := &Token{
110-
AccessToken: tokenResp.AccessToken,
111-
ExpiresIn: tokenResp.ExpiresIn,
112-
Expiry: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second),
113-
}
114-
return token, nil
115-
}
116-
117-
// Additional time for Access Tokens to not expire.
104+
// ExpiryDelta is the Additional time for Access Tokens to not expire.
118105
const ExpiryDelta = 10 * time.Second
119106

120107
// expired checks if the Token is close to expiring.
@@ -151,7 +138,7 @@ func parseToken(accessToken string) (*Token, error) {
151138

152139
expiry := time.Unix(tokenData.Exp, 0)
153140
if time.Now().After(expiry) {
154-
return nil, errors.New("Token has expired")
141+
return nil, errors.New("Atlas Cloud Access Token has expired")
155142
}
156143

157144
return &Token{
@@ -160,3 +147,96 @@ func parseToken(accessToken string) (*Token, error) {
160147
ExpiresIn: int(time.Until(expiry).Seconds()),
161148
}, nil
162149
}
150+
151+
// fetchTokenFromRemoteServer makes a manual POST request to Server (tokenUrl) to fetch the access Token.
152+
func (c *OAuthTokenSource) fetchTokenFromRemoteServer() (*Token, error) {
153+
data := url.Values{}
154+
data.Set("grant_type", "client_credentials")
155+
156+
req, err := http.NewRequest("POST", c.tokenURL, strings.NewReader(data.Encode()))
157+
if err != nil {
158+
return nil, err
159+
}
160+
req.SetBasicAuth(c.clientID, c.clientSecret)
161+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
162+
req.Header.Set("User-Agent", c.userAgent)
163+
164+
client := &http.Client{}
165+
resp, err := client.Do(req)
166+
if err != nil {
167+
return nil, err
168+
}
169+
defer resp.Body.Close()
170+
171+
if resp.StatusCode != http.StatusOK {
172+
if resp.StatusCode == http.StatusTooManyRequests {
173+
msg, _ := io.ReadAll(resp.Body)
174+
formattedMessage := fmt.Sprintf("%v %v: HTTP %v Detail: %v Reason: %v",
175+
"POST", c.tokenURL, resp.StatusCode,
176+
"Token request was rate limited", string(msg))
177+
return nil, errors.New(formattedMessage)
178+
}
179+
formattedMessage := fmt.Sprintf("%v %v: HTTP %v Detail: %v Reason: %v",
180+
"POST", c.tokenURL, resp.StatusCode,
181+
"Failed to obtain Access Token when fetching new OAuth Token from remote server",
182+
resp.Header.Get("www-authenticate"))
183+
return nil, errors.New(formattedMessage)
184+
}
185+
// tokenRemoteResponse represents successful response from token endpoint
186+
var tokenRemoteResponse struct {
187+
AccessToken string `json:"access_token"`
188+
TokenType string `json:"token_type"`
189+
ExpiresIn int `json:"expires_in"`
190+
}
191+
192+
if err := json.NewDecoder(resp.Body).Decode(&tokenRemoteResponse); err != nil {
193+
return nil, err
194+
}
195+
196+
// Construct the Token with expiry time
197+
token := &Token{
198+
AccessToken: tokenRemoteResponse.AccessToken,
199+
ExpiresIn: tokenRemoteResponse.ExpiresIn,
200+
Expiry: time.Now().Add(time.Duration(tokenRemoteResponse.ExpiresIn) * time.Second),
201+
}
202+
return token, nil
203+
}
204+
205+
// revokeToken revokes the provided access token by making a POST request to the OAuth revoke endpoint.
206+
func (c *OAuthTokenSource) revokeTokenInRemoteServer(token string) error {
207+
revokeUrl := c.revokeURL
208+
data := url.Values{}
209+
data.Set("token", token)
210+
data.Set("token_type_hint", "access_token")
211+
212+
req, err := http.NewRequest("POST", revokeUrl, strings.NewReader(data.Encode()))
213+
if err != nil {
214+
return err
215+
}
216+
req.SetBasicAuth(c.clientID, c.clientSecret)
217+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
218+
req.Header.Set("User-Agent", c.userAgent)
219+
220+
client := &http.Client{}
221+
resp, err := client.Do(req)
222+
if err != nil {
223+
return err
224+
}
225+
defer resp.Body.Close()
226+
227+
if resp.StatusCode != http.StatusOK {
228+
if resp.StatusCode == http.StatusTooManyRequests {
229+
msg, _ := io.ReadAll(resp.Body)
230+
formattedMessage := fmt.Sprintf("%v %v: HTTP %v Detail: %v Reason: %v",
231+
"POST", c.tokenURL, resp.StatusCode,
232+
"Token Revocation request was rate limited", string(msg))
233+
return errors.New(formattedMessage)
234+
}
235+
formattedMessage := fmt.Sprintf("%v %v: HTTP %v Detail: %v Reason: %v",
236+
"POST", c.tokenURL, resp.StatusCode,
237+
"Failed to revoke Access Token when fetching new OAuth Token from remote server",
238+
resp.Header.Get("www-authenticate"))
239+
return errors.New(formattedMessage)
240+
}
241+
return nil
242+
}

0 commit comments

Comments
 (0)