Skip to content

Commit 58baa9f

Browse files
authored
Merge pull request #465 from nclaeys/bugfix/support-workload-identity
Rework CustomJWTAuthentication to request the oauth token correctly
2 parents 0abb791 + 101759a commit 58baa9f

File tree

4 files changed

+124
-50
lines changed

4 files changed

+124
-50
lines changed

src/integrationtest/java/com.microsoft.aad.msal4j/CertificateHelper.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ static KeyStore createKeyStore() throws KeyStoreException, NoSuchProviderExcepti
1616
String os = SystemUtils.OS_NAME;
1717
if (os.contains("Mac")) {
1818
return KeyStore.getInstance("KeychainStore");
19+
} else if (os.contains("Linux")) {
20+
return KeyStore.getInstance(KeyStore.getDefaultType());
1921
} else {
2022
return KeyStore.getInstance("Windows-MY", "SunMSCAPI");
2123
}

src/integrationtest/java/com.microsoft.aad.msal4j/ConfidentialClientApplicationUnitT.java

Lines changed: 79 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33

44
package com.microsoft.aad.msal4j;
55

6+
import com.nimbusds.jose.JOSEException;
67
import com.nimbusds.jose.JWSAlgorithm;
78
import com.nimbusds.jose.JWSHeader;
89
import com.nimbusds.jose.crypto.RSASSASigner;
910
import com.nimbusds.jose.util.Base64;
1011
import com.nimbusds.jose.util.Base64URL;
1112
import com.nimbusds.jwt.JWTClaimsSet;
1213
import com.nimbusds.jwt.SignedJWT;
14+
import com.nimbusds.oauth2.sdk.auth.JWTAuthentication;
1315
import com.nimbusds.oauth2.sdk.auth.PrivateKeyJWT;
16+
import org.easymock.Capture;
1417
import org.easymock.EasyMock;
1518
import org.powermock.api.easymock.PowerMock;
1619
import org.powermock.core.classloader.annotations.PowerMockIgnore;
@@ -22,13 +25,14 @@
2225

2326
import java.io.IOException;
2427
import java.net.URI;
28+
import java.net.URLEncoder;
2529
import java.security.*;
2630
import java.security.cert.CertificateException;
2731
import java.util.*;
2832
import java.util.concurrent.Future;
2933

30-
import static org.testng.Assert.assertFalse;
31-
import static org.testng.Assert.assertNotNull;
34+
import static org.easymock.EasyMock.*;
35+
import static org.testng.Assert.*;
3236

