Skip to content

Commit f504ba2

Browse files
committed
Refactor assertion behavior to be entirely per-request
1 parent e356bb6 commit f504ba2

File tree

4 files changed

+189
-96
lines changed

4 files changed

+189
-96
lines changed

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientCertificate.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,33 @@ public List<String> getEncodedPublicKeyCertificateChain() throws CertificateEnco
5555
return result;
5656
}
5757

58+
/**
59+
* Gets a newly created JWT assertion using the certificate.
60+
* <p>
61+
* This method creates a fresh JWT assertion on each call, which prevents issues
62+
* with token expiration and ensures each request has a unique assertion.
63+
*
64+
* @param authority The authority for which the assertion is being created, must not be null
65+
* @param clientId The client ID of the application, used as the subject of the JWT
66+
* @param sendX5c Whether to include the x5c claim (certificate chain) in the JWT
67+
* @return A JWT assertion for client authentication
68+
* @throws NullPointerException if authority is null
69+
*/
70+
public String getAssertion(Authority authority, String clientId, boolean sendX5c) {
71+
if (authority == null) {
72+
throw new NullPointerException("Authority cannot be null");
73+
}
74+
75+
boolean useSha1 = Authority.detectAuthorityType(authority.canonicalAuthorityUrl()) == AuthorityType.ADFS;
76+
77+
return JwtHelper.buildJwt(
78+
clientId,
79+
this,
80+
authority.selfSignedJwtAudience(),
81+
sendX5c,
82+
useSha1).assertion();
83+
}
84+
5885
static ClientCertificate create(InputStream pkcs12Certificate, String password)
5986
throws KeyStoreException, NoSuchProviderException, NoSuchAlgorithmException,
6087
CertificateException, IOException, UnrecoverableKeyException {

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ConfidentialClientApplication.java

Lines changed: 6 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,14 @@
2121
*/
2222
public class ConfidentialClientApplication extends AbstractClientApplicationBase implements IConfidentialClientApplication {
2323

24-
private ClientCertificate clientCertificate;
25-
private String assertion;
26-
private IClientCredential clientCredential;
27-
String secret;
24+
IClientCredential clientCredential;
25+
private boolean sendX5c;
2826

2927
/** AppTokenProvider creates a Credential from a function that provides access tokens. The function
3028
must be concurrency safe. This is intended only to allow the Azure SDK to cache MSI tokens. It isn't
3129
useful to applications in general because the token provider must implement all authentication logic. */
3230
public Function<AppTokenProviderParameters, CompletableFuture<TokenProviderResult>> appTokenProvider;
3331

34-
private boolean sendX5c;
35-
3632
@Override
3733
public CompletableFuture<IAuthenticationResult> acquireToken(ClientCredentialParameters parameters) {
3834
validateNotNull("parameters", parameters);
@@ -76,80 +72,11 @@ private ConfidentialClientApplication(Builder builder) {
7672

7773
log = LoggerFactory.getLogger(ConfidentialClientApplication.class);
7874

79-
initClientAuthentication(builder.clientCredential);
75+
this.clientCredential = builder.clientCredential;
8076

8177
this.tenant = this.authenticationAuthority.tenant;
8278
}
8379

84-
private void initClientAuthentication(IClientCredential clientCredential) {
85-
validateNotNull("clientCredential", clientCredential);
86-
87-
this.clientCredential = clientCredential;
88-
89-
if (clientCredential instanceof ClientSecret) {
90-
this.secret = ((ClientSecret) clientCredential).clientSecret();
91-
} else if (clientCredential instanceof ClientCertificate) {
92-
this.clientCertificate = (ClientCertificate) clientCredential;
93-
this.assertion = getAssertionString(clientCredential);
94-
} else if (clientCredential instanceof ClientAssertion) {
95-
this.assertion = getAssertionString(clientCredential);
96-
} else {
97-
throw new IllegalArgumentException("Unsupported client credential");
98-
}
99-
}
100-
101-
/**
102-
* Generates a JWT-formatted assertion string based on the provided client credential. Returns null in cases where
103-
* the request for that credential type would not use a JWT assertion (e.g. client secret).
104-
*
105-
* @param clientCredential The client credential to use for token acquisition.
106-
* @return JWT-formatted assertion string
107-
*/
108-
String getAssertionString(IClientCredential clientCredential) {
109-
if (clientCredential instanceof ClientCertificate) {
110-
// Check if the current assertion is null or has expired, and if so create a new one
111-
if (this.assertion == null || hasJwtExpired(this.assertion)) {
112-
boolean useSha1 = Authority.detectAuthorityType(this.authenticationAuthority.canonicalAuthorityUrl()) == AuthorityType.ADFS;
113-
114-
this.assertion = JwtHelper.buildJwt(
115-
clientId(),
116-
clientCertificate,
117-
this.authenticationAuthority.selfSignedJwtAudience(),
118-
sendX5c,
119-
useSha1).assertion();
120-
}
121-
return this.assertion;
122-
} else if (clientCredential instanceof ClientAssertion) {
123-
return ((ClientAssertion) clientCredential).assertion();
124-
} else if (clientCredential instanceof ClientSecret) {
125-
return null;
126-
} else {
127-
throw new IllegalArgumentException("Unsupported client credential");
128-
}
129-
}
130-
131-
//Overload for the common case where the application's default credential was not overridden in the request.
132-
String getAssertionString() {
133-
return this.getAssertionString(this.clientCredential);
134-
}
135-
136-
/**
137-
* Checks if the JWT-formatted assertion has expired by parsing the "exp" claim.
138-
*
139-
* @param jwt JWT string
140-
* @return true if the JWT has expired. Otherwise false
141-
*/
142-
boolean hasJwtExpired(String jwt) {
143-
final Date currentDateTime = new Date(System.currentTimeMillis());
144-
Base64.Decoder decoder = Base64.getUrlDecoder();
145-
146-
String payload = new String(decoder.decode(jwt.split("\\.")[1]));
147-
148-
final Date expirationTime = (Date) JsonHelper.parseJsonToMap(payload).get("exp");
149-
150-
return expirationTime.before(currentDateTime);
151-
}
152-
15380
/**
15481
* Creates instance of Builder of ConfidentialClientApplication
15582
*
@@ -177,6 +104,9 @@ public static class Builder extends AbstractClientApplicationBase.Builder<Builde
177104

178105
private Builder(String clientId, IClientCredential clientCredential) {
179106
super(clientId);
107+
108+
validateNotNull("clientCredential", clientCredential);
109+
180110
this.clientCredential = clientCredential;
181111
}
182112

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/TokenRequestExecutor.java

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -86,25 +86,76 @@ private void addQueryParameters(OAuthHttpRequest oauthHttpRequest) {
8686
String clientID = msalRequest.application().clientId();
8787
queryParameters.put("client_id", clientID);
8888

89-
// If the client application has a client assertion to apply to the request, check if a new client assertion
90-
// was supplied as a request parameter. If so, use the request's assertion instead of the application's
89+
// Add client authentication parameters if this is a confidential client
9190
if (msalRequest.application() instanceof ConfidentialClientApplication) {
92-
if (msalRequest instanceof ClientCredentialRequest && ((ClientCredentialRequest) msalRequest).parameters.clientCredential() != null) {
93-
IClientCredential credential = ((ClientCredentialRequest) msalRequest).parameters.clientCredential();
94-
addJWTBearerAssertionParams(queryParameters, ((ConfidentialClientApplication) msalRequest.application()).getAssertionString(credential));
95-
} else {
96-
if (((ConfidentialClientApplication) msalRequest.application()).getAssertionString() != null) {
97-
addJWTBearerAssertionParams(queryParameters, ((ConfidentialClientApplication) msalRequest.application()).getAssertionString());
98-
} else if (((ConfidentialClientApplication) msalRequest.application()).secret != null) {
99-
// Client secrets have a different parameter than bearer assertions
100-
queryParameters.put("client_secret", ((ConfidentialClientApplication) msalRequest.application()).secret);
101-
}
102-
}
91+
ConfidentialClientApplication application = (ConfidentialClientApplication) msalRequest.application();
92+
93+
// Determine which credential to use - either from the request or from the application
94+
IClientCredential credential = getCredentialToUse(application);
95+
96+
// Add appropriate authentication parameters based on the credential type
97+
addCredentialToRequest(queryParameters, credential, application);
10398
}
10499

105100
oauthHttpRequest.setQuery(StringHelper.serializeQueryParameters(queryParameters));
106101
}
107102

103+
/**
104+
* Determines which credential to use for authentication:
105+
* - If the request is a ClientCredentialRequest with a specified credential, use that
106+
* - Otherwise use the application's credential
107+
*
108+
* @param application The confidential client application
109+
* @return The credential to use, may be null if no credential is available
110+
*/
111+
private IClientCredential getCredentialToUse(ConfidentialClientApplication application) {
112+
if (msalRequest instanceof ClientCredentialRequest &&
113+
((ClientCredentialRequest) msalRequest).parameters.clientCredential() != null) {
114+
return ((ClientCredentialRequest) msalRequest).parameters.clientCredential();
115+
}
116+
return application.clientCredential;
117+
}
118+
119+
/**
120+
* Adds the appropriate authentication parameters to the request based on credential type.
121+
* Handles different credential types (secret, assertion, certificate) by adding the appropriate
122+
* parameters to the request.
123+
*
124+
* @param queryParameters The map of query parameters to add to
125+
* @param credential The credential to use for authentication, may be null
126+
* @param application The confidential client application
127+
*/
128+
private void addCredentialToRequest(Map<String, String> queryParameters,
129+
IClientCredential credential,
130+
ConfidentialClientApplication application) {
131+
if (credential == null) {
132+
return;
133+
}
134+
135+
if (credential instanceof ClientSecret) {
136+
// For client secret, add client_secret parameter
137+
queryParameters.put("client_secret", ((ClientSecret) credential).clientSecret());
138+
} else if (credential instanceof ClientAssertion) {
139+
// For client assertion, add client_assertion and client_assertion_type parameters
140+
addJWTBearerAssertionParams(queryParameters, ((ClientAssertion) credential).assertion());
141+
} else if (credential instanceof ClientCertificate) {
142+
// For client certificate, generate a new assertion and add it to the request
143+
ClientCertificate certificate = (ClientCertificate) credential;
144+
String assertion = certificate.getAssertion(
145+
application.authenticationAuthority,
146+
application.clientId(),
147+
application.sendX5c());
148+
addJWTBearerAssertionParams(queryParameters, assertion);
149+
}
150+
// If credential is of an unknown type, no additional parameters are added
151+
}
152+
153+
/**
154+
* Adds the JWT bearer token assertion parameters to the request
155+
*
156+
* @param queryParameters The map of query parameters to add to
157+
* @param assertion The JWT assertion string
158+
*/
108159
private void addJWTBearerAssertionParams(Map<String, String> queryParameters, String assertion) {
109160
queryParameters.put("client_assertion", assertion);
110161
queryParameters.put("client_assertion_type", ClientAssertion.ASSERTION_TYPE_JWT_BEARER);

msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ClientCertificateTest.java

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,25 @@
33

44
package com.microsoft.aad.msal4j;
55

6-
import com.nimbusds.oauth2.sdk.auth.PrivateKeyJWT;
76
import com.nimbusds.jwt.SignedJWT;
87
import org.junit.jupiter.api.Test;
98
import org.junit.jupiter.api.TestInstance;
109

1110
import static org.junit.jupiter.api.Assertions.assertEquals;
11+
import static org.junit.jupiter.api.Assertions.assertNotEquals;
1212
import static org.junit.jupiter.api.Assertions.assertNotNull;
1313
import static org.junit.jupiter.api.Assertions.assertNull;
1414
import static org.junit.jupiter.api.Assertions.assertThrows;
15-
import static org.junit.jupiter.api.Assertions.assertTrue;
1615
import static org.mockito.ArgumentMatchers.any;
1716
import static org.mockito.Mockito.*;
1817

1918
import java.math.BigInteger;
2019
import java.security.*;
2120
import java.security.cert.CertificateException;
2221
import java.security.interfaces.RSAPrivateKey;
23-
import java.text.ParseException;
2422
import java.util.*;
23+
import java.net.URLDecoder;
24+
import java.nio.charset.StandardCharsets;
2525

2626
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
2727
class ClientCertificateTest {
@@ -77,12 +77,17 @@ void testIClientCertificateInterface_CredentialFactoryUsesSha256() throws Except
7777
HttpRequest request = parameters.getArgument(0);
7878
String requestBody = request.body();
7979

80-
SignedJWT signedJWT = SignedJWT.parse(cca.getAssertionString());
80+
String clientAssertion = extractClientAssertion(requestBody);
8181

82-
if (requestBody.contains(cca.getAssertionString())
83-
&& signedJWT.getHeader().toJSONObject().containsKey("x5t#S256")) {
84-
return TestHelper.expectedResponse(200, TestHelper.getSuccessfulTokenResponse(tokenResponseValues));
82+
if (clientAssertion != null) {
83+
SignedJWT signedJWT = SignedJWT.parse(clientAssertion);
84+
if (signedJWT.getHeader().toJSONObject().containsKey("x5t#S256")) {
85+
return TestHelper.expectedResponse(200, TestHelper.getSuccessfulTokenResponse(tokenResponseValues));
86+
}
8587
}
88+
89+
//If the client assertion is null or does not contain the x5t#S256 header,
90+
// that indicates a problem in assertion generation and this test should fail.
8691
return null;
8792
});
8893

@@ -94,6 +99,86 @@ void testIClientCertificateInterface_CredentialFactoryUsesSha256() throws Except
9499
assertEquals("accessTokenSha256", result.accessToken());
95100
}
96101

102+
@Test
103+
void testClientCertificate_GeneratesNewAssertionEachTime() throws Exception {
104+
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
105+
List<String> capturedAssertions = new ArrayList<>();
106+
107+
ConfidentialClientApplication cca =
108+
ConfidentialClientApplication.builder("clientId", ClientCredentialFactory.createFromCertificate(TestHelper.getPrivateKey(), TestHelper.getX509Cert()))
109+
.authority("https://login.microsoftonline.com/tenant")
110+
.instanceDiscovery(false)
111+
.validateAuthority(false)
112+
.httpClient(httpClientMock)
113+
.build();
114+
115+
// Mock the HTTP client to capture assertions from each request
116+
when(httpClientMock.send(any(HttpRequest.class))).thenAnswer(parameters -> {
117+
HttpRequest request = parameters.getArgument(0);
118+
String requestBody = request.body();
119+
120+
String clientAssertion = extractClientAssertion(requestBody);
121+
if (clientAssertion != null) {
122+
capturedAssertions.add(clientAssertion);
123+
124+
// Verify it's a valid JWT with proper headers
125+
SignedJWT signedJWT = SignedJWT.parse(clientAssertion);
126+
if (signedJWT.getHeader().toJSONObject().containsKey("x5t#S256")) {
127+
HashMap<String, String> tokenResponseValues = new HashMap<>();
128+
tokenResponseValues.put("access_token", "access_token_" + capturedAssertions.size());
129+
return TestHelper.expectedResponse(200, TestHelper.getSuccessfulTokenResponse(tokenResponseValues));
130+
}
131+
}
132+
return null;
133+
});
134+
135+
ClientCredentialParameters parameters = ClientCredentialParameters.builder(Collections.singleton("scopes")).skipCache(true).build();
136+
137+
// Make two token requests
138+
IAuthenticationResult result1 = cca.acquireToken(parameters).get();
139+
IAuthenticationResult result2 = cca.acquireToken(parameters).get();
140+
141+
// Verify two unique assertions were generated
142+
assertEquals(2, capturedAssertions.size(), "Two assertions should have been generated");
143+
assertNotEquals(capturedAssertions.get(0), capturedAssertions.get(1),
144+
"Each token request should generate a unique assertion");
145+
146+
// Optional: Parse and verify JWT properties if needed
147+
SignedJWT jwt1 = SignedJWT.parse(capturedAssertions.get(0));
148+
SignedJWT jwt2 = SignedJWT.parse(capturedAssertions.get(1));
149+
150+
// Different JTI (JWT ID) should be used for each assertion
151+
assertNotEquals(jwt1.getJWTClaimsSet().getJWTID(), jwt2.getJWTClaimsSet().getJWTID(),
152+
"Each assertion should have a unique JTI claim");
153+
154+
// Verify the tokens are different
155+
assertNotEquals(result1.accessToken(), result2.accessToken(),
156+
"The access tokens from each request should be different");
157+
}
158+
159+
/**
160+
* Extracts the client_assertion value from a URL-encoded request body
161+
* @param requestBody The request body string
162+
* @return The extracted client assertion or null if not found
163+
*/
164+
private String extractClientAssertion(String requestBody) {
165+
try {
166+
// Split the request body into key-value pairs
167+
String[] pairs = requestBody.split("&");
168+
for (String pair : pairs) {
169+
// Find the client_assertion parameter
170+
if (pair.startsWith("client_assertion=")) {
171+
// Extract and URL-decode the value
172+
return URLDecoder.decode(pair.substring("client_assertion=".length()), StandardCharsets.UTF_8.toString());
173+
}
174+
}
175+
} catch (Exception e) {
176+
// In case of any parsing errors
177+
System.err.println("Error extracting client assertion: " + e.getMessage());
178+
}
179+
return null;
180+
}
181+
97182
class TestClientCredential implements IClientCertificate {
98183
@Override
99184
public PrivateKey privateKey() {

0 commit comments

Comments
 (0)