Skip to content

Commit 89e54be

Browse files
If signature validation fails, reload JWKs and retry if new JWKs are found (#88023)
Co-authored-by: Niels Dewulf
1 parent c3e5daa commit 89e54be

File tree

13 files changed

+950
-559
lines changed

13 files changed

+950
-559
lines changed

docs/changelog/88023.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 88023
2+
summary: "If signature validation fails, reload JWKs and retry if new JWKs are found"
3+
area: Authentication
4+
type: enhancement
5+
issues: []

x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java

Lines changed: 375 additions & 171 deletions
Large diffs are not rendered by default.

x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtil.java

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import com.nimbusds.jose.jwk.JWKSet;
1212
import com.nimbusds.jose.util.JSONObjectUtils;
1313
import com.nimbusds.jwt.JWTClaimsSet;
14-
import com.nimbusds.jwt.SignedJWT;
1514

1615
import org.apache.http.HttpEntity;
1716
import org.apache.http.HttpResponse;
@@ -33,8 +32,9 @@
3332
import org.apache.logging.log4j.Logger;
3433
import org.elasticsearch.ElasticsearchSecurityException;
3534
import org.elasticsearch.SpecialPermission;
36-
import org.elasticsearch.action.support.PlainActionFuture;
35+
import org.elasticsearch.action.ActionListener;
3736
import org.elasticsearch.common.Strings;
37+
import org.elasticsearch.common.hash.MessageDigests;
3838
import org.elasticsearch.common.settings.SecureString;
3939
import org.elasticsearch.common.settings.SettingsException;
4040
import org.elasticsearch.common.ssl.SslConfiguration;
@@ -51,6 +51,7 @@
5151
import java.nio.file.Files;
5252
import java.nio.file.Path;
5353
import java.security.AccessController;
54+
import java.security.MessageDigest;
5455
import java.security.PrivilegedAction;
5556
import java.security.PrivilegedActionException;
5657
import java.security.PrivilegedExceptionAction;
@@ -185,16 +186,25 @@ public static URI parseHttpsUri(final String uriString) {
185186
return null;
186187
}
187188

188-
public static byte[] readUriContents(
189+
public static void readUriContents(
189190
final String jwkSetConfigKeyPkc,
190191
final URI jwkSetPathPkcUri,
191-
final CloseableHttpAsyncClient httpClient
192-
) throws SettingsException {
193-
try {
194-
return JwtUtil.readBytes(httpClient, jwkSetPathPkcUri);
195-
} catch (Exception e) {
196-
throw new SettingsException("Can't get contents for setting [" + jwkSetConfigKeyPkc + "] value [" + jwkSetPathPkcUri + "].", e);
197-
}
192+
final CloseableHttpAsyncClient httpClient,
193+
final ActionListener<byte[]> listener
194+
) {
195+
JwtUtil.readBytes(
196+
httpClient,
197+
jwkSetPathPkcUri,
198+
ActionListener.wrap(
199+
listener::onResponse,
200+
ex -> listener.onFailure(
201+
new SettingsException(
202+
"Can't get contents for setting [" + jwkSetConfigKeyPkc + "] value [" + jwkSetPathPkcUri + "].",
203+
ex
204+
)
205+
)
206+
)
207+
);
198208
}
199209

200210
public static byte[] readFileContents(final String jwkSetConfigKeyPkc, final String jwkSetPathPkc, final Environment environment)
@@ -211,7 +221,7 @@ public static byte[] readFileContents(final String jwkSetConfigKeyPkc, final Str
211221
}
212222

213223
public static String serializeJwkSet(final JWKSet jwkSet, final boolean publicKeysOnly) {
214-
if ((jwkSet == null) || (jwkSet.getKeys().isEmpty())) {
224+
if (jwkSet == null) {
215225
return null;
216226
}
217227
return JSONObjectUtils.toJSONString(jwkSet.toJSONObject(publicKeysOnly));
@@ -262,13 +272,11 @@ public static CloseableHttpAsyncClient createHttpClient(final RealmConfig realmC
262272
}
263273

264274
/**
265-
* Use the HTTP Client to get URL content bytes up to N max bytes.
275+
* Use the HTTP Client to get URL content bytes.
266276
* @param httpClient Configured HTTP/HTTPS client.
267277
* @param uri URI to download.
268-
* @return Byte array of the URI contents up to N max bytes.
269278
*/
270-
public static byte[] readBytes(final CloseableHttpAsyncClient httpClient, final URI uri) {
271-
final PlainActionFuture<byte[]> plainActionFuture = PlainActionFuture.newFuture();
279+
public static void readBytes(final CloseableHttpAsyncClient httpClient, final URI uri, ActionListener<byte[]> listener) {
272280
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
273281
httpClient.execute(new HttpGet(uri), new FutureCallback<>() {
274282
@Override
@@ -278,12 +286,12 @@ public void completed(final HttpResponse result) {
278286
if (statusCode == 200) {
279287
final HttpEntity entity = result.getEntity();
280288
try (InputStream inputStream = entity.getContent()) {
281-
plainActionFuture.onResponse(inputStream.readAllBytes());
289+
listener.onResponse(inputStream.readAllBytes());
282290
} catch (Exception e) {
283-
plainActionFuture.onFailure(e);
291+
listener.onFailure(e);
284292
}
285293
} else {
286-
plainActionFuture.onFailure(
294+
listener.onFailure(
287295
new ElasticsearchSecurityException(
288296
"Get [" + uri + "] failed, status [" + statusCode + "], reason [" + statusLine.getReasonPhrase() + "]."
289297
)
@@ -293,17 +301,16 @@ public void completed(final HttpResponse result) {
293301

294302
@Override
295303
public void failed(Exception e) {
296-
plainActionFuture.onFailure(new ElasticsearchSecurityException("Get [" + uri + "] failed.", e));
304+
listener.onFailure(new ElasticsearchSecurityException("Get [" + uri + "] failed.", e));
297305
}
298306

299307
@Override
300308
public void cancelled() {
301-
plainActionFuture.onFailure(new ElasticsearchSecurityException("Get [" + uri + "] was cancelled."));
309+
listener.onFailure(new ElasticsearchSecurityException("Get [" + uri + "] was cancelled."));
302310
}
303311
});
304312
return null;
305313
});
306-
return plainActionFuture.actionGet();
307314
}
308315

309316
public static Path resolvePath(final Environment environment, final String jwkSetPath) {
@@ -335,14 +342,10 @@ public static SecureString join(final CharSequence delimiter, final CharSequence
335342
* JWSHeader: Header are not support.
336343
* JWTClaimsSet: Claims are supported. Claim keys are prefixed by "jwt_claim_".
337344
* Base64URL: Signature is not supported.
338-
* @param jwt SignedJWT object.
339345
* @return Map of formatted and filtered values to be used as user metadata.
340-
* @throws Exception Parse error.
341346
*/
342-
//
343347
// Values will be filtered by type using isAllowedTypeForClaim().
344-
public static Map<String, Object> toUserMetadata(final SignedJWT jwt) throws Exception {
345-
final JWTClaimsSet claimsSet = jwt.getJWTClaimsSet();
348+
public static Map<String, Object> toUserMetadata(JWTClaimsSet claimsSet) {
346349
return claimsSet.getClaims()
347350
.entrySet()
348351
.stream()
@@ -366,4 +369,10 @@ static boolean isAllowedTypeForClaim(final Object value) {
366369
|| (value instanceof Collection
367370
&& ((Collection<?>) value).stream().allMatch(e -> e instanceof String || e instanceof Boolean || e instanceof Number)));
368371
}
372+
373+
public static byte[] sha256(final CharSequence charSequence) {
374+
final MessageDigest messageDigest = MessageDigests.sha256();
375+
messageDigest.update(charSequence.toString().getBytes(StandardCharsets.UTF_8));
376+
return messageDigest.digest();
377+
}
369378
}

x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtValidateUtil.java

Lines changed: 30 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import org.apache.logging.log4j.LogManager;
3434
import org.apache.logging.log4j.Logger;
35+
import org.elasticsearch.ElasticsearchException;
3536
import org.elasticsearch.common.settings.SecureString;
3637

3738
import java.util.Date;
@@ -48,55 +49,6 @@ public class JwtValidateUtil {
4849
null
4950
);
5051

51-
/**
52-
* Validate a SignedJWT. Use iss/aud/alg filters for those claims, JWKSet for signature, and skew seconds for time claims.
53-
* @param jwt Signed JWT to be validated.
54-
* @param allowedIssuer Filter for the "iss" claim.
55-
* @param allowedAudiences Filter for the "aud" claim.
56-
* @param allowedClockSkewSeconds Skew tolerance for the "auth_time", "iat", "nbf", and "exp" claims.
57-
* @param allowedSignatureAlgorithms Filter for the "aud" header.
58-
* @param jwks JWKs of HMAC secret keys or RSA/EC public keys.
59-
* @throws Exception Error for the first validation to fail.
60-
*/
61-
public static void validate(
62-
final SignedJWT jwt,
63-
final String allowedIssuer,
64-
final List<String> allowedAudiences,
65-
final long allowedClockSkewSeconds,
66-
final List<String> allowedSignatureAlgorithms,
67-
final List<JWK> jwks
68-
) throws Exception {
69-
final Date now = new Date();
70-
71-
if (LOGGER.isDebugEnabled()) {
72-
LOGGER.debug(
73-
"Validating JWT, now [{}], alg [{}], issuer [{}], audiences [{}], typ [{}],"
74-
+ " auth_time [{}], iat [{}], nbf [{}], exp [{}], kid [{}], jti [{}]",
75-
now,
76-
jwt.getHeader().getAlgorithm(),
77-
jwt.getJWTClaimsSet().getIssuer(),
78-
jwt.getJWTClaimsSet().getAudience(),
79-
jwt.getHeader().getType(),
80-
jwt.getJWTClaimsSet().getDateClaim("auth_time"),
81-
jwt.getJWTClaimsSet().getIssueTime(),
82-
jwt.getJWTClaimsSet().getNotBeforeTime(),
83-
jwt.getJWTClaimsSet().getExpirationTime(),
84-
jwt.getHeader().getKeyID(),
85-
jwt.getJWTClaimsSet().getJWTID()
86-
);
87-
}
88-
// validate claims before signature, because log messages about rejected claims can be more helpful than rejected signatures
89-
JwtValidateUtil.validateType(jwt);
90-
JwtValidateUtil.validateIssuer(jwt, allowedIssuer);
91-
JwtValidateUtil.validateAudiences(jwt, allowedAudiences);
92-
JwtValidateUtil.validateSignatureAlgorithm(jwt, allowedSignatureAlgorithms);
93-
JwtValidateUtil.validateAuthTime(jwt, now, allowedClockSkewSeconds);
94-
JwtValidateUtil.validateIssuedAtTime(jwt, now, allowedClockSkewSeconds);
95-
JwtValidateUtil.validateNotBeforeTime(jwt, now, allowedClockSkewSeconds);
96-
JwtValidateUtil.validateExpiredTime(jwt, now, allowedClockSkewSeconds);
97-
JwtValidateUtil.validateSignature(jwt, jwks);
98-
}
99-
10052
public static void validateType(final SignedJWT jwt) throws Exception {
10153
final JOSEObjectType jwtHeaderType = jwt.getHeader().getType();
10254
try {
@@ -277,7 +229,10 @@ static void validateExpiredTime(final Date exp, final Date now, final long allow
277229
* @throws Exception Error if JWKs fail to validate the Signed JWT.
278230
*/
279231
public static void validateSignature(final SignedJWT jwt, final List<JWK> jwks) throws Exception {
280-
assert jwks != null && jwks.isEmpty() == false : "Caller must provide a non-empty JWK list";
232+
assert jwks != null : "Verify requires a non-null JWK list";
233+
if (jwks.isEmpty()) {
234+
throw new ElasticsearchException("Verify requires a non-empty JWK list");
235+
}
281236
final String id = jwt.getHeader().getKeyID();
282237
final JWSAlgorithm alg = jwt.getHeader().getAlgorithm();
283238
LOGGER.trace("JWKs [{}], JWT KID [{}], and JWT Algorithm [{}] before filters.", jwks.size(), id, alg.getName());
@@ -305,12 +260,35 @@ public static void validateSignature(final SignedJWT jwt, final List<JWK> jwks)
305260
final List<JWK> jwksStrength = jwksAlg.stream().filter(j -> JwkValidateUtil.isMatch(j, alg.getName())).toList();
306261
LOGGER.debug("JWKs [{}] after Algorithm [{}] match filter.", jwksStrength.size(), alg);
307262

263+
// No JWKs passed the kid, alg, and strength checks, so nothing left to use in verifying the JWT signature
264+
if (jwksStrength.isEmpty()) {
265+
throw new ElasticsearchException("Verify failed because all " + jwks.size() + " provided JWKs were filtered.");
266+
}
267+
308268
for (final JWK jwk : jwksStrength) {
309269
if (jwt.verify(JwtValidateUtil.createJwsVerifier(jwk))) {
310-
return; // VERIFY SUCCEEDED
270+
LOGGER.trace(
271+
"JWT signature validation succeeded with JWK kty=[{}], jwtAlg=[{}], jwtKid=[{}], use=[{}], ops=[{}]",
272+
jwk.getKeyType(),
273+
jwk.getAlgorithm(),
274+
jwk.getKeyID(),
275+
jwk.getKeyUse(),
276+
jwk.getKeyOperations()
277+
);
278+
return;
279+
} else {
280+
LOGGER.trace(
281+
"JWT signature validation failed with JWK kty=[{}], jwtAlg=[{}], jwtKid=[{}], use=[{}], ops={}",
282+
jwk.getKeyType(),
283+
jwk.getAlgorithm(),
284+
jwk.getKeyID(),
285+
jwk.getKeyUse(),
286+
jwk.getKeyOperations() == null ? "[null]" : jwk.getKeyOperations()
287+
);
311288
}
312289
}
313-
throw new Exception("Verify failed using " + jwksStrength.size() + " of " + jwks.size() + " provided JWKs.");
290+
291+
throw new ElasticsearchException("Verify failed using " + jwksStrength.size() + " of " + jwks.size() + " provided JWKs.");
314292
}
315293

316294
public static JWSVerifier createJwsVerifier(final JWK jwk) throws JOSEException {

x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwkValidateUtilTests.java

Lines changed: 17 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010
import com.nimbusds.jose.JWSAlgorithm;
1111
import com.nimbusds.jose.jwk.JWK;
1212
import com.nimbusds.jose.jwk.OctetSequenceKey;
13-
import com.nimbusds.jose.util.Base64URL;
1413

1514
import org.apache.logging.log4j.LogManager;
1615
import org.apache.logging.log4j.Logger;
1716
import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
1817

1918
import java.nio.charset.StandardCharsets;
19+
import java.util.Arrays;
20+
import java.util.Collection;
2021
import java.util.List;
2122

2223
import static org.hamcrest.Matchers.anyOf;
@@ -27,46 +28,27 @@ public class JwkValidateUtilTests extends JwtTestCase {
2728

2829
private static final Logger LOGGER = LogManager.getLogger(JwkValidateUtilTests.class);
2930

30-
// HMAC JWKSet setting can use keys from randomJwkHmac()
31-
// HMAC key setting cannot use randomJwkHmac(), it must use randomJwkHmacString()
32-
public void testConvertHmacJwkToStringToJwk() throws Exception {
33-
final JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(randomFrom(JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC));
34-
35-
// Use HMAC random bytes for OIDC JWKSet setting only. Demonstrate encode/decode fails if used in OIDC HMAC key setting.
36-
final OctetSequenceKey hmacKeyRandomBytes = JwtTestCase.randomJwkHmac(jwsAlgorithm);
37-
assertThat(this.hmacEncodeDecodeAsPasswordTestHelper(hmacKeyRandomBytes), is(false));
38-
39-
// Convert HMAC random bytes to UTF8 bytes. This makes it usable as an OIDC HMAC key setting.
40-
final OctetSequenceKey hmacKeyString1 = JwtTestCase.conditionJwkHmacForOidc(hmacKeyRandomBytes);
41-
assertThat(this.hmacEncodeDecodeAsPasswordTestHelper(hmacKeyString1), is(true));
42-
43-
// Generate HMAC UTF8 bytes. This is usable as an OIDC HMAC key setting.
44-
final OctetSequenceKey hmacKeyString2 = JwtTestCase.randomJwkHmacOidc(jwsAlgorithm);
45-
assertThat(this.hmacEncodeDecodeAsPasswordTestHelper(hmacKeyString2), is(true));
31+
// Test decode bytes as UTF8 to String, encode back to UTF8, and compare to original bytes. If same, it is safe for OIDC JWK encode.
32+
static boolean isJwkHmacOidcSafe(final JWK jwk) {
33+
if (jwk instanceof OctetSequenceKey jwkHmac) {
34+
final byte[] rawKeyBytes = jwkHmac.getKeyValue().decode();
35+
return Arrays.equals(rawKeyBytes, new String(rawKeyBytes, StandardCharsets.UTF_8).getBytes(StandardCharsets.UTF_8));
36+
}
37+
return true;
4638
}
4739

48-
private boolean hmacEncodeDecodeAsPasswordTestHelper(final OctetSequenceKey hmacKey) {
49-
final OctetSequenceKey hmacKeyNoAttributes = JwtTestCase.jwkHmacRemoveAttributes(hmacKey);
50-
// Encode input key as Base64(keyBytes) and Utf8String(keyBytes)
51-
final String keyBytesToBase64 = hmacKey.getKeyValue().toString();
52-
final String keyBytesAsUtf8 = hmacKey.getKeyValue().decodeToString();
53-
54-
// Decode Base64(keyBytes) into new key and compare to original. This always works.
55-
final OctetSequenceKey decodeFromBase64 = new OctetSequenceKey.Builder(new Base64URL(keyBytesToBase64)).build();
56-
LOGGER.info("Base64 enc/dec test:\ngen: [" + hmacKey + "]\nenc: [" + keyBytesToBase64 + "]\ndec: [" + decodeFromBase64 + "]\n");
57-
if (decodeFromBase64.equals(hmacKeyNoAttributes) == false) {
58-
return false;
40+
static boolean areJwkHmacOidcSafe(final Collection<JWK> jwks) {
41+
for (final JWK jwk : jwks) {
42+
if (JwkValidateUtilTests.isJwkHmacOidcSafe(jwk) == false) {
43+
return false;
44+
}
5945
}
60-
61-
// Decode Utf8String(keyBytes) into new key and compare to original. Only works for randomJwkHmacString, fails for randomJwkHmac.
62-
final OctetSequenceKey decodeFromUtf8 = new OctetSequenceKey.Builder(keyBytesAsUtf8.getBytes(StandardCharsets.UTF_8)).build();
63-
LOGGER.info("UTF8 enc/dec test:\ngen: [" + hmacKey + "]\nenc: [" + keyBytesAsUtf8 + "]\ndec: [" + decodeFromUtf8 + "]\n");
64-
return decodeFromUtf8.equals(hmacKeyNoAttributes);
46+
return true;
6547
}
6648

6749
public void testComputeBitLengthRsa() throws Exception {
6850
for (final String signatureAlgorithmRsa : JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_RSA) {
69-
final JWK jwk = JwtTestCase.randomJwk(signatureAlgorithmRsa);
51+
final JWK jwk = JwtTestCase.randomJwkRsa(JWSAlgorithm.parse(signatureAlgorithmRsa));
7052
final int minLength = JwkValidateUtil.computeBitLengthRsa(jwk.toRSAKey().toPublicKey());
7153
assertThat(minLength, is(anyOf(equalTo(2048), equalTo(3072))));
7254
}
@@ -86,7 +68,7 @@ public void testAlgsJwksAllPkcNotFiltered() throws Exception {
8668

8769
private void filterJwksAndAlgorithmsTestHelper(final List<String> candidateAlgs) throws JOSEException {
8870
final List<String> algsRandom = randomOfMinUnique(2, candidateAlgs); // duplicates allowed
89-
final List<JwtIssuer.AlgJwkPair> algJwkPairsAll = JwtTestCase.randomJwks(algsRandom);
71+
final List<JwtIssuer.AlgJwkPair> algJwkPairsAll = JwtTestCase.randomJwks(algsRandom, randomBoolean());
9072
final List<JWK> jwks = algJwkPairsAll.stream().map(JwtIssuer.AlgJwkPair::jwk).toList();
9173
final List<String> algsAll = algJwkPairsAll.stream().map(JwtIssuer.AlgJwkPair::alg).toList();
9274
final List<JWK> jwksAll = algJwkPairsAll.stream().map(JwtIssuer.AlgJwkPair::jwk).toList();

0 commit comments

Comments
 (0)