Skip to content

Commit a9d3ad3

Browse files
authored
Region discovery support (#343)
* Add Azure regional support * Refactor * Add logs for success/failure to find regional info * Extra log
1 parent db43c66 commit a9d3ad3

File tree

3 files changed

+167
-8
lines changed

3 files changed

+167
-8
lines changed

src/main/java/com/microsoft/aad/msal4j/AadInstanceDiscoveryProvider.java

Lines changed: 94 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,36 @@
33

44
package com.microsoft.aad.msal4j;
55

6+
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
7+
8+
import org.slf4j.Logger;
9+
import org.slf4j.LoggerFactory;
10+
611
import java.net.URL;
712
import java.util.Arrays;
813
import java.util.Collections;
914
import java.util.Set;
1015
import java.util.TreeSet;
16+
import java.util.Map;
17+
import java.util.HashMap;
1118
import java.util.concurrent.ConcurrentHashMap;
1219

1320
class AadInstanceDiscoveryProvider {
1421

1522
private final static String DEFAULT_TRUSTED_HOST = "login.microsoftonline.com";
1623
private final static String AUTHORIZE_ENDPOINT_TEMPLATE = "https://{host}/{tenant}/oauth2/v2.0/authorize";
1724
private final static String INSTANCE_DISCOVERY_ENDPOINT_TEMPLATE = "https://{host}/common/discovery/instance";
25+
private final static String INSTANCE_DISCOVERY_ENDPOINT_TEMPLATE_WITH_REGION = "https://{region}.{host}/common/discovery/instance";
1826
private final static String INSTANCE_DISCOVERY_REQUEST_PARAMETERS_TEMPLATE = "?api-version=1.1&authorization_endpoint={authorizeEndpoint}";
27+
private final static String REGION_NAME = "REGION_NAME";
28+
// For information of the current api-version refer: https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service#versioning
29+
private final static String DEFAULT_API_VERSION = "2020-06-01";
30+
private final static String IMDS_ENDPOINT = "https://169.254.169.254/metadata/instance/compute/location?" + DEFAULT_API_VERSION + "&format=text";
1931

2032
final static TreeSet<String> TRUSTED_HOSTS_SET = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
2133

34+
private final static Logger log = LoggerFactory.getLogger(HttpHelper.class);
35+
2236
static ConcurrentHashMap<String, InstanceDiscoveryMetadataEntry> cache = new ConcurrentHashMap<>();
2337

2438
static {
@@ -102,26 +116,98 @@ private static String getInstanceDiscoveryEndpoint(String host) {
102116
replace("{host}", discoveryHost);
103117
}
104118

119+
private static String getInstanceDiscoveryEndpointWithRegion(String host, String region) {
120+
121+
String discoveryHost = TRUSTED_HOSTS_SET.contains(host) ? host : DEFAULT_TRUSTED_HOST;
122+
123+
return INSTANCE_DISCOVERY_ENDPOINT_TEMPLATE_WITH_REGION.
124+
replace("{region}", region).
125+
replace("{host}", discoveryHost);
126+
}
127+
105128
private static AadInstanceDiscoveryResponse sendInstanceDiscoveryRequest(URL authorityUrl,
106129
MsalRequest msalRequest,
107130
ServiceBundle serviceBundle) {
108131

109-
String instanceDiscoveryRequestUrl = getInstanceDiscoveryEndpoint(authorityUrl.getAuthority()) +
110-
INSTANCE_DISCOVERY_REQUEST_PARAMETERS_TEMPLATE.replace("{authorizeEndpoint}",
111-
getAuthorizeEndpoint(authorityUrl.getAuthority(),
112-
Authority.getTenant(authorityUrl, Authority.detectAuthorityType(authorityUrl))));
132+
String region = StringHelper.EMPTY_STRING;
133+
IHttpResponse httpResponse = null;
134+
135+
//If the autoDetectRegion parameter in the request is set, attempt to discover the region
136+
if (msalRequest.application().autoDetectRegion()) {
137+
region = discoverRegion(msalRequest, serviceBundle);
138+
}
139+
140+
//If the region is known, attempt to make instance discovery request with region endpoint
141+
if (!region.isEmpty()) {
142+
String instanceDiscoveryRequestUrl = getInstanceDiscoveryEndpointWithRegion(authorityUrl.getAuthority(), region) +
143+
formInstanceDiscoveryParameters(authorityUrl);
144+
145+
httpResponse = httpRequest(instanceDiscoveryRequestUrl, msalRequest.headers().getReadonlyHeaderMap(), msalRequest, serviceBundle);
146+
}
147+
148+
//If the region is unknown or the instance discovery failed at the region endpoint, try the global endpoint
149+
if (region.isEmpty() || httpResponse == null || httpResponse.statusCode() != HTTPResponse.SC_OK) {
150+
if (!region.isEmpty()) {
151+
log.warn("Could not retrieve regional instance discovery metadata, falling back to global endpoint");
152+
}
153+
154+
String instanceDiscoveryRequestUrl = getInstanceDiscoveryEndpoint(authorityUrl.getAuthority()) +
155+
formInstanceDiscoveryParameters(authorityUrl);
156+
157+
httpResponse = httpRequest(instanceDiscoveryRequestUrl, msalRequest.headers().getReadonlyHeaderMap(), msalRequest, serviceBundle);
158+
}
113159

160+
return JsonHelper.convertJsonToObject(httpResponse.body(), AadInstanceDiscoveryResponse.class);
161+
}
162+
163+
private static String formInstanceDiscoveryParameters(URL authorityUrl) {
164+
return INSTANCE_DISCOVERY_REQUEST_PARAMETERS_TEMPLATE.replace("{authorizeEndpoint}",
165+
getAuthorizeEndpoint(authorityUrl.getAuthority(),
166+
Authority.getTenant(authorityUrl, Authority.detectAuthorityType(authorityUrl))));
167+
}
168+
169+
private static IHttpResponse httpRequest(String requestUrl, Map<String, String> headers, MsalRequest msalRequest, ServiceBundle serviceBundle) {
114170
HttpRequest httpRequest = new HttpRequest(
115171
HttpMethod.GET,
116-
instanceDiscoveryRequestUrl,
117-
msalRequest.headers().getReadonlyHeaderMap());
172+
requestUrl,
173+
headers);
118174

119-
IHttpResponse httpResponse= HttpHelper.executeHttpRequest(
175+
return HttpHelper.executeHttpRequest(
120176
httpRequest,
121177
msalRequest.requestContext(),
122178
serviceBundle);
179+
}
123180

124-
return JsonHelper.convertJsonToObject(httpResponse.body(), AadInstanceDiscoveryResponse.class);
181+
private static String discoverRegion(MsalRequest msalRequest, ServiceBundle serviceBundle) {
182+
183+
//Check if the REGION_NAME environment variable has a value for the region
184+
if (System.getenv(REGION_NAME) != null) {
185+
log.info("Region found in environment variable: " + System.getenv(REGION_NAME));
186+
187+
return System.getenv(REGION_NAME);
188+
}
189+
190+
try {
191+
//Check the IMDS endpoint to retrieve current region (will only work if application is running in an Azure VM)
192+
Map<String, String> headers = new HashMap<>();
193+
headers.put("Metadata", "true");
194+
IHttpResponse httpResponse = httpRequest(IMDS_ENDPOINT, headers, msalRequest, serviceBundle);
195+
196+
//If call to IMDS endpoint was successful, return region from response body
197+
if (httpResponse.statusCode() == HTTPResponse.SC_OK && !httpResponse.body().isEmpty()) {
198+
log.info("Region retrieved from IMDS endpoint: " + httpResponse.body());
199+
200+
return httpResponse.body();
201+
}
202+
203+
log.warn(String.format("Call to local IMDS failed with status code: %s, or response was empty", httpResponse.statusCode()));
204+
205+
return StringHelper.EMPTY_STRING;
206+
} catch (Exception e) {
207+
//IMDS call failed, cannot find region
208+
log.warn(String.format("Exception during call to local IMDS endpoint: %s", e.getMessage()));
209+
return StringHelper.EMPTY_STRING;
210+
}
125211
}
126212

127213
private static void doInstanceDiscoveryAndCache(URL authorityUrl,

src/main/java/com/microsoft/aad/msal4j/AbstractClientApplicationBase.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ abstract class AbstractClientApplicationBase implements IClientApplicationBase {
9696
@Getter
9797
private String clientCapabilities;
9898

99+
@Accessors(fluent = true)
100+
@Getter
101+
private boolean autoDetectRegion;
102+
99103
@Override
100104
public CompletableFuture<IAuthenticationResult> acquireToken(AuthorizationCodeParameters parameters) {
101105

@@ -292,6 +296,7 @@ abstract static class Builder<T extends Builder<T>> {
292296
private ITokenCacheAccessAspect tokenCacheAccessAspect;
293297
private AadInstanceDiscoveryResponse aadInstanceDiscoveryResponse;
294298
private String clientCapabilities;
299+
private boolean autoDetectRegion;
295300
private Integer connectTimeoutForDefaultHttpClient;
296301
private Integer readTimeoutForDefaultHttpClient;
297302

@@ -573,6 +578,22 @@ public T clientCapabilities(Set<String> capabilities) {
573578
return self();
574579
}
575580

581+
/**
582+
* Indicates that the library should attempt to discover the Azure region the application is running in when
583+
* fetching the instance discovery metadata.
584+
*
585+
* If the region is found, token requests will be sent to the regional ESTS endpoint rather than the global endpoint.
586+
* If region information could not be found, the library will fall back to using the global endpoint, which is also
587+
* the default behavior if this value is not set.
588+
*
589+
* @param val boolean (default is false)
590+
* @return instance of the Builder on which method was called
591+
*/
592+
public T autoDetectRegion(boolean val) {
593+
autoDetectRegion = val;
594+
return self();
595+
}
596+
576597
abstract AbstractClientApplicationBase build();
577598
}
578599

@@ -599,6 +620,7 @@ public T clientCapabilities(Set<String> capabilities) {
599620
tokenCache = new TokenCache(builder.tokenCacheAccessAspect);
600621
aadAadInstanceDiscoveryResponse = builder.aadInstanceDiscoveryResponse;
601622
clientCapabilities = builder.clientCapabilities;
623+
autoDetectRegion = builder.autoDetectRegion;
602624

603625
if(aadAadInstanceDiscoveryResponse != null){
604626
AadInstanceDiscoveryProvider.cacheInstanceDiscoveryMetadata(

src/test/java/com/microsoft/aad/msal4j/AadInstanceDiscoveryTest.java

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,55 @@ public void aadInstanceDiscoveryTest_responseSetByDeveloper_invalidJson() throws
135135
.aadInstanceDiscoveryResponse(instanceDiscoveryResponse)
136136
.build();
137137
}
138+
139+
@Test()
140+
public void aadInstanceDiscoveryTest_AutoDetectRegion_NoRegionDetected() throws Exception{
141+
142+
String instanceDiscoveryResponse = TestHelper.readResource(
143+
this.getClass(),
144+
"/instance_discovery_data/aad_instance_discovery_response_valid.json");
145+
146+
PublicClientApplication app = PublicClientApplication.builder("client_id")
147+
.aadInstanceDiscoveryResponse(instanceDiscoveryResponse)
148+
.autoDetectRegion(true)
149+
.build();
150+
151+
AuthorizationCodeParameters parameters = AuthorizationCodeParameters.builder(
152+
"code", new URI("http://my.redirect.com")).build();
153+
154+
MsalRequest msalRequest = new AuthorizationCodeRequest(
155+
parameters,
156+
app,
157+
new RequestContext(app, PublicApi.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE, parameters));
158+
159+
URL authority = new URL(app.authority());
160+
161+
PowerMock.mockStaticPartial(
162+
AadInstanceDiscoveryProvider.class,
163+
"discoverRegion");
164+
165+
PowerMock.expectPrivate(
166+
AadInstanceDiscoveryProvider.class,
167+
"discoverRegion",
168+
msalRequest,
169+
app.getServiceBundle()).andThrow(new AssertionError()).anyTimes();
170+
171+
PowerMock.replay(AadInstanceDiscoveryProvider.class);
172+
173+
InstanceDiscoveryMetadataEntry entry = AadInstanceDiscoveryProvider.getMetadataEntry(
174+
authority,
175+
false,
176+
msalRequest,
177+
app.getServiceBundle());
178+
179+
//Region detection will have been performed in the expected discoverRegion method, but these tests (likely) aren't
180+
// being run in an Azure VM and nstance discovery will fall back to the global endpoint (login.microsoftonline.com)
181+
Assert.assertEquals(entry.preferredNetwork(), "login.microsoftonline.com");
182+
Assert.assertEquals(entry.preferredCache(), "login.windows.net");
183+
Assert.assertEquals(entry.aliases().size(), 4);
184+
Assert.assertTrue(entry.aliases().contains("login.microsoftonline.com"));
185+
Assert.assertTrue(entry.aliases().contains("login.windows.net"));
186+
Assert.assertTrue(entry.aliases().contains("login.microsoft.com"));
187+
Assert.assertTrue(entry.aliases().contains("sts.windows.net"));
188+
}
138189
}

0 commit comments

Comments
 (0)