@@ -21,6 +21,8 @@ import (
2121
2222 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
2323 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
24+ "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base"
25+ "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity"
2426 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
2527)
2628
@@ -33,7 +35,7 @@ const (
3335 organizationsAuthority = microsoftAuthorityHost + "organizations/"
3436 microsoftAuthority = microsoftAuthorityHost + "72f988bf-86f1-41af-91ab-2d7cd011db47"
3537 //msIDlabTenantAuthority = microsoftAuthorityHost + "msidlab4.onmicrosoft.com" - Will be needed in the future
36-
38+ msiClientId = "4b7a4b0b-ecb2-409e-879a-1e21a15ddaf6"
3739 // Default values
3840 defaultClientId = "f62c5ae3-bf3a-4af5-afa8-a68b800396e9"
3941 pemFile = "../../../cert.pem"
@@ -48,14 +50,12 @@ func httpRequest(ctx context.Context, url string, query url.Values, accessToken
4850 ctx , cancel = context .WithTimeout (ctx , 10 * time .Second )
4951 defer cancel ()
5052 }
51-
5253 req , err := http .NewRequestWithContext (ctx , "GET" , url , nil )
5354 if err != nil {
5455 return nil , fmt .Errorf ("failed to build new http request: %w" , err )
5556 }
5657 req .Header .Set ("Authorization" , "Bearer " + accessToken )
5758 req .URL .RawQuery = query .Encode ()
58-
5959 resp , err := httpClient .Do (req )
6060 if err != nil {
6161 return nil , fmt .Errorf ("http.Get(%s) failed: %w" , req .URL .String (), err )
@@ -142,6 +142,18 @@ func (l *labClient) labAccessToken() (string, error) {
142142 return result .AccessToken , nil
143143}
144144
145+ func (l * labClient ) getLabResponse (url string , query url.Values ) (string , error ) {
146+ accessToken , err := l .labAccessToken ()
147+ if err != nil {
148+ return "" , fmt .Errorf ("problem getting lab access token: %w" , err )
149+ }
150+ responseBody , err := httpRequest (context .Background (), url , query , accessToken )
151+ if err != nil {
152+ return "" , err
153+ }
154+ return string (responseBody ), nil
155+ }
156+
145157func (l * labClient ) user (ctx context.Context , query url.Values ) (user , error ) {
146158 accessToken , err := l .labAccessToken ()
147159 if err != nil {
@@ -482,6 +494,112 @@ func TestAccountFromCache(t *testing.T) {
482494
483495}
484496
497+ type urlModifierTransport struct {
498+ base http.RoundTripper
499+ modifyFunc func (* http.Request )
500+ }
501+
502+ // RoundTrip implements the http.RoundTripper interface
503+ func (t * urlModifierTransport ) RoundTrip (req * http.Request ) (* http.Response , error ) {
504+ // Modifying the original resquest to have the updated URL
505+ if t .modifyFunc != nil {
506+ t .modifyFunc (req )
507+ }
508+ return t .base .RoundTrip (req )
509+ }
510+
511+ func TestAcquireMSITokenExchangeForESTSToken (t * testing.T ) {
512+ labC , err := newLabClient ()
513+ if err != nil {
514+ t .Fatal (err )
515+ }
516+ baseUrl := "https://service.msidlab.com/"
517+ resource := "api://azureadtokenexchange"
518+ query := url.Values {}
519+ query .Add ("resource" , "WebApp" )
520+ response , err := labC .getLabResponse (baseUrl + "EnvironmentVariables" , query )
521+ if err != nil {
522+ t .Fatalf ("Failed to get resource env variable: %v" , err )
523+ }
524+ var result map [string ]string
525+ err = json .Unmarshal ([]byte (response ), & result )
526+ if err != nil {
527+ t .Fatalf ("Failed to unmarshal response: %v" , err )
528+ }
529+
530+ for key , value := range result {
531+ if key == "IDENTITY_ENDPOINT" {
532+ value = "https://service.msidlab.com/MSIToken?azureresource=WebApp&uri=" + value
533+ }
534+ t .Setenv (key , value )
535+ }
536+ // Replace your existing http.Client with this one
537+ httpClient := http.Client {
538+ Transport : & urlModifierTransport {
539+ base : http .DefaultTransport ,
540+ modifyFunc : func (req * http.Request ) {
541+ req .URL .Host = "service.msidlab.com"
542+ req .URL .Path = "/MSIToken"
543+ req .URL .Scheme = "https"
544+ req .URL .RawQuery = "azureresource=WebApp&uri=http%3A%2F%2F127.0.0.1%3A41488%2Fmsi%2Ftoken%2F%3Fapi-version%3D2019-08-01%26resource%3Dapi%3A%2F%2Fazureadtokenexchange%26client_id%3D" + msiClientId
545+ accessToken , err := labC .labAccessToken ()
546+ if err != nil {
547+ t .Fatal ("Failed to get access token: " , err )
548+ }
549+ req .Header .Set ("Authorization" , "Bearer " + accessToken )
550+ },
551+ },
552+ }
553+ ctx := context .Background ()
554+ msiClient , err := managedidentity .New (managedidentity .UserAssignedClientID (msiClientId ),
555+ managedidentity .WithHTTPClient (& httpClient ))
556+ if err != nil {
557+ t .Fatalf ("Failed to create MSI client: %v" , err )
558+ }
559+ token , err := msiClient .AcquireToken (ctx , resource )
560+ if err != nil {
561+ t .Fatalf ("Failed to acquire token: %v" , err )
562+ }
563+ if token .AccessToken == "" {
564+ t .Fatal ("Expected non-empty access token" )
565+ }
566+ cred := confidential .NewCredFromAssertionCallback (func (ctx context.Context , opt confidential.AssertionRequestOptions ) (string , error ) {
567+ token , err := msiClient .AcquireToken (ctx , resource )
568+ if err != nil {
569+ t .Fatalf ("Failed to acquire token: %v" , err )
570+ }
571+ return token .AccessToken , nil
572+ })
573+ confidentialClient , err := confidential .New (microsoftAuthority ,
574+ defaultClientId ,
575+ cred ,
576+ confidential .WithInstanceDiscovery (false ))
577+ if err != nil {
578+ t .Fatalf ("Failed to create confidential client: %v" , err )
579+ }
580+
581+ authResult , err := confidentialClient .AcquireTokenByCredential (ctx , []string {"https://msidlabs.vault.azure.net/.default" })
582+ if err != nil {
583+ t .Fatalf ("Failed to acquire token by credential: %v" , err )
584+ }
585+ if authResult .AccessToken == "" {
586+ t .Fatal ("Expected non-empty access token" )
587+ }
588+ if authResult .Metadata .TokenSource != base .TokenSourceIdentityProvider {
589+ t .Fatalf ("Expected token source 'IdentityProvider', got '%d'" , authResult .Metadata .TokenSource )
590+ }
591+ authResult , err = confidentialClient .AcquireTokenSilent (ctx , []string {"https://msidlabs.vault.azure.net/.default" })
592+ if err != nil {
593+ t .Fatalf ("Failed to acquire token by credential: %v" , err )
594+ }
595+ if authResult .AccessToken == "" {
596+ t .Fatal ("Expected non-empty access token" )
597+ }
598+ if authResult .Metadata .TokenSource != base .TokenSourceCache {
599+ t .Fatalf ("Expected token source 'Cache', got '%d'" , authResult .Metadata .TokenSource )
600+ }
601+ }
602+
485603func TestAdfsToken (t * testing.T ) {
486604 if testing .Short () {
487605 t .Skip ("skipping integration test" )
@@ -509,5 +627,4 @@ func TestAdfsToken(t *testing.T) {
509627 if result .AccessToken == "" {
510628 t .Fatal ("TestConfidentialClientwithSecret: on AcquireTokenByCredential(): got AccessToken == '', want AccessToken != ''" )
511629 }
512-
513630}
0 commit comments