Skip to content

Commit 5881237

Browse files
authored
Exposing necessary types from each client (#561)
* Exposing necessary types from each client * Updated formatting * Update confidential_test.go * Removed TokenSourceUnknown
1 parent fc56b03 commit 5881237

File tree

8 files changed

+81
-58
lines changed

8 files changed

+81
-58
lines changed

apps/confidential/confidential.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ type AuthenticationScheme = authority.AuthenticationScheme
6565

6666
type Account = shared.Account
6767

68+
type TokenSource = base.TokenSource
69+
70+
const (
71+
TokenSourceIdentityProvider = base.TokenSourceIdentityProvider
72+
TokenSourceCache = base.TokenSourceCache
73+
)
74+
6875
// CertFromPEM converts a PEM file (.pem or .key) for use with [NewCredFromCert]. The file
6976
// must contain the public certificate and the private key. If a PEM block is encrypted and
7077
// password is not an empty string, it attempts to decrypt the PEM blocks using the password.

apps/confidential/confidential_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1097,7 +1097,7 @@ func TestRefreshIn(t *testing.T) {
10971097
if err != nil {
10981098
t.Fatal(err)
10991099
}
1100-
if ar.Metadata.TokenSource != base.Cache && !tt.shouldGetNewToken {
1100+
if ar.Metadata.TokenSource != TokenSourceCache && !tt.shouldGetNewToken {
11011101
t.Fatal("should have returned from cache.")
11021102
}
11031103
if (ar.AccessToken == secondToken) != tt.shouldGetNewToken {

apps/internal/base/base.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,8 @@ type TokenSource int
103103

104104
// These are all the types of token flows.
105105
const (
106-
SourceUnknown TokenSource = 0
107-
IdentityProvider TokenSource = 1
108-
Cache TokenSource = 2
106+
TokenSourceIdentityProvider TokenSource = 0
107+
TokenSourceCache TokenSource = 1
109108
)
110109

111110
// AuthResultFromStorage creates an AuthResult from a storage token response (which is generated from the cache).
@@ -133,7 +132,7 @@ func AuthResultFromStorage(storageTokenResponse storage.TokenResponse) (AuthResu
133132
GrantedScopes: grantedScopes,
134133
DeclinedScopes: nil,
135134
Metadata: AuthResultMetadata{
136-
TokenSource: Cache,
135+
TokenSource: TokenSourceCache,
137136
RefreshOn: storageTokenResponse.AccessToken.RefreshOn.T,
138137
},
139138
}, nil
@@ -151,7 +150,7 @@ func NewAuthResult(tokenResponse accesstokens.TokenResponse, account shared.Acco
151150
ExpiresOn: tokenResponse.ExpiresOn,
152151
GrantedScopes: tokenResponse.GrantedScopes.Slice,
153152
Metadata: AuthResultMetadata{
154-
TokenSource: IdentityProvider,
153+
TokenSource: TokenSourceIdentityProvider,
155154
RefreshOn: tokenResponse.RefreshOn.T,
156155
},
157156
}, nil

apps/internal/base/base_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ func TestCreateAuthenticationResult(t *testing.T) {
344344
GrantedScopes: []string{"user.read"},
345345
DeclinedScopes: nil,
346346
Metadata: AuthResultMetadata{
347-
TokenSource: IdentityProvider,
347+
TokenSource: TokenSourceIdentityProvider,
348348
},
349349
},
350350
},
@@ -419,7 +419,7 @@ func TestAuthResultFromStorage(t *testing.T) {
419419
ExpiresOn: future,
420420
GrantedScopes: []string{"profile", "openid", "user.read"},
421421
Metadata: AuthResultMetadata{
422-
TokenSource: Cache,
422+
TokenSource: TokenSourceCache,
423423
},
424424
},
425425
},

apps/managedidentity/managedidentity.go

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@ import (
3232
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
3333
)
3434

35+
// AuthResult contains the results of one token acquisition operation.
36+
// For details see https://aka.ms/msal-net-authenticationresult
37+
type AuthResult = base.AuthResult
38+
39+
type TokenSource = base.TokenSource
40+
41+
const (
42+
TokenSourceIdentityProvider = base.TokenSourceIdentityProvider
43+
TokenSourceCache = base.TokenSourceCache
44+
)
45+
3546
const (
3647
// DefaultToIMDS indicates that the source is defaulted to IMDS when no environment variables are set.
3748
DefaultToIMDS Source = "DefaultToIMDS"
@@ -304,7 +315,7 @@ var now = time.Now
304315
//
305316
// Resource: scopes application is requesting access to
306317
// Options: [WithClaims]
307-
func (c Client) AcquireToken(ctx context.Context, resource string, options ...AcquireTokenOption) (base.AuthResult, error) {
318+
func (c Client) AcquireToken(ctx context.Context, resource string, options ...AcquireTokenOption) (AuthResult, error) {
308319
resource = strings.TrimSuffix(resource, "/.default")
309320
o := AcquireTokenOptions{}
310321
for _, option := range options {
@@ -316,7 +327,7 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac
316327
if o.claims == "" {
317328
stResp, err := cacheManager.Read(ctx, c.authParams)
318329
if err != nil {
319-
return base.AuthResult{}, err
330+
return AuthResult{}, err
320331
}
321332
ar, err := base.AuthResultFromStorage(stResp)
322333
if err == nil {
@@ -333,7 +344,7 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac
333344
return c.getToken(ctx, resource)
334345
}
335346

336-
func (c Client) getToken(ctx context.Context, resource string) (base.AuthResult, error) {
347+
func (c Client) getToken(ctx context.Context, resource string) (AuthResult, error) {
337348
switch c.source {
338349
case AzureArc:
339350
return c.acquireTokenForAzureArc(ctx, resource)
@@ -348,110 +359,110 @@ func (c Client) getToken(ctx context.Context, resource string) (base.AuthResult,
348359
case ServiceFabric:
349360
return c.acquireTokenForServiceFabric(ctx, resource)
350361
default:
351-
return base.AuthResult{}, fmt.Errorf("unsupported source %q", c.source)
362+
return AuthResult{}, fmt.Errorf("unsupported source %q", c.source)
352363
}
353364
}
354365

355-
func (c Client) acquireTokenForAppService(ctx context.Context, resource string) (base.AuthResult, error) {
366+
func (c Client) acquireTokenForAppService(ctx context.Context, resource string) (AuthResult, error) {
356367
req, err := createAppServiceAuthRequest(ctx, c.miType, resource)
357368
if err != nil {
358-
return base.AuthResult{}, err
369+
return AuthResult{}, err
359370
}
360371
tokenResponse, err := c.getTokenForRequest(req, resource)
361372
if err != nil {
362-
return base.AuthResult{}, err
373+
return AuthResult{}, err
363374
}
364375
return authResultFromToken(c.authParams, tokenResponse)
365376
}
366377

367-
func (c Client) acquireTokenForIMDS(ctx context.Context, resource string) (base.AuthResult, error) {
378+
func (c Client) acquireTokenForIMDS(ctx context.Context, resource string) (AuthResult, error) {
368379
req, err := createIMDSAuthRequest(ctx, c.miType, resource)
369380
if err != nil {
370-
return base.AuthResult{}, err
381+
return AuthResult{}, err
371382
}
372383
tokenResponse, err := c.getTokenForRequest(req, resource)
373384
if err != nil {
374-
return base.AuthResult{}, err
385+
return AuthResult{}, err
375386
}
376387
return authResultFromToken(c.authParams, tokenResponse)
377388
}
378389

379-
func (c Client) acquireTokenForCloudShell(ctx context.Context, resource string) (base.AuthResult, error) {
390+
func (c Client) acquireTokenForCloudShell(ctx context.Context, resource string) (AuthResult, error) {
380391
req, err := createCloudShellAuthRequest(ctx, resource)
381392
if err != nil {
382-
return base.AuthResult{}, err
393+
return AuthResult{}, err
383394
}
384395
tokenResponse, err := c.getTokenForRequest(req, resource)
385396
if err != nil {
386-
return base.AuthResult{}, err
397+
return AuthResult{}, err
387398
}
388399
return authResultFromToken(c.authParams, tokenResponse)
389400
}
390401

391-
func (c Client) acquireTokenForAzureML(ctx context.Context, resource string) (base.AuthResult, error) {
402+
func (c Client) acquireTokenForAzureML(ctx context.Context, resource string) (AuthResult, error) {
392403
req, err := createAzureMLAuthRequest(ctx, c.miType, resource)
393404
if err != nil {
394-
return base.AuthResult{}, err
405+
return AuthResult{}, err
395406
}
396407
tokenResponse, err := c.getTokenForRequest(req, resource)
397408
if err != nil {
398-
return base.AuthResult{}, err
409+
return AuthResult{}, err
399410
}
400411
return authResultFromToken(c.authParams, tokenResponse)
401412
}
402413

403-
func (c Client) acquireTokenForServiceFabric(ctx context.Context, resource string) (base.AuthResult, error) {
414+
func (c Client) acquireTokenForServiceFabric(ctx context.Context, resource string) (AuthResult, error) {
404415
req, err := createServiceFabricAuthRequest(ctx, resource)
405416
if err != nil {
406-
return base.AuthResult{}, err
417+
return AuthResult{}, err
407418
}
408419
tokenResponse, err := c.getTokenForRequest(req, resource)
409420
if err != nil {
410-
return base.AuthResult{}, err
421+
return AuthResult{}, err
411422
}
412423
return authResultFromToken(c.authParams, tokenResponse)
413424
}
414425

415-
func (c Client) acquireTokenForAzureArc(ctx context.Context, resource string) (base.AuthResult, error) {
426+
func (c Client) acquireTokenForAzureArc(ctx context.Context, resource string) (AuthResult, error) {
416427
req, err := createAzureArcAuthRequest(ctx, resource, "")
417428
if err != nil {
418-
return base.AuthResult{}, err
429+
return AuthResult{}, err
419430
}
420431

421432
response, err := c.httpClient.Do(req)
422433
if err != nil {
423-
return base.AuthResult{}, err
434+
return AuthResult{}, err
424435
}
425436
defer response.Body.Close()
426437

427438
if response.StatusCode != http.StatusUnauthorized {
428-
return base.AuthResult{}, fmt.Errorf("expected a 401 response, received %d", response.StatusCode)
439+
return AuthResult{}, fmt.Errorf("expected a 401 response, received %d", response.StatusCode)
429440
}
430441

431442
secret, err := c.getAzureArcSecretKey(response, runtime.GOOS)
432443
if err != nil {
433-
return base.AuthResult{}, err
444+
return AuthResult{}, err
434445
}
435446

436447
secondRequest, err := createAzureArcAuthRequest(ctx, resource, string(secret))
437448
if err != nil {
438-
return base.AuthResult{}, err
449+
return AuthResult{}, err
439450
}
440451

441452
tokenResponse, err := c.getTokenForRequest(secondRequest, resource)
442453
if err != nil {
443-
return base.AuthResult{}, err
454+
return AuthResult{}, err
444455
}
445456
return authResultFromToken(c.authParams, tokenResponse)
446457
}
447458

448-
func authResultFromToken(authParams authority.AuthParams, token accesstokens.TokenResponse) (base.AuthResult, error) {
459+
func authResultFromToken(authParams authority.AuthParams, token accesstokens.TokenResponse) (AuthResult, error) {
449460
if cacheManager == nil {
450-
return base.AuthResult{}, errors.New("cache instance is nil")
461+
return AuthResult{}, errors.New("cache instance is nil")
451462
}
452463
account, err := cacheManager.Write(authParams, token)
453464
if err != nil {
454-
return base.AuthResult{}, err
465+
return AuthResult{}, err
455466
}
456467
// if refreshOn is not set, set it to half of the time until expiry if expiry is more than 2 hours away
457468
if token.RefreshOn.T.IsZero() {
@@ -461,7 +472,7 @@ func authResultFromToken(authParams authority.AuthParams, token accesstokens.Tok
461472
}
462473
ar, err := base.NewAuthResult(token, account)
463474
if err != nil {
464-
return base.AuthResult{}, err
475+
return AuthResult{}, err
465476
}
466477
ar.AccessToken, err = authParams.AuthnScheme.FormatAccessToken(ar.AccessToken)
467478
return ar, err

0 commit comments

Comments
 (0)