Skip to content

Commit a927fe7

Browse files
committed
fix(test): Update invalidCredsExpectedError to accomodate more than Keycloak auth errors.
1 parent 4718d4e commit a927fe7

File tree

2 files changed

+56
-26
lines changed

2 files changed

+56
-26
lines changed

auth_providers/auth_oauth.go

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,23 @@ func (b *CommandConfigOauth) GetServerConfig() *Server {
442442

443443
// GetAccessToken returns the OAuth2 token source for the given configuration.
444444
func (b *CommandConfigOauth) GetAccessToken() (oauth2.TokenSource, error) {
445+
if b == nil {
446+
return nil, fmt.Errorf("CommandConfigOauth is nil")
447+
}
448+
449+
b.ValidateAuthConfig()
450+
451+
if b.AccessToken != "" && (b.ClientID == "" || b.ClientSecret == "" || b.TokenURL == "") {
452+
log.Printf("[DEBUG] Access token is explicitly set, and no client credentials are provided. Using static token source.")
453+
return oauth2.StaticTokenSource(
454+
&oauth2.Token{
455+
AccessToken: b.AccessToken,
456+
TokenType: DefaultTokenPrefix,
457+
Expiry: b.Expiry,
458+
},
459+
), nil
460+
}
461+
445462
log.Printf("[DEBUG] Getting OAuth2 token source for client ID: %s", b.ClientID)
446463
if b.ClientID == "" || b.ClientSecret == "" || b.TokenURL == "" {
447464
return nil, fmt.Errorf("client ID, client secret, and token URL must be provided")
@@ -463,7 +480,19 @@ func (b *CommandConfigOauth) GetAccessToken() (oauth2.TokenSource, error) {
463480

464481
ctx := context.Background()
465482
log.Printf("[DEBUG] Returning call config.TokenSource() for client ID: %s", b.ClientID)
466-
return config.TokenSource(ctx), nil
483+
tokenSource := config.TokenSource(ctx)
484+
if tokenSource == nil {
485+
return nil, fmt.Errorf("failed to create token source for client ID: %s", b.ClientID)
486+
}
487+
token, tErr := tokenSource.Token()
488+
if tErr != nil {
489+
return nil, fmt.Errorf("failed to retrieve token for client ID %s: %w", b.ClientID, tErr)
490+
}
491+
if token == nil || token.AccessToken == "" {
492+
return nil, fmt.Errorf("received empty OAuth token for client ID: %s", b.ClientID)
493+
}
494+
495+
return tokenSource, nil
467496
}
468497

469498
// RoundTrip executes a single HTTP transaction, adding the OAuth2 token to the request

auth_providers/auth_oauth_test.go

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ func TestCommandConfigOauth_Authenticate(t *testing.T) {
257257
}
258258
fullParamsInvalidPassConfig.WithSkipVerify(true)
259259
invalidCredsExpectedError := []string{
260-
"oauth2", "unauthorized_client", "Invalid client or Invalid client credentials",
260+
"oauth2", "fail", "invalid", "client",
261261
}
262262
authOauthTest(t, "w/ full params & invalid pass", true, fullParamsInvalidPassConfig, invalidCredsExpectedError...)
263263

@@ -391,31 +391,32 @@ func authOauthTest(
391391

392392
// oauth credentials should always generate an access token
393393
oauthToken, tErr := config.GetAccessToken()
394-
if tErr != nil {
395-
t.Errorf("oAuth auth test '%s' failed to get token source with %v", testName, tErr)
396-
t.FailNow()
397-
return
398-
}
399-
if oauthToken == nil {
400-
t.Errorf("oAuth auth test '%s' failed to get token source", testName)
401-
t.FailNow()
402-
return
403-
}
404-
var at *oauth2.Token
405-
var tkErr error
406-
at, tkErr = oauthToken.Token()
407-
if tkErr != nil {
408-
t.Errorf("oAuth auth test '%s' failed to get token source", testName)
409-
t.FailNow()
410-
}
411-
if at == nil || at.AccessToken == "" {
412-
t.Errorf("oAuth auth test '%s' failed to get token source", testName)
413-
t.FailNow()
414-
return
394+
if !allowFail {
395+
if tErr != nil {
396+
t.Errorf("oAuth auth test '%s' failed to get token source with %v", testName, tErr)
397+
t.FailNow()
398+
return
399+
}
400+
if oauthToken == nil {
401+
t.Errorf("oAuth auth test '%s' failed to get token source", testName)
402+
t.FailNow()
403+
return
404+
}
405+
var at *oauth2.Token
406+
var tkErr error
407+
at, tkErr = oauthToken.Token()
408+
if tkErr != nil {
409+
t.Errorf("oAuth auth test '%s' failed to get token source", testName)
410+
t.FailNow()
411+
}
412+
if at == nil || at.AccessToken == "" {
413+
t.Errorf("oAuth auth test '%s' failed to get token source", testName)
414+
t.FailNow()
415+
return
416+
}
417+
//t.Logf("token %s", at.AccessToken)
418+
t.Logf("oAuth auth test '%s' succeeded", testName)
415419
}
416-
//t.Logf("token %s", at.AccessToken)
417-
t.Logf("oAuth auth test '%s' succeeded", testName)
418-
419420
err := config.Authenticate()
420421
if allowFail {
421422
if err == nil {

0 commit comments

Comments
 (0)