3337
@PowerMockIgnore({"javax.net.ssl.*"})
3438
@PrepareForTest({ConfidentialClientApplication.class,
@@ -206,29 +210,7 @@ private ClientAssertion buildShortJwt(String clientId,
206210

207211
@Test
208212
public void testClientAssertion_noException() throws Exception{
209-
210-
IClientCertificate certificate = CertificateHelper.getClientCertificate();
211-
212-
final ClientCertificate credential = (ClientCertificate) certificate;
213-
214-
final JWTClaimsSet claimsSet = new JWTClaimsSet.Builder()
215-
.issuer("issuer")
216-
.subject("subject")
217-
.build();
218-
219-
SignedJWT jwt;
220-
JWSHeader.Builder builder = new JWSHeader.Builder(JWSAlgorithm.RS256);
221-
222-
List<Base64> certs = new ArrayList<>();
223-
for (String cert : credential.getEncodedPublicKeyCertificateChain()) {
224-
certs.add(new Base64(cert));
225-
}
226-
builder.x509CertChain(certs);
227-
228-
jwt = new SignedJWT(builder.build(), claimsSet);
229-
final RSASSASigner signer = new RSASSASigner(credential.privateKey());
230-
231-
jwt.sign(signer);
213+
SignedJWT jwt = createClientAssertion("issuer");
232214

233215
ClientAssertion clientAssertion = new ClientAssertion(jwt.serialize());
234216

@@ -245,14 +227,84 @@ public void testClientAssertion_noException() throws Exception{
245227

246228
}
247229

230+
@Test
231+
public void testClientAssertion_acquireToken() throws Exception{
232+
SignedJWT jwt = createClientAssertion("issuer");
233+
234+
ClientAssertion clientAssertion = new ClientAssertion(jwt.serialize());
235+
ConfidentialClientApplication app = ConfidentialClientApplication
236+
.builder(TestConfiguration.AAD_CLIENT_ID, ClientCredentialFactory.createFromClientAssertion(clientAssertion.assertion()))
237+
.authority(TestConfiguration.AAD_TENANT_ENDPOINT)
238+
.build();
239+
240+
String scope = "requestedScope";
241+
ClientCredentialRequest clientCredentialRequest = getClientCredentialRequest(app, scope);
242+
243+
IHttpClient httpClientMock = EasyMock.mock(IHttpClient.class);
244+
Capture<HttpRequest> captureSingleArgument = newCapture();
245+
expect(httpClientMock.send(capture(captureSingleArgument))).andReturn(new HttpResponse());
246+
EasyMock.replay(httpClientMock);
247+
248+
TokenRequestExecutor tokenRequestExecutor = new TokenRequestExecutor(app.authenticationAuthority, clientCredentialRequest, mockedServiceBundle(httpClientMock));
249+
try {
250+
tokenRequestExecutor.executeTokenRequest();
251+
} catch(Exception e) {
252+
//Ignored, we only want to check the request that was send.
253+
}
254+
HttpRequest value = captureSingleArgument.getValue();
255+
String body = value.body();
256+
Assert.assertTrue(body.contains("grant_type=client_credentials"));
257+
Assert.assertTrue(body.contains("client_assertion=" + clientAssertion.assertion()));
258+
Assert.assertTrue(body.contains("client_assertion_type=" + URLEncoder.encode(JWTAuthentication.CLIENT_ASSERTION_TYPE, "utf-8")));
259+
Assert.assertTrue(body.contains("scope=" + URLEncoder.encode("openid profile offline_access " + scope, "utf-8")));
260+
Assert.assertTrue(body.contains("client_id=" + TestConfiguration.AAD_CLIENT_ID));
261+
}
262+
263+
private ServiceBundle mockedServiceBundle(IHttpClient httpClientMock) {
264+
ServiceBundle serviceBundle = new ServiceBundle(
265+
null,
266+
httpClientMock,
267+
new TelemetryManager(null, false));
268+
return serviceBundle;
269+
}
270+
271+
private ClientCredentialRequest getClientCredentialRequest(ConfidentialClientApplication app, String scope) {
272+
Set<String> scopes = new HashSet<>();
273+
scopes.add(scope);
274+
ClientCredentialParameters clientCredentials = ClientCredentialParameters.builder(scopes).tenant(IdToken.TENANT_IDENTIFIER).build();
275+
RequestContext requestContext = new RequestContext(
276+
app,
277+
PublicApi.ACQUIRE_TOKEN_FOR_CLIENT,
278+
clientCredentials);
279+
280+
ClientCredentialRequest clientCredentialRequest =
281+
new ClientCredentialRequest(
282+
clientCredentials,
283+
app,
284+
requestContext);
285+
return clientCredentialRequest;
286+
}
287+
248288
@Test(expectedExceptions = MsalClientException.class)
249289
public void testClientAssertion_throwsException() throws Exception{
290+
SignedJWT jwt = createClientAssertion(null);
291+
292+
ClientAssertion clientAssertion = new ClientAssertion(jwt.serialize());
293+
294+
IClientCredential iClientCredential = ClientCredentialFactory.createFromClientAssertion(
295+
clientAssertion.assertion());
250296

297+
ConfidentialClientApplication.builder(TestConfiguration.AAD_CLIENT_ID, iClientCredential).authority(TestConfiguration.AAD_TENANT_ENDPOINT).build();
298+
299+
}
300+
301+
private SignedJWT createClientAssertion(String issuer) throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException, UnrecoverableKeyException, NoSuchProviderException, JOSEException {
251302
IClientCertificate certificate = CertificateHelper.getClientCertificate();
303+
252304
final ClientCertificate credential = (ClientCertificate) certificate;
253305

254306
final JWTClaimsSet claimsSet = new JWTClaimsSet.Builder()
255-
.issuer(null)
307+
.issuer(issuer)
256308
.subject("subject")
257309
.build();
258310

@@ -269,15 +321,7 @@ public void testClientAssertion_throwsException() throws Exception{
269321
final RSASSASigner signer = new RSASSASigner(credential.privateKey());
270322

271323
jwt.sign(signer);
272-
273-
ClientAssertion clientAssertion = new ClientAssertion(jwt.serialize());
274-
275-
IClientCredential iClientCredential = ClientCredentialFactory.createFromClientAssertion(
276-
clientAssertion.assertion());
277-
278-
ConfidentialClientApplication.builder(TestConfiguration.AAD_CLIENT_ID, iClientCredential).authority(TestConfiguration.AAD_TENANT_ENDPOINT).build();
279-
324+
return jwt;
280325
}
281326

282-
283327
}

src/main/java/com/microsoft/aad/msal4j/ConfidentialClientApplication.java

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
public class ConfidentialClientApplication extends AbstractClientApplicationBase implements IConfidentialClientApplication {
3434

3535
private ClientAuthentication clientAuthentication;
36-
private CustomJWTAuthentication customJWTAuthentication;
3736
private boolean clientCertAuthentication = false;
3837
private ClientCertificate clientCertificate;
3938

@@ -137,22 +136,11 @@ private ClientAuthentication createClientAuthFromClientAssertion(
137136
//This library is not supposed to validate Issuer and subject values.
138137
//The next lines of code ensures that exception is not thrown.
139138
if (e.getMessage().contains("Issuer and subject in client JWT assertion must designate the same client identifier")) {
140-
String clientAssertion1 = MultivaluedMapUtils.getFirstValue(map, "client_assertion");
141-
Base64URL[] parts;
142-
try {
143-
parts = JOSEObject.split(clientAssertion1);
144-
145-
SignedJWT signedJWT = new SignedJWT(parts[0], parts[1], parts[2]);
146-
String subjectValue = signedJWT.getJWTClaimsSet().getSubject();
147139
return new CustomJWTAuthentication(
148140
ClientAuthenticationMethod.PRIVATE_KEY_JWT,
149-
new ClientID(subjectValue)
141+
clientAssertion,
142+
new ClientID(clientId())
150143
);
151-
152-
} catch (java.text.ParseException ex) {
153-
log.error("Ideally the system should not reach here. Parse Exception while trying to build CustomJWTAuthentication.");
154-
throw new MsalClientException(e);
155-
}
156144
}
157145
throw new MsalClientException(e);
158146
}

src/main/java/com/microsoft/aad/msal4j/CustomJWTAuthentication.java

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,59 @@
33

44
package com.microsoft.aad.msal4j;
55

6+
import com.nimbusds.common.contenttype.ContentType;
7+
import com.nimbusds.oauth2.sdk.SerializeException;
68
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
79
import com.nimbusds.oauth2.sdk.auth.ClientAuthenticationMethod;
10+
import com.nimbusds.oauth2.sdk.auth.JWTAuthentication;
811
import com.nimbusds.oauth2.sdk.http.HTTPRequest;
912
import com.nimbusds.oauth2.sdk.id.ClientID;
13+
import com.nimbusds.oauth2.sdk.util.URLUtils;
14+
15+
import java.net.URLEncoder;
16+
import java.util.Collections;
17+
import java.util.HashMap;
18+
import java.util.List;
19+
import java.util.Map;
1020

1121
public class CustomJWTAuthentication extends ClientAuthentication {
22+
private ClientAssertion clientAssertion;
1223

13-
protected CustomJWTAuthentication(ClientAuthenticationMethod method, ClientID clientID) {
24+
protected CustomJWTAuthentication(ClientAuthenticationMethod method, ClientAssertion clientAssertion, ClientID clientID) {
1425
super(method, clientID);
26+
this.clientAssertion = clientAssertion;
1527
}
1628

1729
@Override
1830
public void applyTo(HTTPRequest httpRequest) {
31+
if (httpRequest.getMethod() != HTTPRequest.Method.POST) {
32+
throw new SerializeException("The HTTP request method must be POST");
33+
} else {
34+
ContentType ct = httpRequest.getEntityContentType();
35+
if (ct == null) {
36+
throw new SerializeException("Missing HTTP Content-Type header");
37+
} else if (!ct.matches(ContentType.APPLICATION_URLENCODED)) {
38+
throw new SerializeException("The HTTP Content-Type header must be " + ContentType.APPLICATION_URLENCODED);
39+
} else {
40+
Map<String, List<String>> params = httpRequest.getQueryParameters();
41+
params.putAll(this.toParameters());
42+
String queryString = URLUtils.serializeParameters(params);
43+
httpRequest.setQuery(queryString);
44+
}
45+
}
46+
}
47+
48+
public Map<String, List<String>> toParameters() {
49+
HashMap<String, List<String>> params = new HashMap<>();
50+
51+
try {
52+
params.put("client_assertion", Collections.singletonList(this.clientAssertion.assertion()));
53+
} catch (IllegalStateException var3) {
54+
throw new SerializeException("Couldn't serialize JWT to a client assertion string: " + var3.getMessage(), var3);
55+
}
1956

57+
params.put("client_assertion_type", Collections.singletonList(JWTAuthentication.CLIENT_ASSERTION_TYPE));
58+
params.put("client_id", Collections.singletonList(getClientID().getValue()));
59+
return params;
2060
}
2161
}

0 commit comments

Comments
 (0)