Skip to content

Commit f7579f7

Browse files
authored
Fix panic in AcquireTokenSilent for public clients due to nil Credential (#581)
* Added a guard for public client refreshing through credential * refreshing token when refresh_on is present and added test * Added test for authcode for confidential client * updated code based on comment * Update apps/internal/base/base.go
1 parent a35dff7 commit f7579f7

File tree

3 files changed

+140
-2
lines changed

3 files changed

+140
-2
lines changed

apps/confidential/confidential_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,77 @@ func TestAcquireTokenByAuthCode(t *testing.T) {
481481
}
482482
}
483483

484+
func TestAcquireTokenByAuthCodeTokenExpiry(t *testing.T) {
485+
accessToken := "initial-access-token"
486+
newAccessToken := "new-access-token"
487+
homeTenant := "home-tenant"
488+
clientInfo := base64.RawStdEncoding.EncodeToString([]byte(
489+
fmt.Sprintf(`{"uid":"uid","utid":"%s"}`, homeTenant),
490+
))
491+
lmo := "login.microsoftonline.com"
492+
493+
originalTime := base.Now
494+
defer func() {
495+
base.Now = originalTime
496+
}()
497+
498+
cred, err := NewCredFromSecret(fakeSecret)
499+
if err != nil {
500+
t.Fatal(err)
501+
}
502+
503+
mockClient := mock.NewClient()
504+
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "common")))
505+
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(homeTenant, fmt.Sprintf(authorityFmt, lmo, homeTenant)), "rt", clientInfo, 36000, 1000)))
506+
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(newAccessToken, mock.GetIDToken(homeTenant, fmt.Sprintf(authorityFmt, lmo, homeTenant)), "rt", clientInfo, 36000, 1000)))
507+
508+
client, err := New(fmt.Sprintf(authorityFmt, lmo, "common"), fakeClientID, cred, WithHTTPClient(mockClient), WithInstanceDiscovery(false))
509+
if err != nil {
510+
t.Fatal(err)
511+
}
512+
513+
// Acquire token using auth code
514+
ar, err := client.AcquireTokenByAuthCode(context.Background(), "code", "https://localhost", tokenScope)
515+
if err != nil {
516+
t.Fatal(err)
517+
}
518+
if ar.AccessToken != accessToken {
519+
t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken)
520+
}
521+
522+
account := ar.Account
523+
524+
// First silent call should return cached token
525+
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account))
526+
if err != nil {
527+
t.Fatal(err)
528+
}
529+
if ar.AccessToken != accessToken {
530+
t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken)
531+
}
532+
if ar.Metadata.TokenSource != base.TokenSourceCache {
533+
t.Fatalf("expected token source %v, got %v", base.TokenSourceCache, ar.Metadata.TokenSource)
534+
}
535+
536+
// Move time forward past RefreshOn (1001 seconds)
537+
fixedTime := time.Now().Add(time.Duration(1001) * time.Second)
538+
base.Now = func() time.Time {
539+
return fixedTime
540+
}
541+
542+
// Second silent call should automatically refresh and return new token
543+
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account))
544+
if err != nil {
545+
t.Fatal(err)
546+
}
547+
if ar.AccessToken != newAccessToken {
548+
t.Fatalf("expected %q, got %q", newAccessToken, ar.AccessToken)
549+
}
550+
// Verify the token came from the identity provider (refresh), not cache
551+
if ar.Metadata.TokenSource != base.TokenSourceIdentityProvider {
552+
t.Fatalf("expected token source %v, got %v", base.TokenSourceIdentityProvider, ar.Metadata.TokenSource)
553+
}
554+
}
484555
func TestInvalidJsonErrFromResponse(t *testing.T) {
485556
cred, err := NewCredFromSecret(fakeSecret)
486557
if err != nil {

apps/internal/base/base.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,8 +367,19 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen
367367
// If the token is not same, we don't need to refresh it.
368368
// Which means it refreshed.
369369
if str, err := m.Read(ctx, authParams); err == nil && str.AccessToken.Secret == ar.AccessToken {
370-
if tr, er := b.Token.Credential(ctx, authParams, silent.Credential); er == nil {
371-
return b.AuthResultFromToken(ctx, authParams, tr)
370+
switch silent.RequestType {
371+
case accesstokens.ATConfidential:
372+
if tr, er := b.Token.Credential(ctx, authParams, silent.Credential); er == nil {
373+
return b.AuthResultFromToken(ctx, authParams, tr)
374+
}
375+
case accesstokens.ATPublic:
376+
token, err := b.Token.Refresh(ctx, silent.RequestType, authParams, silent.Credential, storageTokenResponse.RefreshToken)
377+
if err != nil {
378+
return ar, err
379+
}
380+
return b.AuthResultFromToken(ctx, authParams, token)
381+
case accesstokens.ATUnknown:
382+
return ar, errors.New("silent request type cannot be ATUnknown")
372383
}
373384
}
374385
}

apps/public/public_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"time"
1717

1818
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
19+
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base"
1920
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
2021
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock"
2122
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
@@ -1046,3 +1047,58 @@ func getNewClientWithMockedResponses(
10461047

10471048
return client, nil
10481049
}
1050+
1051+
func TestAcquireTokenSilentWithRefreshOnIsExpired(t *testing.T) {
1052+
accessToken := "*"
1053+
homeTenant := "home-tenant"
1054+
clientInfo := base64.RawStdEncoding.EncodeToString([]byte(
1055+
fmt.Sprintf(`{"uid":"uid","utid":"%s"}`, homeTenant),
1056+
))
1057+
lmo := "login.microsoftonline.com"
1058+
originalTime := base.Now
1059+
defer func() {
1060+
base.Now = originalTime
1061+
}()
1062+
mockClient := mock.NewClient()
1063+
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "common")))
1064+
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(homeTenant, fmt.Sprintf(authorityFmt, lmo, homeTenant)), "rt", clientInfo, 36000, 1000)))
1065+
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody("new-"+accessToken, mock.GetIDToken(homeTenant, fmt.Sprintf(authorityFmt, lmo, homeTenant)), "rt", clientInfo, 36000, 1000)))
1066+
1067+
client, err := New("common",
1068+
WithAuthority(fmt.Sprintf(authorityFmt, lmo, "common")),
1069+
WithHTTPClient(mockClient),
1070+
WithInstanceDiscovery(false))
1071+
if err != nil {
1072+
t.Fatal(err)
1073+
}
1074+
// the auth flow isn't important, we just need to populate the cache
1075+
ar, err := client.AcquireTokenByAuthCode(context.Background(), "code", "https://localhost", tokenScope)
1076+
if err != nil {
1077+
t.Fatal(err)
1078+
}
1079+
if ar.AccessToken != accessToken {
1080+
t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken)
1081+
}
1082+
account := ar.Account
1083+
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account))
1084+
if err != nil {
1085+
t.Fatal(err)
1086+
}
1087+
if ar.AccessToken != accessToken {
1088+
t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken)
1089+
}
1090+
// moving time forward to expire the current token
1091+
fixedTime := time.Now().Add(time.Duration(36001) * time.Second)
1092+
base.Now = func() time.Time {
1093+
return fixedTime
1094+
}
1095+
// calling the acquire token again
1096+
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account))
1097+
if err != nil {
1098+
t.Fatal(err)
1099+
}
1100+
if ar.AccessToken != "new-"+accessToken {
1101+
t.Fatalf("expected %q, got %q", "new-"+accessToken, ar.AccessToken)
1102+
}
1103+
1104+
}

0 commit comments

Comments
 (0)