Skip to content

Commit 0644297

Browse files
authored
Merge pull request #129 from AzureAD/sagonzal/customHttpClient
Let users configure HTTP client
2 parents f0f154d + 78e47aa commit 0644297

35 files changed

+1136
-763
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.microsoft.aad.msal4j;
5+
6+
import org.apache.http.Header;
7+
import org.apache.http.client.methods.CloseableHttpResponse;
8+
import org.apache.http.client.methods.HttpGet;
9+
import org.apache.http.client.methods.HttpPost;
10+
import org.apache.http.client.methods.HttpRequestBase;
11+
import org.apache.http.entity.ContentType;
12+
import org.apache.http.entity.StringEntity;
13+
import org.apache.http.impl.client.CloseableHttpClient;
14+
import org.apache.http.impl.client.HttpClients;
15+
import org.apache.http.util.EntityUtils;
16+
17+
import java.io.IOException;
18+
import java.util.Collections;
19+
import java.util.HashMap;
20+
import java.util.List;
21+
import java.util.Map;
22+
23+
class ApacheHttpClientAdapter implements IHttpClient {
24+
25+
private CloseableHttpClient httpClient;
26+
27+
ApacheHttpClientAdapter(){
28+
this.httpClient = HttpClients.createDefault();
29+
}
30+
31+
@Override
32+
public IHttpResponse send(HttpRequest httpRequest) throws Exception {
33+
34+
HttpRequestBase request = buildApacheRequestFromMsalRequest(httpRequest);
35+
CloseableHttpResponse response = httpClient.execute(request);
36+
37+
return buildMsalResponseFromApacheResponse(response);
38+
}
39+
40+
41+
private HttpRequestBase buildApacheRequestFromMsalRequest(HttpRequest httpRequest){
42+
43+
if(httpRequest.httpMethod() == HttpMethod.GET){
44+
return builGetRequest(httpRequest);
45+
} else if(httpRequest.httpMethod() == HttpMethod.POST){
46+
return buildPostRequest(httpRequest);
47+
} else {
48+
throw new IllegalArgumentException("HttpRequest method should be either GET or POST");
49+
}
50+
}
51+
52+
private HttpGet builGetRequest(HttpRequest httpRequest){
53+
HttpGet httpGet = new HttpGet(httpRequest.url().toString());
54+
55+
for(Map.Entry<String, String> entry: httpRequest.headers().entrySet()){
56+
httpGet.setHeader(entry.getKey(), entry.getValue());
57+
}
58+
59+
return httpGet;
60+
}
61+
62+
private HttpPost buildPostRequest(HttpRequest httpRequest){
63+
64+
HttpPost httpPost = new HttpPost(httpRequest.url().toString());
65+
for(Map.Entry<String, String> entry: httpRequest.headers().entrySet()){
66+
httpPost.setHeader(entry.getKey(), entry.getValue());
67+
}
68+
69+
String contentTypeHeaderValue = httpRequest.headerValue("Content-Type");
70+
ContentType contentType = ContentType.getByMimeType(contentTypeHeaderValue);
71+
StringEntity stringEntity = new StringEntity(httpRequest.body(), contentType);
72+
73+
httpPost.setEntity(stringEntity);
74+
return httpPost;
75+
}
76+
77+
private IHttpResponse buildMsalResponseFromApacheResponse(CloseableHttpResponse apacheResponse)
78+
throws IOException {
79+
80+
IHttpResponse httpResponse = new HttpResponse();
81+
((HttpResponse) httpResponse).statusCode(apacheResponse.getStatusLine().getStatusCode());
82+
83+
Map<String, List<String>> headers = new HashMap<>();
84+
for(Header header: apacheResponse.getAllHeaders()){
85+
headers.put(header.getName(), Collections.singletonList(header.getValue()));
86+
}
87+
((HttpResponse) httpResponse).headers(headers);
88+
89+
String responseBody = EntityUtils.toString(apacheResponse.getEntity(), "UTF-8");
90+
((HttpResponse) httpResponse).body(responseBody);
91+
return httpResponse;
92+
}
93+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.microsoft.aad.msal4j;
5+
6+
import labapi.LabResponse;
7+
import labapi.LabUserProvider;
8+
import labapi.NationalCloud;
9+
import org.testng.Assert;
10+
import org.testng.annotations.BeforeClass;
11+
import org.testng.annotations.Test;
12+
13+
import java.util.Collections;
14+
15+
public class HttpClientIT {
16+
private LabUserProvider labUserProvider;
17+
18+
@BeforeClass
19+
public void setUp() {
20+
labUserProvider = LabUserProvider.getInstance();
21+
}
22+
23+
@Test
24+
public void acquireToken_okHttpClient() throws Exception {
25+
26+
LabResponse labResponse = getManagedUserAccountWithPassword();
27+
assertAcquireTokenCommon(labResponse, new OkHttpClientAdapter());
28+
}
29+
30+
@Test
31+
public void acquireToken_apacheHttpClient() throws Exception {
32+
33+
LabResponse labResponse = getManagedUserAccountWithPassword();
34+
assertAcquireTokenCommon(labResponse, new ApacheHttpClientAdapter());
35+
}
36+
37+
private void assertAcquireTokenCommon(LabResponse labResponse, IHttpClient httpClient)
38+
throws Exception{
39+
PublicClientApplication pca = PublicClientApplication.builder(
40+
labResponse.getAppId()).
41+
authority(TestConstants.ORGANIZATIONS_AUTHORITY).
42+
httpClient(httpClient).
43+
build();
44+
45+
IAuthenticationResult result = pca.acquireToken(UserNamePasswordParameters.
46+
builder(Collections.singleton(TestConstants.GRAPH_DEFAULT_SCOPE),
47+
labResponse.getUser().getUpn(),
48+
labResponse.getUser().getPassword().toCharArray())
49+
.build())
50+
.get();
51+
52+
Assert.assertNotNull(result);
53+
Assert.assertNotNull(result.accessToken());
54+
Assert.assertNotNull(result.idToken());
55+
Assert.assertEquals(labResponse.getUser().getUpn(), result.account().username());
56+
}
57+
58+
private LabResponse getManagedUserAccountWithPassword(){
59+
LabResponse labResponse = labUserProvider.getDefaultUser(
60+
NationalCloud.AZURE_CLOUD,
61+
false);
62+
labUserProvider.getUserPassword(labResponse.getUser());
63+
64+
return labResponse;
65+
}
66+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.microsoft.aad.msal4j;
5+
6+
import okhttp3.Headers;
7+
import okhttp3.MediaType;
8+
import okhttp3.OkHttpClient;
9+
import okhttp3.Request;
10+
import okhttp3.RequestBody;
11+
import okhttp3.Response;
12+
import okhttp3.ResponseBody;
13+
14+
import java.io.IOException;
15+
16+
class OkHttpClientAdapter implements IHttpClient{
17+
18+
private OkHttpClient client;
19+
20+
OkHttpClientAdapter(){
21+
this.client = new OkHttpClient();
22+
}
23+
24+
@Override
25+
public IHttpResponse send(HttpRequest httpRequest) throws IOException {
26+
27+
Request request = buildOkRequestFromMsalRequest(httpRequest);
28+
29+
Response okHttpResponse= client.newCall(request).execute();
30+
return buildMsalResponseFromOkResponse(okHttpResponse);
31+
}
32+
33+
private Request buildOkRequestFromMsalRequest(HttpRequest httpRequest){
34+
35+
if(httpRequest.httpMethod() == HttpMethod.GET){
36+
return buildGetRequest(httpRequest);
37+
} else if(httpRequest.httpMethod() == HttpMethod.POST){
38+
return buildPostRequest(httpRequest);
39+
} else {
40+
throw new IllegalArgumentException("HttpRequest method should be either GET or POST");
41+
}
42+
}
43+
44+
private Request buildGetRequest(HttpRequest httpRequest){
45+
Headers headers = Headers.of(httpRequest.headers());
46+
47+
return new Request.Builder()
48+
.url(httpRequest.url())
49+
.headers(headers)
50+
.build();
51+
}
52+
53+
private Request buildPostRequest(HttpRequest httpRequest){
54+
Headers headers = Headers.of(httpRequest.headers());
55+
String contentType = httpRequest.headerValue("Content-Type");
56+
MediaType type = MediaType.parse(contentType);
57+
58+
RequestBody requestBody = RequestBody.create(type, httpRequest.body());
59+
60+
return new Request.Builder()
61+
.url(httpRequest.url())
62+
.post(requestBody)
63+
.headers(headers)
64+
.build();
65+
}
66+
67+
private IHttpResponse buildMsalResponseFromOkResponse(Response okHttpResponse) throws IOException{
68+
69+
IHttpResponse httpResponse = new HttpResponse();
70+
((HttpResponse) httpResponse).statusCode(okHttpResponse.code());
71+
72+
ResponseBody body = okHttpResponse.body();
73+
if(body != null){
74+
((HttpResponse) httpResponse).body(body.string());
75+
}
76+
77+
Headers headers = okHttpResponse.headers();
78+
if(headers != null){
79+
((HttpResponse) httpResponse).headers(headers.toMultimap());
80+
}
81+
return httpResponse;
82+
}
83+
}

src/integrationtest/java/com.microsoft.aad.msal4j/UsernamePasswordIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ public void acquireTokenWithUsernamePassword_ADFSv2() throws Exception{
7373
assertAcquireTokenCommon(labResponse, password);
7474
}
7575

76-
public void assertAcquireTokenCommon(LabResponse labResponse, String password)
76+
private void assertAcquireTokenCommon(LabResponse labResponse, String password)
7777
throws Exception{
7878
PublicClientApplication pca = PublicClientApplication.builder(
7979
labResponse.getAppId()).

src/integrationtest/java/labapi/KeyVaultSecretsProvider.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ private IClientCredential getClientCredentialFromKeyStore() {
7575
key = (PrivateKey) keystore.getKey(CERTIFICATE_ALIAS, null);
7676
publicCertificate = (X509Certificate) keystore.getCertificate(
7777
CERTIFICATE_ALIAS);
78+
7879
} catch (Exception e){
7980
throw new RuntimeException("Error getting certificate from keystore: " + e.getMessage());
8081
}

src/main/java/com/microsoft/aad/msal4j/AadInstanceDiscovery.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,18 +62,21 @@ private static String getInstanceDiscoveryEndpoint(String host) {
6262

6363
private static InstanceDiscoveryResponse sendInstanceDiscoveryRequest
6464
(URL authorityUrl, MsalRequest msalRequest,
65-
ServiceBundle serviceBundle) throws Exception {
65+
ServiceBundle serviceBundle) {
6666

6767
String instanceDiscoveryRequestUrl = getInstanceDiscoveryEndpoint(authorityUrl.getAuthority()) +
6868
INSTANCE_DISCOVERY_REQUEST_PARAMETERS_TEMPLATE.replace("{authorizeEndpoint}",
6969
getAuthorizeEndpoint(authorityUrl.getAuthority(),
7070
Authority.getTenant(authorityUrl, Authority.detectAuthorityType(authorityUrl))));
7171

72-
String json = HttpHelper.executeHttpRequest
73-
(log, HttpMethod.GET, instanceDiscoveryRequestUrl, msalRequest.headers().getReadonlyHeaderMap(),
74-
null, msalRequest.requestContext(), serviceBundle);
72+
HttpRequest httpRequest = new HttpRequest(
73+
HttpMethod.GET,
74+
instanceDiscoveryRequestUrl,
75+
msalRequest.headers().getReadonlyHeaderMap());
7576

76-
return JsonHelper.convertJsonToObject(json, InstanceDiscoveryResponse.class);
77+
IHttpResponse httpResponse= HttpHelper.executeHttpRequest(httpRequest, msalRequest.requestContext(), serviceBundle);
78+
79+
return JsonHelper.convertJsonToObject(httpResponse.body(), InstanceDiscoveryResponse.class);
7780
}
7881

7982
private static void validate(InstanceDiscoveryResponse instanceDiscoveryResponse) {

src/main/java/com/microsoft/aad/msal4j/ClientApplicationBase.java

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,13 @@ abstract class ClientApplicationBase implements IClientApplicationBase {
5757
@Getter(AccessLevel.PACKAGE)
5858
private Consumer<List<HashMap<String, String>>> telemetryConsumer;
5959

60-
@Override
61-
public Proxy proxy() {
62-
return this.serviceBundle.getProxy();
63-
}
60+
@Accessors(fluent = true)
61+
@Getter
62+
public Proxy proxy;
6463

65-
@Override
66-
public SSLSocketFactory sslSocketFactory() {
67-
return this.serviceBundle.getSslSocketFactory();
68-
}
64+
@Accessors(fluent = true)
65+
@Getter
66+
public SSLSocketFactory sslSocketFactory;
6967

7068
@Accessors(fluent = true)
7169
@Getter
@@ -159,18 +157,24 @@ AuthenticationResult acquireTokenCommon(MsalRequest msalRequest, Authority reque
159157
headers.getHeaderCorrelationIdValue()));
160158
}
161159

162-
TokenRequest request = new TokenRequest(requestAuthority, msalRequest, serviceBundle);
160+
TokenRequestExecutor requestExecutor = new TokenRequestExecutor(
161+
requestAuthority,
162+
msalRequest,
163+
serviceBundle);
163164

164-
AuthenticationResult result = request.executeOauthRequestAndProcessResponse();
165+
AuthenticationResult result = requestExecutor.executeTokenRequest();
165166

166167
if(authenticationAuthority.authorityType.equals(AuthorityType.B2C)){
167-
tokenCache.saveTokens(request, result, authenticationAuthority.host);
168+
tokenCache.saveTokens(requestExecutor, result, authenticationAuthority.host);
168169
} else {
169170
InstanceDiscoveryMetadataEntry instanceDiscoveryMetadata =
170-
AadInstanceDiscovery.GetMetadataEntry
171-
(requestAuthority.canonicalAuthorityUrl(), validateAuthority, msalRequest, serviceBundle);
171+
AadInstanceDiscovery.GetMetadataEntry(
172+
requestAuthority.canonicalAuthorityUrl(),
173+
validateAuthority,
174+
msalRequest,
175+
serviceBundle);
172176

173-
tokenCache.saveTokens(request, result, instanceDiscoveryMetadata.preferredCache);
177+
tokenCache.saveTokens(requestExecutor, result, instanceDiscoveryMetadata.preferredCache);
174178
}
175179

176180
return result;
@@ -226,6 +230,7 @@ abstract static class Builder<T extends Builder<T>> {
226230
private ExecutorService executorService;
227231
private Proxy proxy;
228232
private SSLSocketFactory sslSocketFactory;
233+
private IHttpClient httpClient;
229234
private Consumer<List<HashMap<String, String>>> telemetryConsumer;
230235
private Boolean onlySendFailureTelemetry = false;
231236
private ITokenCacheAccessAspect tokenCacheAccessAspect;
@@ -344,6 +349,14 @@ public T proxy(Proxy val) {
344349
return self();
345350
}
346351

352+
353+
public T httpClient(IHttpClient val){
354+
validateNotNull("httpClient", val);
355+
356+
httpClient = val;
357+
return self();
358+
}
359+
347360
/**
348361
* Sets SSLSocketFactory to be used by the client application for all network communication.
349362
*
@@ -403,10 +416,13 @@ private static Authority createDefaultAADAuthority() {
403416
correlationId = builder.correlationId;
404417
logPii = builder.logPii;
405418
telemetryConsumer = builder.telemetryConsumer;
419+
proxy = builder.proxy;
420+
sslSocketFactory = builder.sslSocketFactory;
406421
serviceBundle = new ServiceBundle(
407422
builder.executorService,
408-
builder.proxy,
409-
builder.sslSocketFactory,
423+
builder.httpClient == null ?
424+
new DefaultHttpClient(builder.proxy, builder.sslSocketFactory) :
425+
builder.httpClient,
410426
new TelemetryManager(telemetryConsumer, builder.onlySendFailureTelemetry));
411427
authenticationAuthority = builder.authenticationAuthority;
412428
tokenCache = new TokenCache(builder.tokenCacheAccessAspect);

0 commit comments

Comments
 (0)