Skip to content

Commit daaf16c

Browse files
committed
feat(oauth): Add GetAccessToken to allow authentication without connecting to Command.
1 parent 9563f38 commit daaf16c

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

auth_providers/auth_oauth.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,11 +440,38 @@ func (b *CommandConfigOauth) GetServerConfig() *Server {
440440
return &server
441441
}
442442

443+
// GetAccessToken returns the OAuth2 token source for the given configuration.
444+
func (b *CommandConfigOauth) GetAccessToken() (oauth2.TokenSource, error) {
445+
log.Printf("[DEBUG] Getting OAuth2 token source for client ID: %s", b.ClientID)
446+
if b.ClientID == "" || b.ClientSecret == "" || b.TokenURL == "" {
447+
return nil, fmt.Errorf("client ID, client secret, and token URL must be provided")
448+
}
449+
450+
config := &clientcredentials.Config{
451+
ClientID: b.ClientID,
452+
ClientSecret: b.ClientSecret,
453+
TokenURL: b.TokenURL,
454+
Scopes: b.Scopes,
455+
}
456+
457+
if b.Audience != "" {
458+
log.Printf("[DEBUG] Setting audience for OAuth2 token source: %s", b.Audience)
459+
config.EndpointParams = map[string][]string{
460+
"audience": {b.Audience},
461+
}
462+
}
463+
464+
ctx := context.Background()
465+
log.Printf("[DEBUG] Returning call config.TokenSource() for client ID: %s", b.ClientID)
466+
return config.TokenSource(ctx), nil
467+
}
468+
443469
// RoundTrip executes a single HTTP transaction, adding the OAuth2 token to the request
444470
func (t *oauth2Transport) RoundTrip(req *http.Request) (*http.Response, error) {
445471
log.Printf("[DEBUG] Attempting to get oAuth token from: %s %s", req.Method, req.URL)
446472
token, err := t.src.Token()
447473
if err != nil {
474+
448475
return nil, fmt.Errorf("failed to retrieve OAuth token: %w", err)
449476
}
450477

auth_providers/auth_oauth_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"testing"
2727

2828
"github.com/Keyfactor/keyfactor-auth-client-go/auth_providers"
29+
"golang.org/x/oauth2"
2930
)
3031

3132
func TestOAuthAuthenticator_GetHttpClient(t *testing.T) {
@@ -346,6 +347,18 @@ func TestCommandConfigOauth_Authenticate(t *testing.T) {
346347
authOauthTest(t, "with invalid creds implicit config file", true, invCmdHost, invHostExpectedError...)
347348
}
348349

350+
func TestCommandConfigOauth_GetAccessToken(t *testing.T) {
351+
clientID, clientSecret, tokenURL := exportOAuthEnvVariables()
352+
t.Log("Testing auth with w/ full params variables")
353+
fullParamsConfig := &auth_providers.CommandConfigOauth{
354+
ClientID: clientID,
355+
ClientSecret: clientSecret,
356+
TokenURL: tokenURL,
357+
}
358+
fullParamsConfig.WithSkipVerify(true)
359+
authOauthTest(t, "w/ GetAccessToken w/ full params variables", false, fullParamsConfig)
360+
}
361+
349362
func TestCommandConfigOauth_Build(t *testing.T) {
350363
// Skip test if TEST_KEYFACTOR_AD_AUTH is set to 1 or true
351364
if os.Getenv("TEST_KEYFACTOR_AD_AUTH") == "1" || os.Getenv("TEST_KEYFACTOR_AD_AUTH") == "true" {
@@ -376,6 +389,33 @@ func authOauthTest(
376389
t.Run(
377390
fmt.Sprintf("oAuth Auth Test %s", testName), func(t *testing.T) {
378391

392+
// oauth credentials should always generate an access token
393+
oauthToken, tErr := config.GetAccessToken()
394+
if tErr != nil {
395+
t.Errorf("oAuth auth test '%s' failed to get token source with %v", testName, tErr)
396+
t.FailNow()
397+
return
398+
}
399+
if oauthToken == nil {
400+
t.Errorf("oAuth auth test '%s' failed to get token source", testName)
401+
t.FailNow()
402+
return
403+
}
404+
var at *oauth2.Token
405+
var tkErr error
406+
at, tkErr = oauthToken.Token()
407+
if tkErr != nil {
408+
t.Errorf("oAuth auth test '%s' failed to get token source", testName)
409+
t.FailNow()
410+
}
411+
if at == nil || at.AccessToken == "" {
412+
t.Errorf("oAuth auth test '%s' failed to get token source", testName)
413+
t.FailNow()
414+
return
415+
}
416+
//t.Logf("token %s", at.AccessToken)
417+
t.Logf("oAuth auth test '%s' succeeded", testName)
418+
379419
err := config.Authenticate()
380420
if allowFail {
381421
if err == nil {

0 commit comments

Comments
 (0)