1212import java .security .PublicKey ;
1313import java .security .spec .RSAPublicKeySpec ;
1414import java .util .Base64 ;
15+ import java .util .Map ;
16+ import java .util .concurrent .ConcurrentHashMap ;
1517
1618import org .keycloak .TokenVerifier ;
1719import org .keycloak .common .VerificationException ;
@@ -32,91 +34,131 @@ public class OauthTokenManager {
3234 private final String host ;
3335 private final String realm ;
3436
35- private String authUrl ;
36- private PublicKey publicKey = null ;
37+ private String jwksUrl ;
38+ private final Map <String , PublicKey > publicKeysByKid = new ConcurrentHashMap <>();
39+ private volatile long lastFetchTimestamp = 0L ;
3740
38- public void initPublicKey () {
39- String correctedHost = host ;
40- String correctedRealm = realm ;
41+ private static final long REFRESH_INTERVAL_MS = 6 * 60 * 60 * 1000 ; // 6 hours cache-validity
4142
42- if (publicKey != null )
43- return ;
44- if (!correctedHost .endsWith ("/" ))
45- correctedHost += "/" ;
46- if (!correctedRealm .startsWith ("/" ))
47- correctedRealm = "/" + correctedRealm ;
43+ public synchronized void initPublicKeys () {
44+ String correctedHost = host .endsWith ("/" ) ? host : host + "/" ;
45+ String correctedRealm = realm .startsWith ("/" ) ? realm .substring (1 ) : realm ;
46+ jwksUrl = correctedHost + "realms/" + correctedRealm + "/protocol/openid-connect/certs" ;
4847
49- authUrl = correctedHost + "realms" + correctedRealm + "/protocol/openid-connect/certs" ;
5048 try {
51- log .info ("Getting public key from: [{}]" , authUrl );
52- publicKey = fetchPublicKey (authUrl );
49+ log .info ("Fetching JWKS from [{}]" , jwksUrl );
50+ ObjectMapper om = new ObjectMapper ();
51+ HttpClient client = HttpClient .newHttpClient ();
52+ HttpRequest req = HttpRequest .newBuilder ().uri (URI .create (jwksUrl )).GET ().build ();
53+ HttpResponse <String > res = client .send (req , HttpResponse .BodyHandlers .ofString ());
54+ if (res .statusCode () >= 300 )
55+ throw new IOException ("Failed to fetch JWKS: HTTP " + res .statusCode ());
56+
57+ JsonNode jwks = om .readTree (res .body ());
58+ Map <String , PublicKey > newMap = new ConcurrentHashMap <>();
59+
60+ for (JsonNode key : jwks .withArray ("keys" )) {
61+ if (!key .has ("kid" ) || !key .has ("n" ) || !key .has ("e" ))
62+ continue ;
63+ String kid = key .get ("kid" ).asText ();
64+ String n = key .get ("n" ).asText ();
65+ String e = key .get ("e" ).asText ();
66+
67+ BigInteger modulus = new BigInteger (1 , Base64 .getUrlDecoder ().decode (n ));
68+ BigInteger exponent = new BigInteger (1 , Base64 .getUrlDecoder ().decode (e ));
69+
70+ RSAPublicKeySpec spec = new RSAPublicKeySpec (modulus , exponent );
71+ PublicKey pk = KeyFactory .getInstance ("RSA" ).generatePublic (spec );
72+ newMap .put (kid , pk );
73+ }
74+
75+ publicKeysByKid .clear ();
76+ publicKeysByKid .putAll (newMap );
77+ lastFetchTimestamp = System .currentTimeMillis ();
78+
79+ log .info ("Loaded {} JWKS keys from {} (kids={})" , newMap .size (), jwksUrl , newMap .keySet ());
5380 } catch (Exception e ) {
54- log .error ("There was an error fetching the PublicKey from the openIdConnect-server [{}]. " , authUrl );
55- throw new IllegalStateException (e );
81+ log .error ("Failed to fetch JWKS keys from [{}]" , jwksUrl , e );
82+ throw new IllegalStateException ("Could not load JWKS from " + jwksUrl , e );
5683 }
5784 }
5885
59- private PublicKey fetchPublicKey (String jwksUrl ) throws Exception {
60- ObjectMapper objectMapper = new ObjectMapper ();
61- HttpClient client = HttpClient .newHttpClient ();
62- HttpRequest request = HttpRequest .newBuilder ().uri (URI .create (jwksUrl )).GET ().build ();
63-
64- HttpResponse <String > response = client .send (request , HttpResponse .BodyHandlers .ofString ());
65-
66- if (response .statusCode () >= 300 ) {
67- throw new IOException ("Failed to fetch JWKS: HTTP " + response .statusCode ());
86+ public String extractKidFromJwt (String jwt ) {
87+ try {
88+ String [] parts = jwt .split ("\\ ." );
89+ if (parts .length < 2 )
90+ return null ;
91+ String headerJson = new String (Base64 .getUrlDecoder ().decode (parts [0 ]), StandardCharsets .UTF_8 );
92+ JsonNode node = new ObjectMapper ().readTree (headerJson );
93+ return node .has ("kid" ) ? node .get ("kid" ).asText () : null ;
94+ } catch (Exception e ) {
95+ return null ;
6896 }
97+ }
6998
70- JsonNode jwks = objectMapper . readTree ( response . body ());
71- // Just take the first key for now.
72- JsonNode key = jwks . get ( "keys" ). get ( 0 );
73-
74- String modulusBase64 = key .get ("n" ). asText ( );
75- String exponentBase64 = key . get ( "e" ). asText ();
76-
77- byte [] modulusBytes = Base64 . getUrlDecoder (). decode ( modulusBase64 );
78- byte [] exponentBytes = Base64 . getUrlDecoder (). decode ( exponentBase64 );
79-
80- BigInteger modulus = new BigInteger ( 1 , modulusBytes );
81- BigInteger exponent = new BigInteger ( 1 , exponentBytes );
82-
83- RSAPublicKeySpec spec = new RSAPublicKeySpec ( modulus , exponent );
84- KeyFactory factory = KeyFactory . getInstance ( "RSA" );
85- return factory . generatePublic ( spec ) ;
99+ public PublicKey getKeyForKid ( String kid ) {
100+ if ( publicKeysByKid . isEmpty () || System . currentTimeMillis () - lastFetchTimestamp > REFRESH_INTERVAL_MS )
101+ initPublicKeys ( );
102+
103+ PublicKey pk = publicKeysByKid .get (kid );
104+ if ( pk == null ) {
105+ log . warn ( "No cached key for kid='{}'. Refreshing JWKS..." , kid );
106+ initPublicKeys ( );
107+ pk = publicKeysByKid . get ( kid );
108+ if ( pk == null ) {
109+ log . error ( "JWKS refresh did not contain kid='{}'. Possible misconfiguration or key rotation issue." ,
110+ kid );
111+ throw new UnauthorizedException ( "Unknown key ID: " + kid );
112+ }
113+ }
114+ return pk ;
86115 }
87116
88- /**
89- * Checks the access token and verifies its signature. If the token is valid,
90- * returns a tenantId.
91- *
92- * @param accessToken
93- * @return tenantId or null if the token is invalid or not present.
94- */
95117 public String checkAccess (String accessToken ) {
96118 try {
97119 TokenVerifier <AccessToken > tokenVerifier = persistUserInfoInContext (accessToken );
98120 if (tokenVerifier == null )
99- throw new UnauthorizedException ();
121+ throw new UnauthorizedException ("Token could not be parsed." );
100122
101- initPublicKey ();
102- tokenVerifier .publicKey (publicKey );
103- try {
104- tokenVerifier .verifySignature ();
105- } catch (VerificationException e ) {
106- throw new UnauthorizedException (
107- "Error verifying token from user with publicKey obtained from keycloak." , e );
108- }
123+ String rawJwt = accessToken .startsWith ("Bearer " ) ? accessToken .substring (7 ) : accessToken ;
124+ String kid = extractKidFromJwt (rawJwt );
125+ if (kid == null )
126+ throw new UnauthorizedException ("Token has no 'kid' header." );
127+
128+ PublicKey pk = getKeyForKid (kid );
109129
110130 try {
131+ tokenVerifier .publicKey (pk );
132+ tokenVerifier .verifySignature ();
111133 tokenVerifier .verify ();
112- AccessToken token = tokenVerifier .getToken ();
113- return (String ) token .getOtherClaims ().get ("tenants_read" );
114134 } catch (VerificationException e ) {
115- throw new ForbiddenException ();
135+ // Retry once after forced JWKS refresh
136+ log .warn ("Signature verification failed for kid='{}'. Retrying after JWKS refresh." , kid );
137+ initPublicKeys ();
138+ PublicKey refreshedPk = publicKeysByKid .get (kid );
139+ if (refreshedPk == null ) {
140+ log .error ("Token verification failed after refresh. kid='{}' unknown." , kid );
141+ throw new UnauthorizedException ("Invalid token signature. kid=" + kid , e );
142+ }
143+ try {
144+ tokenVerifier .publicKey (refreshedPk );
145+ tokenVerifier .verifySignature ();
146+ tokenVerifier .verify ();
147+ } catch (VerificationException e2 ) {
148+ throw new UnauthorizedException ("Token signature invalid after refresh (kid=" + kid + ")" , e2 );
149+ }
116150 }
151+
152+ AccessToken token = tokenVerifier .getToken ();
153+ return (String ) token .getOtherClaims ().get ("tenants_read" );
154+
155+ } catch (VerificationException e ) {
156+ throw new UnauthorizedException ("Token verification failed." , e );
157+ } catch (UnauthorizedException | ForbiddenException e ) {
158+ throw e ;
117159 } catch (Exception e ) {
118160 log .error ("Error checking token." , e );
119- throw e ;
161+ throw new UnauthorizedException ( "Error verifying token: " + e . getMessage (), e ) ;
120162 }
121163 }
122164
@@ -129,20 +171,12 @@ private TokenVerifier<AccessToken> persistUserInfoInContext(String authorization
129171
130172 try {
131173 TokenVerifier <AccessToken > tokenVerifier = TokenVerifier .create (authorizationHeader , AccessToken .class );
132- RemoteOauthToken remoteAccessToken = RemoteOauthToken .builder ()
133- .accessToken (tokenVerifier .getToken ())
134- .build ();
135- if (!remoteAccessToken .getAccessToken ().isActive ()) {
136- log .warn ("Token is inactive." );
174+ AccessToken token = tokenVerifier .getToken ();
175+ if (token == null || !token .isActive ()) {
176+ log .warn ("Token is inactive or null." );
137177 return null ;
138178 }
139- // Disabled to enable getting token from side-channels like 'localhost'.
140- /*
141- * if (!remoteAccessToken.getIssuer().equalsIgnoreCase(authUrl)) {
142- * log.warn("Token has wrong real-url."); return null; }
143- */
144179 return tokenVerifier ;
145-
146180 } catch (VerificationException e ) {
147181 log .warn ("Token was checked and deemed invalid." , e );
148182 return null ;
@@ -156,37 +190,29 @@ public LocalOauthTokens getTokensFromCredentials(String clientId, String usernam
156190 public LocalOauthTokens getTokensFromCredentials (String clientId , String clientSecret , String username ,
157191 String password ) {
158192 try {
159- String tokenEndpoint = host ;
160- if (!tokenEndpoint .endsWith ("/" ))
161- tokenEndpoint += "/" ;
193+ String tokenEndpoint = host .endsWith ("/" ) ? host : host + "/" ;
162194 tokenEndpoint += "realms/" + realm + "/protocol/openid-connect/token" ;
163195
164196 String form = "grant_type=password" + "&client_id=" + URLEncoder .encode (clientId , StandardCharsets .UTF_8 )
165197 + "&username=" + URLEncoder .encode (username , StandardCharsets .UTF_8 ) + "&password="
166198 + URLEncoder .encode (password , StandardCharsets .UTF_8 );
167- if (clientSecret != null ) {
199+ if (clientSecret != null )
168200 form += "&client_secret=" + URLEncoder .encode (clientSecret , StandardCharsets .UTF_8 );
169- }
170201
171- HttpRequest request = HttpRequest .newBuilder ()
202+ HttpRequest req = HttpRequest .newBuilder ()
172203 .uri (URI .create (tokenEndpoint ))
173204 .header ("Content-Type" , "application/x-www-form-urlencoded" )
174205 .POST (HttpRequest .BodyPublishers .ofString (form ))
175206 .build ();
176207
177208 HttpClient client = HttpClient .newHttpClient ();
178- HttpResponse <String > response = client .send (request , HttpResponse .BodyHandlers .ofString ());
179-
180- if (response .statusCode () >= 300 ) {
181- throw new IOException ("Token request failed: HTTP " + response .statusCode () + " - " + response .body ());
182- }
209+ HttpResponse <String > res = client .send (req , HttpResponse .BodyHandlers .ofString ());
210+ if (res .statusCode () >= 300 )
211+ throw new IOException ("Token request failed: HTTP " + res .statusCode () + " - " + res .body ());
183212
184213 ObjectMapper mapper = new ObjectMapper ();
185- JsonNode json = mapper .readTree (response .body ());
214+ JsonNode json = mapper .readTree (res .body ());
186215 log .info ("Token received successfully." );
187- log .debug ("Access token: {}" , json .get ("access_token" ).asText ());
188- log .debug ("Refresh token: {}" , json .get ("refresh_token" ).asText ());
189-
190216 return LocalOauthTokens .builder ()
191217 .accessToken (json .get ("access_token" ).asText ())
192218 .refreshToken (json .get ("refresh_token" ).asText ())
@@ -197,5 +223,4 @@ public LocalOauthTokens getTokensFromCredentials(String clientId, String clientS
197223 throw new IllegalStateException ("Unable to get token" , e );
198224 }
199225 }
200-
201226}
0 commit comments