Skip to content

Commit c5a14c4

Browse files
authored
Refactor instance/region discovery (#763)
* Refactor instance/region discovery logic * Address PR feedback * Address PR feedback * Return earlier if instance discovery is disabled
1 parent a3a039f commit c5a14c4

File tree

2 files changed

+51
-38
lines changed

2 files changed

+51
-38
lines changed

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AadInstanceDiscoveryProvider.java

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -62,40 +62,50 @@ static InstanceDiscoveryMetadataEntry getMetadataEntry(URL authorityUrl,
6262
boolean validateAuthority,
6363
MsalRequest msalRequest,
6464
ServiceBundle serviceBundle) {
65-
String host = authorityUrl.getHost();
6665

67-
if (shouldUseRegionalEndpoint(msalRequest)) {
68-
//Server side telemetry requires the result from region discovery when any part of the region API is used
69-
String detectedRegion = discoverRegion(msalRequest, serviceBundle);
66+
String host = authorityUrl.getHost();
7067

71-
if (msalRequest.application().azureRegion() != null) {
72-
host = getRegionalizedHost(authorityUrl.getHost(), msalRequest.application().azureRegion());
68+
//If instanceDiscovery flag set to false, cache a basic instance metadata entry to skip future lookups
69+
if (!msalRequest.application().instanceDiscovery()) {
70+
if (cache.get(host) == null) {
71+
log.debug("Instance discovery set to false, caching a default entry.");
72+
cacheInstanceDiscoveryMetadata(host);
7373
}
74+
return cache.get(host);
75+
}
7476

75-
//If region autodetection is enabled and a specific region not already set,
76-
// set the application's region to the discovered region so that future requests can skip the IMDS endpoint call
77-
if (null == msalRequest.application().azureRegion() && msalRequest.application().autoDetectRegion()
78-
&& null != detectedRegion) {
79-
msalRequest.application().azureRegion = detectedRegion;
80-
}
81-
cacheRegionInstanceMetadata(authorityUrl.getHost(), msalRequest.application().azureRegion());
82-
serviceBundle.getServerSideTelemetry().getCurrentRequest().regionOutcome(
83-
determineRegionOutcome(detectedRegion, msalRequest.application().azureRegion(), msalRequest.application().autoDetectRegion()));
77+
//If a region was set by an app developer or previously found through autodetection, adjust the authority host to use it
78+
if (shouldUseRegionalEndpoint(msalRequest) && msalRequest.application().azureRegion() != null) {
79+
host = getRegionalizedHost(authorityUrl.getHost(), msalRequest.application().azureRegion());
8480
}
8581

86-
InstanceDiscoveryMetadataEntry result = cache.get(host);
82+
//If there is no cached instance metadata, do instance discovery cache the result
83+
if (cache.get(host) == null) {
84+
log.debug("No cached instance metadata, will attempt instance discovery.");
8785

88-
if (result == null) {
89-
if(msalRequest.application().instanceDiscovery() && !instanceDiscoveryFailed){
90-
doInstanceDiscoveryAndCache(authorityUrl, validateAuthority, msalRequest, serviceBundle);
91-
} else {
92-
// instanceDiscovery flag is set to False. Do not perform instanceDiscovery.
93-
return InstanceDiscoveryMetadataEntry.builder().
94-
preferredCache(host).
95-
preferredNetwork(host).
96-
aliases(Collections.singleton(host)).
97-
build();
86+
if (shouldUseRegionalEndpoint(msalRequest)) {
87+
log.debug("Region API used, will attempt to discover Azure region.");
88+
89+
//Server side telemetry requires the result from region discovery when any part of the region API is used
90+
String detectedRegion = discoverRegion(msalRequest, serviceBundle);
91+
92+
//If region autodetection is enabled and a specific region was not already set, set the application's
93+
// region to the discovered region so that future requests can skip the IMDS endpoint call
94+
if (msalRequest.application().azureRegion() == null
95+
&& msalRequest.application().autoDetectRegion()
96+
&& detectedRegion != null) {
97+
log.debug(String.format("Region autodetection found %s, this region will be used for future calls.", detectedRegion));
98+
99+
msalRequest.application().azureRegion = detectedRegion;
100+
host = getRegionalizedHost(authorityUrl.getHost(), msalRequest.application().azureRegion());
101+
}
102+
103+
cacheRegionInstanceMetadata(authorityUrl.getHost(), host);
104+
serviceBundle.getServerSideTelemetry().getCurrentRequest().regionOutcome(
105+
determineRegionOutcome(detectedRegion, msalRequest.application().azureRegion(), msalRequest.application().autoDetectRegion()));
98106
}
107+
108+
doInstanceDiscoveryAndCache(authorityUrl, validateAuthority, msalRequest, serviceBundle);
99109
}
100110

101111
return cache.get(host);
@@ -126,7 +136,7 @@ static AadInstanceDiscoveryResponse parseInstanceDiscoveryMetadata(String instan
126136
return aadInstanceDiscoveryResponse;
127137
}
128138

129-
static void cacheInstanceDiscoveryMetadata(String host,
139+
static void cacheInstanceDiscoveryResponse(String host,
130140
AadInstanceDiscoveryResponse aadInstanceDiscoveryResponse) {
131141

132142
if (aadInstanceDiscoveryResponse != null && aadInstanceDiscoveryResponse.metadata() != null) {
@@ -136,6 +146,11 @@ static void cacheInstanceDiscoveryMetadata(String host,
136146
}
137147
}
138148
}
149+
150+
cacheInstanceDiscoveryMetadata(host);
151+
}
152+
153+
static void cacheInstanceDiscoveryMetadata(String host) {
139154
cache.putIfAbsent(host, InstanceDiscoveryMetadataEntry.builder().
140155
preferredCache(host).
141156
preferredNetwork(host).
@@ -164,14 +179,13 @@ private static boolean shouldUseRegionalEndpoint(MsalRequest msalRequest){
164179
return false;
165180
}
166181

167-
static void cacheRegionInstanceMetadata(String host, String region) {
182+
static void cacheRegionInstanceMetadata(String originalHost, String regionalHost) {
168183

169184
Set<String> aliases = new HashSet<>();
170-
aliases.add(host);
171-
String regionalHost = getRegionalizedHost(host, region);
185+
aliases.add(originalHost);
172186

173187
cache.putIfAbsent(regionalHost, InstanceDiscoveryMetadataEntry.builder().
174-
preferredCache(host).
188+
preferredCache(originalHost).
175189
preferredNetwork(regionalHost).
176190
aliases(aliases).
177191
build());
@@ -229,12 +243,10 @@ static AadInstanceDiscoveryResponse sendInstanceDiscoveryRequest(URL authorityUr
229243
MsalRequest msalRequest,
230244
ServiceBundle serviceBundle) {
231245

232-
IHttpResponse httpResponse = null;
233-
234246
String instanceDiscoveryRequestUrl = getInstanceDiscoveryEndpoint(authorityUrl) +
235247
formInstanceDiscoveryParameters(authorityUrl);
236248

237-
httpResponse = executeRequest(instanceDiscoveryRequestUrl, msalRequest.headers().getReadonlyHeaderMap(), msalRequest, serviceBundle);
249+
IHttpResponse httpResponse = executeRequest(instanceDiscoveryRequestUrl, msalRequest.headers().getReadonlyHeaderMap(), msalRequest, serviceBundle);
238250

239251
AadInstanceDiscoveryResponse response = JsonHelper.convertJsonToObject(httpResponse.body(), AadInstanceDiscoveryResponse.class);
240252

@@ -244,7 +256,8 @@ static AadInstanceDiscoveryResponse sendInstanceDiscoveryRequest(URL authorityUr
244256
throw MsalServiceExceptionFactory.fromHttpResponse(httpResponse);
245257
}
246258
// instance discovery failed due to reasons other than an invalid authority, do not perform instance discovery again in this environment.
247-
instanceDiscoveryFailed = true;
259+
log.debug("Instance discovery failed due to an unknown error, no more instance discovery attempts will be made.");
260+
cacheInstanceDiscoveryMetadata(authorityUrl.getHost());
248261
}
249262

250263
return response;
@@ -295,7 +308,7 @@ static String discoverRegion(MsalRequest msalRequest, ServiceBundle serviceBundl
295308

296309
//Check if the REGION_NAME environment variable has a value for the region
297310
if (System.getenv(REGION_NAME) != null) {
298-
log.info("Region found in environment variable: " + System.getenv(REGION_NAME));
311+
log.info(String.format("Region found in environment variable: %s",System.getenv(REGION_NAME)));
299312
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_ENV_VARIABLE.telemetryValue);
300313

301314
return System.getenv(REGION_NAME);
@@ -351,7 +364,7 @@ private static void doInstanceDiscoveryAndCache(URL authorityUrl,
351364
}
352365
}
353366

354-
cacheInstanceDiscoveryMetadata(authorityUrl.getHost(), aadInstanceDiscoveryResponse);
367+
cacheInstanceDiscoveryResponse(authorityUrl.getHost(), aadInstanceDiscoveryResponse);
355368
}
356369

357370
private static void validate(AadInstanceDiscoveryResponse aadInstanceDiscoveryResponse) {

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AbstractClientApplicationBase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ public T instanceDiscovery(boolean val) {
708708
instanceDiscovery = builder.instanceDiscovery;
709709

710710
if (aadAadInstanceDiscoveryResponse != null) {
711-
AadInstanceDiscoveryProvider.cacheInstanceDiscoveryMetadata(
711+
AadInstanceDiscoveryProvider.cacheInstanceDiscoveryResponse(
712712
authenticationAuthority.host,
713713
aadAadInstanceDiscoveryResponse);
714714
}

0 commit comments

Comments
 (0)