Skip to content

Commit 462d177

Browse files
committed
Added Region auto enable
Added Region auto enable in confidential client and its test
1 parent bf74752 commit 462d177

File tree

2 files changed

+78
-1
lines changed

2 files changed

+78
-1
lines changed

apps/confidential/confidential.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import (
1818
"encoding/pem"
1919
"errors"
2020
"fmt"
21+
"os"
22+
"strings"
2123

2224
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
2325
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base"
@@ -315,16 +317,21 @@ func New(authority, clientID string, cred Credential, options ...Option) (Client
315317
if err != nil {
316318
return Client{}, err
317319
}
318-
320+
region := os.Getenv("MSAL_FORCE_REGION")
319321
opts := clientOptions{
320322
authority: authority,
321323
// if the caller specified a token provider, it will handle all details of authentication, using Client only as a token cache
322324
disableInstanceDiscovery: cred.tokenProvider != nil,
323325
httpClient: shared.DefaultClient,
326+
azureRegion: region,
324327
}
325328
for _, o := range options {
326329
o(&opts)
327330
}
331+
if strings.EqualFold(opts.azureRegion, "DisableMsalForceRegion") {
332+
opts.azureRegion = ""
333+
}
334+
328335
baseOpts := []base.Option{
329336
base.WithCacheAccessor(opts.accessor),
330337
base.WithClientCapabilities(opts.capabilities),

apps/confidential/confidential_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,76 @@ func TestAcquireTokenByCredential(t *testing.T) {
164164
}
165165
}
166166

167+
func TestRegionAutoEnable(t *testing.T) {
168+
cred, err := NewCredFromSecret(fakeSecret)
169+
if err != nil {
170+
t.Fatal(err)
171+
}
172+
tests := []struct {
173+
region string
174+
envRegion string
175+
}{
176+
{
177+
region: "",
178+
envRegion: "envRegion",
179+
},
180+
{
181+
region: "region",
182+
envRegion: "envRegion",
183+
},
184+
{
185+
region: "DisableMsalForceRegion",
186+
envRegion: "envRegion",
187+
},
188+
}
189+
190+
for _, test := range tests {
191+
lmo := "login.microsoftonline.com"
192+
tenant := "tenant"
193+
mockClient := mock.Client{}
194+
if test.envRegion != "" {
195+
err := os.Setenv("MSAL_FORCE_REGION", test.envRegion)
196+
if err != nil {
197+
t.Fatal(err)
198+
}
199+
}
200+
var client Client
201+
if test.region != "" {
202+
client, err = New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithAzureRegion(test.region))
203+
if err != nil {
204+
t.Fatal(err)
205+
}
206+
} else {
207+
client, err = New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient))
208+
if err != nil {
209+
t.Fatal(err)
210+
}
211+
}
212+
213+
t.Cleanup(func() {
214+
os.Unsetenv("MSAL_FORCE_REGION")
215+
})
216+
if test.region == "" {
217+
if test.envRegion != "" {
218+
if client.base.AuthParams.AuthorityInfo.Region != test.envRegion {
219+
t.Fatalf("wanted %q, got %q", test.envRegion, client.base.AuthParams.AuthorityInfo.Region)
220+
}
221+
}
222+
} else {
223+
if test.region == "DisableMsalForceRegion" {
224+
if client.base.AuthParams.AuthorityInfo.Region != "" {
225+
t.Fatalf("wanted empty, got %q", client.base.AuthParams.AuthorityInfo.Region)
226+
}
227+
} else {
228+
229+
if client.base.AuthParams.AuthorityInfo.Region != test.region {
230+
t.Fatalf("wanted %q, got %q", test.region, client.base.AuthParams.AuthorityInfo.Region)
231+
}
232+
}
233+
}
234+
}
235+
}
236+
167237
func TestAcquireTokenOnBehalfOf(t *testing.T) {
168238
// this test is an offline version of TestOnBehalfOf in integration_test.go
169239
cred, err := NewCredFromSecret(fakeSecret)

0 commit comments

Comments
 (0)