|
3 | 3 |
|
4 | 4 | package com.microsoft.aad.msal4j; |
5 | 5 |
|
| 6 | +import com.nimbusds.oauth2.sdk.http.HTTPResponse; |
| 7 | + |
| 8 | +import org.slf4j.Logger; |
| 9 | +import org.slf4j.LoggerFactory; |
| 10 | + |
6 | 11 | import java.net.URL; |
7 | 12 | import java.util.Arrays; |
8 | 13 | import java.util.Collections; |
9 | 14 | import java.util.Set; |
10 | 15 | import java.util.TreeSet; |
| 16 | +import java.util.Map; |
| 17 | +import java.util.HashMap; |
11 | 18 | import java.util.concurrent.ConcurrentHashMap; |
12 | 19 |
|
13 | 20 | class AadInstanceDiscoveryProvider { |
14 | 21 |
|
15 | 22 | private final static String DEFAULT_TRUSTED_HOST = "login.microsoftonline.com"; |
16 | 23 | private final static String AUTHORIZE_ENDPOINT_TEMPLATE = "https://{host}/{tenant}/oauth2/v2.0/authorize"; |
17 | 24 | private final static String INSTANCE_DISCOVERY_ENDPOINT_TEMPLATE = "https://{host}/common/discovery/instance"; |
| 25 | + private final static String INSTANCE_DISCOVERY_ENDPOINT_TEMPLATE_WITH_REGION = "https://{region}.{host}/common/discovery/instance"; |
18 | 26 | private final static String INSTANCE_DISCOVERY_REQUEST_PARAMETERS_TEMPLATE = "?api-version=1.1&authorization_endpoint={authorizeEndpoint}"; |
| 27 | + private final static String REGION_NAME = "REGION_NAME"; |
| 28 | + // For information of the current api-version refer: https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service#versioning |
| 29 | + private final static String DEFAULT_API_VERSION = "2020-06-01"; |
| 30 | + private final static String IMDS_ENDPOINT = "https://169.254.169.254/metadata/instance/compute/location?" + DEFAULT_API_VERSION + "&format=text"; |
19 | 31 |
|
20 | 32 | final static TreeSet<String> TRUSTED_HOSTS_SET = new TreeSet<>(String.CASE_INSENSITIVE_ORDER); |
21 | 33 |
|
| 34 | + private final static Logger log = LoggerFactory.getLogger(HttpHelper.class); |
| 35 | + |
22 | 36 | static ConcurrentHashMap<String, InstanceDiscoveryMetadataEntry> cache = new ConcurrentHashMap<>(); |
23 | 37 |
|
24 | 38 | static { |
@@ -102,26 +116,98 @@ private static String getInstanceDiscoveryEndpoint(String host) { |
102 | 116 | replace("{host}", discoveryHost); |
103 | 117 | } |
104 | 118 |
|
| 119 | + private static String getInstanceDiscoveryEndpointWithRegion(String host, String region) { |
| 120 | + |
| 121 | + String discoveryHost = TRUSTED_HOSTS_SET.contains(host) ? host : DEFAULT_TRUSTED_HOST; |
| 122 | + |
| 123 | + return INSTANCE_DISCOVERY_ENDPOINT_TEMPLATE_WITH_REGION. |
| 124 | + replace("{region}", region). |
| 125 | + replace("{host}", discoveryHost); |
| 126 | + } |
| 127 | + |
105 | 128 | private static AadInstanceDiscoveryResponse sendInstanceDiscoveryRequest(URL authorityUrl, |
106 | 129 | MsalRequest msalRequest, |
107 | 130 | ServiceBundle serviceBundle) { |
108 | 131 |
|
109 | | - String instanceDiscoveryRequestUrl = getInstanceDiscoveryEndpoint(authorityUrl.getAuthority()) + |
110 | | - INSTANCE_DISCOVERY_REQUEST_PARAMETERS_TEMPLATE.replace("{authorizeEndpoint}", |
111 | | - getAuthorizeEndpoint(authorityUrl.getAuthority(), |
112 | | - Authority.getTenant(authorityUrl, Authority.detectAuthorityType(authorityUrl)))); |
| 132 | + String region = StringHelper.EMPTY_STRING; |
| 133 | + IHttpResponse httpResponse = null; |
| 134 | + |
| 135 | + //If the autoDetectRegion parameter in the request is set, attempt to discover the region |
| 136 | + if (msalRequest.application().autoDetectRegion()) { |
| 137 | + region = discoverRegion(msalRequest, serviceBundle); |
| 138 | + } |
| 139 | + |
| 140 | + //If the region is known, attempt to make instance discovery request with region endpoint |
| 141 | + if (!region.isEmpty()) { |
| 142 | + String instanceDiscoveryRequestUrl = getInstanceDiscoveryEndpointWithRegion(authorityUrl.getAuthority(), region) + |
| 143 | + formInstanceDiscoveryParameters(authorityUrl); |
| 144 | + |
| 145 | + httpResponse = httpRequest(instanceDiscoveryRequestUrl, msalRequest.headers().getReadonlyHeaderMap(), msalRequest, serviceBundle); |
| 146 | + } |
| 147 | + |
| 148 | + //If the region is unknown or the instance discovery failed at the region endpoint, try the global endpoint |
| 149 | + if (region.isEmpty() || httpResponse == null || httpResponse.statusCode() != HTTPResponse.SC_OK) { |
| 150 | + if (!region.isEmpty()) { |
| 151 | + log.warn("Could not retrieve regional instance discovery metadata, falling back to global endpoint"); |
| 152 | + } |
| 153 | + |
| 154 | + String instanceDiscoveryRequestUrl = getInstanceDiscoveryEndpoint(authorityUrl.getAuthority()) + |
| 155 | + formInstanceDiscoveryParameters(authorityUrl); |
| 156 | + |
| 157 | + httpResponse = httpRequest(instanceDiscoveryRequestUrl, msalRequest.headers().getReadonlyHeaderMap(), msalRequest, serviceBundle); |
| 158 | + } |
113 | 159 |
|
| 160 | + return JsonHelper.convertJsonToObject(httpResponse.body(), AadInstanceDiscoveryResponse.class); |
| 161 | + } |
| 162 | + |
| 163 | + private static String formInstanceDiscoveryParameters(URL authorityUrl) { |
| 164 | + return INSTANCE_DISCOVERY_REQUEST_PARAMETERS_TEMPLATE.replace("{authorizeEndpoint}", |
| 165 | + getAuthorizeEndpoint(authorityUrl.getAuthority(), |
| 166 | + Authority.getTenant(authorityUrl, Authority.detectAuthorityType(authorityUrl)))); |
| 167 | + } |
| 168 | + |
| 169 | + private static IHttpResponse httpRequest(String requestUrl, Map<String, String> headers, MsalRequest msalRequest, ServiceBundle serviceBundle) { |
114 | 170 | HttpRequest httpRequest = new HttpRequest( |
115 | 171 | HttpMethod.GET, |
116 | | - instanceDiscoveryRequestUrl, |
117 | | - msalRequest.headers().getReadonlyHeaderMap()); |
| 172 | + requestUrl, |
| 173 | + headers); |
118 | 174 |
|
119 | | - IHttpResponse httpResponse= HttpHelper.executeHttpRequest( |
| 175 | + return HttpHelper.executeHttpRequest( |
120 | 176 | httpRequest, |
121 | 177 | msalRequest.requestContext(), |
122 | 178 | serviceBundle); |
| 179 | + } |
123 | 180 |
|
124 | | - return JsonHelper.convertJsonToObject(httpResponse.body(), AadInstanceDiscoveryResponse.class); |
| 181 | + private static String discoverRegion(MsalRequest msalRequest, ServiceBundle serviceBundle) { |
| 182 | + |
| 183 | + //Check if the REGION_NAME environment variable has a value for the region |
| 184 | + if (System.getenv(REGION_NAME) != null) { |
| 185 | + log.info("Region found in environment variable: " + System.getenv(REGION_NAME)); |
| 186 | + |
| 187 | + return System.getenv(REGION_NAME); |
| 188 | + } |
| 189 | + |
| 190 | + try { |
| 191 | + //Check the IMDS endpoint to retrieve current region (will only work if application is running in an Azure VM) |
| 192 | + Map<String, String> headers = new HashMap<>(); |
| 193 | + headers.put("Metadata", "true"); |
| 194 | + IHttpResponse httpResponse = httpRequest(IMDS_ENDPOINT, headers, msalRequest, serviceBundle); |
| 195 | + |
| 196 | + //If call to IMDS endpoint was successful, return region from response body |
| 197 | + if (httpResponse.statusCode() == HTTPResponse.SC_OK && !httpResponse.body().isEmpty()) { |
| 198 | + log.info("Region retrieved from IMDS endpoint: " + httpResponse.body()); |
| 199 | + |
| 200 | + return httpResponse.body(); |
| 201 | + } |
| 202 | + |
| 203 | + log.warn(String.format("Call to local IMDS failed with status code: %s, or response was empty", httpResponse.statusCode())); |
| 204 | + |
| 205 | + return StringHelper.EMPTY_STRING; |
| 206 | + } catch (Exception e) { |
| 207 | + //IMDS call failed, cannot find region |
| 208 | + log.warn(String.format("Exception during call to local IMDS endpoint: %s", e.getMessage())); |
| 209 | + return StringHelper.EMPTY_STRING; |
| 210 | + } |
125 | 211 | } |
126 | 212 |
|
127 | 213 | private static void doInstanceDiscoveryAndCache(URL authorityUrl, |
|
0 commit comments