Skip to content

Commit ffca18a

Browse files
committed
fix JWKS KID selection
1 parent 9a13a6f commit ffca18a

File tree

4 files changed

+266
-89
lines changed

4 files changed

+266
-89
lines changed

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<modelVersion>4.0.0</modelVersion>
1212
<artifactId>oauth-token-manager</artifactId>
13-
<version>1.0.10</version>
13+
<version>1.0.11</version>
1414
<name>OauthTokenManager</name>
1515
<packaging>jar</packaging>
1616

src/main/java/info/unterrainer/oauthtokenmanager/OauthTokenManager.java

Lines changed: 113 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import java.security.PublicKey;
1313
import java.security.spec.RSAPublicKeySpec;
1414
import java.util.Base64;
15+
import java.util.Map;
16+
import java.util.concurrent.ConcurrentHashMap;
1517

1618
import org.keycloak.TokenVerifier;
1719
import org.keycloak.common.VerificationException;
@@ -32,91 +34,131 @@ public class OauthTokenManager {
3234
private final String host;
3335
private final String realm;
3436

35-
private String authUrl;
36-
private PublicKey publicKey = null;
37+
private String jwksUrl;
38+
private final Map<String, PublicKey> publicKeysByKid = new ConcurrentHashMap<>();
39+
private volatile long lastFetchTimestamp = 0L;
3740

38-
public void initPublicKey() {
39-
String correctedHost = host;
40-
String correctedRealm = realm;
41+
private static final long REFRESH_INTERVAL_MS = 6 * 60 * 60 * 1000; // 6 hours cache-validity
4142

42-
if (publicKey != null)
43-
return;
44-
if (!correctedHost.endsWith("/"))
45-
correctedHost += "/";
46-
if (!correctedRealm.startsWith("/"))
47-
correctedRealm = "/" + correctedRealm;
43+
public synchronized void initPublicKeys() {
44+
String correctedHost = host.endsWith("/") ? host : host + "/";
45+
String correctedRealm = realm.startsWith("/") ? realm.substring(1) : realm;
46+
jwksUrl = correctedHost + "realms/" + correctedRealm + "/protocol/openid-connect/certs";
4847

49-
authUrl = correctedHost + "realms" + correctedRealm + "/protocol/openid-connect/certs";
5048
try {
51-
log.info("Getting public key from: [{}]", authUrl);
52-
publicKey = fetchPublicKey(authUrl);
49+
log.info("Fetching JWKS from [{}]", jwksUrl);
50+
ObjectMapper om = new ObjectMapper();
51+
HttpClient client = HttpClient.newHttpClient();
52+
HttpRequest req = HttpRequest.newBuilder().uri(URI.create(jwksUrl)).GET().build();
53+
HttpResponse<String> res = client.send(req, HttpResponse.BodyHandlers.ofString());
54+
if (res.statusCode() >= 300)
55+
throw new IOException("Failed to fetch JWKS: HTTP " + res.statusCode());
56+
57+
JsonNode jwks = om.readTree(res.body());
58+
Map<String, PublicKey> newMap = new ConcurrentHashMap<>();
59+
60+
for (JsonNode key : jwks.withArray("keys")) {
61+
if (!key.has("kid") || !key.has("n") || !key.has("e"))
62+
continue;
63+
String kid = key.get("kid").asText();
64+
String n = key.get("n").asText();
65+
String e = key.get("e").asText();
66+
67+
BigInteger modulus = new BigInteger(1, Base64.getUrlDecoder().decode(n));
68+
BigInteger exponent = new BigInteger(1, Base64.getUrlDecoder().decode(e));
69+
70+
RSAPublicKeySpec spec = new RSAPublicKeySpec(modulus, exponent);
71+
PublicKey pk = KeyFactory.getInstance("RSA").generatePublic(spec);
72+
newMap.put(kid, pk);
73+
}
74+
75+
publicKeysByKid.clear();
76+
publicKeysByKid.putAll(newMap);
77+
lastFetchTimestamp = System.currentTimeMillis();
78+
79+
log.info("Loaded {} JWKS keys from {} (kids={})", newMap.size(), jwksUrl, newMap.keySet());
5380
} catch (Exception e) {
54-
log.error("There was an error fetching the PublicKey from the openIdConnect-server [{}].", authUrl);
55-
throw new IllegalStateException(e);
81+
log.error("Failed to fetch JWKS keys from [{}]", jwksUrl, e);
82+
throw new IllegalStateException("Could not load JWKS from " + jwksUrl, e);
5683
}
5784
}
5885

59-
private PublicKey fetchPublicKey(String jwksUrl) throws Exception {
60-
ObjectMapper objectMapper = new ObjectMapper();
61-
HttpClient client = HttpClient.newHttpClient();
62-
HttpRequest request = HttpRequest.newBuilder().uri(URI.create(jwksUrl)).GET().build();
63-
64-
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
65-
66-
if (response.statusCode() >= 300) {
67-
throw new IOException("Failed to fetch JWKS: HTTP " + response.statusCode());
86+
public String extractKidFromJwt(String jwt) {
87+
try {
88+
String[] parts = jwt.split("\\.");
89+
if (parts.length < 2)
90+
return null;
91+
String headerJson = new String(Base64.getUrlDecoder().decode(parts[0]), StandardCharsets.UTF_8);
92+
JsonNode node = new ObjectMapper().readTree(headerJson);
93+
return node.has("kid") ? node.get("kid").asText() : null;
94+
} catch (Exception e) {
95+
return null;
6896
}
97+
}
6998

70-
JsonNode jwks = objectMapper.readTree(response.body());
71-
// Just take the first key for now.
72-
JsonNode key = jwks.get("keys").get(0);
73-
74-
String modulusBase64 = key.get("n").asText();
75-
String exponentBase64 = key.get("e").asText();
76-
77-
byte[] modulusBytes = Base64.getUrlDecoder().decode(modulusBase64);
78-
byte[] exponentBytes = Base64.getUrlDecoder().decode(exponentBase64);
79-
80-
BigInteger modulus = new BigInteger(1, modulusBytes);
81-
BigInteger exponent = new BigInteger(1, exponentBytes);
82-
83-
RSAPublicKeySpec spec = new RSAPublicKeySpec(modulus, exponent);
84-
KeyFactory factory = KeyFactory.getInstance("RSA");
85-
return factory.generatePublic(spec);
99+
public PublicKey getKeyForKid(String kid) {
100+
if (publicKeysByKid.isEmpty() || System.currentTimeMillis() - lastFetchTimestamp > REFRESH_INTERVAL_MS)
101+
initPublicKeys();
102+
103+
PublicKey pk = publicKeysByKid.get(kid);
104+
if (pk == null) {
105+
log.warn("No cached key for kid='{}'. Refreshing JWKS...", kid);
106+
initPublicKeys();
107+
pk = publicKeysByKid.get(kid);
108+
if (pk == null) {
109+
log.error("JWKS refresh did not contain kid='{}'. Possible misconfiguration or key rotation issue.",
110+
kid);
111+
throw new UnauthorizedException("Unknown key ID: " + kid);
112+
}
113+
}
114+
return pk;
86115
}
87116

88-
/**
89-
* Checks the access token and verifies its signature. If the token is valid,
90-
* returns a tenantId.
91-
*
92-
* @param accessToken
93-
* @return tenantId or null if the token is invalid or not present.
94-
*/
95117
public String checkAccess(String accessToken) {
96118
try {
97119
TokenVerifier<AccessToken> tokenVerifier = persistUserInfoInContext(accessToken);
98120
if (tokenVerifier == null)
99-
throw new UnauthorizedException();
121+
throw new UnauthorizedException("Token could not be parsed.");
100122

101-
initPublicKey();
102-
tokenVerifier.publicKey(publicKey);
103-
try {
104-
tokenVerifier.verifySignature();
105-
} catch (VerificationException e) {
106-
throw new UnauthorizedException(
107-
"Error verifying token from user with publicKey obtained from keycloak.", e);
108-
}
123+
String rawJwt = accessToken.startsWith("Bearer ") ? accessToken.substring(7) : accessToken;
124+
String kid = extractKidFromJwt(rawJwt);
125+
if (kid == null)
126+
throw new UnauthorizedException("Token has no 'kid' header.");
127+
128+
PublicKey pk = getKeyForKid(kid);
109129

110130
try {
131+
tokenVerifier.publicKey(pk);
132+
tokenVerifier.verifySignature();
111133
tokenVerifier.verify();
112-
AccessToken token = tokenVerifier.getToken();
113-
return (String) token.getOtherClaims().get("tenants_read");
114134
} catch (VerificationException e) {
115-
throw new ForbiddenException();
135+
// Retry once after forced JWKS refresh
136+
log.warn("Signature verification failed for kid='{}'. Retrying after JWKS refresh.", kid);
137+
initPublicKeys();
138+
PublicKey refreshedPk = publicKeysByKid.get(kid);
139+
if (refreshedPk == null) {
140+
log.error("Token verification failed after refresh. kid='{}' unknown.", kid);
141+
throw new UnauthorizedException("Invalid token signature. kid=" + kid, e);
142+
}
143+
try {
144+
tokenVerifier.publicKey(refreshedPk);
145+
tokenVerifier.verifySignature();
146+
tokenVerifier.verify();
147+
} catch (VerificationException e2) {
148+
throw new UnauthorizedException("Token signature invalid after refresh (kid=" + kid + ")", e2);
149+
}
116150
}
151+
152+
AccessToken token = tokenVerifier.getToken();
153+
return (String) token.getOtherClaims().get("tenants_read");
154+
155+
} catch (VerificationException e) {
156+
throw new UnauthorizedException("Token verification failed.", e);
157+
} catch (UnauthorizedException | ForbiddenException e) {
158+
throw e;
117159
} catch (Exception e) {
118160
log.error("Error checking token.", e);
119-
throw e;
161+
throw new UnauthorizedException("Error verifying token: " + e.getMessage(), e);
120162
}
121163
}
122164

@@ -129,20 +171,12 @@ private TokenVerifier<AccessToken> persistUserInfoInContext(String authorization
129171

130172
try {
131173
TokenVerifier<AccessToken> tokenVerifier = TokenVerifier.create(authorizationHeader, AccessToken.class);
132-
RemoteOauthToken remoteAccessToken = RemoteOauthToken.builder()
133-
.accessToken(tokenVerifier.getToken())
134-
.build();
135-
if (!remoteAccessToken.getAccessToken().isActive()) {
136-
log.warn("Token is inactive.");
174+
AccessToken token = tokenVerifier.getToken();
175+
if (token == null || !token.isActive()) {
176+
log.warn("Token is inactive or null.");
137177
return null;
138178
}
139-
// Disabled to enable getting token from side-channels like 'localhost'.
140-
/*
141-
* if (!remoteAccessToken.getIssuer().equalsIgnoreCase(authUrl)) {
142-
* log.warn("Token has wrong real-url."); return null; }
143-
*/
144179
return tokenVerifier;
145-
146180
} catch (VerificationException e) {
147181
log.warn("Token was checked and deemed invalid.", e);
148182
return null;
@@ -156,37 +190,29 @@ public LocalOauthTokens getTokensFromCredentials(String clientId, String usernam
156190
public LocalOauthTokens getTokensFromCredentials(String clientId, String clientSecret, String username,
157191
String password) {
158192
try {
159-
String tokenEndpoint = host;
160-
if (!tokenEndpoint.endsWith("/"))
161-
tokenEndpoint += "/";
193+
String tokenEndpoint = host.endsWith("/") ? host : host + "/";
162194
tokenEndpoint += "realms/" + realm + "/protocol/openid-connect/token";
163195

164196
String form = "grant_type=password" + "&client_id=" + URLEncoder.encode(clientId, StandardCharsets.UTF_8)
165197
+ "&username=" + URLEncoder.encode(username, StandardCharsets.UTF_8) + "&password="
166198
+ URLEncoder.encode(password, StandardCharsets.UTF_8);
167-
if (clientSecret != null) {
199+
if (clientSecret != null)
168200
form += "&client_secret=" + URLEncoder.encode(clientSecret, StandardCharsets.UTF_8);
169-
}
170201

171-
HttpRequest request = HttpRequest.newBuilder()
202+
HttpRequest req = HttpRequest.newBuilder()
172203
.uri(URI.create(tokenEndpoint))
173204
.header("Content-Type", "application/x-www-form-urlencoded")
174205
.POST(HttpRequest.BodyPublishers.ofString(form))
175206
.build();
176207

177208
HttpClient client = HttpClient.newHttpClient();
178-
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
179-
180-
if (response.statusCode() >= 300) {
181-
throw new IOException("Token request failed: HTTP " + response.statusCode() + " - " + response.body());
182-
}
209+
HttpResponse<String> res = client.send(req, HttpResponse.BodyHandlers.ofString());
210+
if (res.statusCode() >= 300)
211+
throw new IOException("Token request failed: HTTP " + res.statusCode() + " - " + res.body());
183212

184213
ObjectMapper mapper = new ObjectMapper();
185-
JsonNode json = mapper.readTree(response.body());
214+
JsonNode json = mapper.readTree(res.body());
186215
log.info("Token received successfully.");
187-
log.debug("Access token: {}", json.get("access_token").asText());
188-
log.debug("Refresh token: {}", json.get("refresh_token").asText());
189-
190216
return LocalOauthTokens.builder()
191217
.accessToken(json.get("access_token").asText())
192218
.refreshToken(json.get("refresh_token").asText())
@@ -197,5 +223,4 @@ public LocalOauthTokens getTokensFromCredentials(String clientId, String clientS
197223
throw new IllegalStateException("Unable to get token", e);
198224
}
199225
}
200-
201226
}

0 commit comments

Comments
 (0)