Skip to content

Commit 01ea825

Browse files
authored
Merge pull request #1250 from modelix/MODELIX-1062-Reload-RSA-keys-from-disk
fix(authorization): reload keys when the file changes
2 parents 9cd9974 + 2b80425 commit 01ea825

File tree

2 files changed

+259
-71
lines changed

2 files changed

+259
-71
lines changed

authorization/src/main/kotlin/org/modelix/authorization/ModelixJWTUtil.kt

Lines changed: 76 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,15 @@ import com.nimbusds.jose.crypto.MACSigner
1010
import com.nimbusds.jose.crypto.RSASSASigner
1111
import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory
1212
import com.nimbusds.jose.jwk.JWK
13+
import com.nimbusds.jose.jwk.JWKMatcher
14+
import com.nimbusds.jose.jwk.JWKSelector
1315
import com.nimbusds.jose.jwk.JWKSet
1416
import com.nimbusds.jose.jwk.KeyType
1517
import com.nimbusds.jose.jwk.KeyUse
1618
import com.nimbusds.jose.jwk.RSAKey
1719
import com.nimbusds.jose.jwk.gen.RSAKeyGenerator
1820
import com.nimbusds.jose.jwk.source.ImmutableJWKSet
21+
import com.nimbusds.jose.jwk.source.JWKSource
1922
import com.nimbusds.jose.jwk.source.RemoteJWKSet
2023
import com.nimbusds.jose.proc.BadJOSEException
2124
import com.nimbusds.jose.proc.JWSAlgorithmFamilyJWSKeySelector
@@ -47,32 +50,26 @@ import java.util.Base64
4750
import java.util.Date
4851
import java.util.UUID
4952
import javax.crypto.spec.SecretKeySpec
53+
import kotlin.time.Duration
54+
import kotlin.time.Duration.Companion.seconds
5055

5156
class ModelixJWTUtil {
52-
private var hmacKeys = LinkedHashMap<JWSAlgorithm, ByteArray>()
53-
private var rsaPrivateKey: JWK? = null
54-
private var rsaPublicKeys = ArrayList<JWK>()
55-
private val jwksUrls = LinkedHashSet<URL>()
57+
private val hmacKeys = LinkedHashMap<JWSAlgorithm, ByteArray>()
58+
private val jwkSources = ArrayList<JWKSource<SecurityContext>>()
5659
private var expectedKeyId: String? = null
5760
private var ktorClient: HttpClient? = null
5861
var accessControlDataProvider: IAccessControlDataProvider = EmptyAccessControlDataProvider()
5962

6063
private var jwtProcessor: JWTProcessor<SecurityContext>? = null
64+
var fileRefreshTime: Duration = 5.seconds
6165

6266
@Synchronized
6367
private fun getOrCreateJwtProcessor(): JWTProcessor<SecurityContext> {
6468
return jwtProcessor ?: DefaultJWTProcessor<SecurityContext>().also { processor ->
6569
val keySelectors: List<JWSKeySelector<SecurityContext>> = hmacKeys.map { it.toPair() }.map {
6670
SingleKeyJWSKeySelector<SecurityContext>(it.first, SecretKeySpec(it.second, it.first.name))
67-
} + jwksUrls.map {
68-
val client = this.ktorClient
69-
if (client == null) {
70-
JWSAlgorithmFamilyJWSKeySelector.fromJWKSetURL<SecurityContext>(it)
71-
} else {
72-
JWSAlgorithmFamilyJWSKeySelector.fromJWKSource<SecurityContext>(RemoteJWKSet(it, KtorResourceRetriever(client)))
73-
}
74-
} + rsaPublicKeys.map {
75-
JWSAlgorithmFamilyJWSKeySelector.fromJWKSource<SecurityContext>(ImmutableJWKSet(JWKSet(it.toPublicJWK())))
71+
} + jwkSources.map {
72+
JWSAlgorithmFamilyJWSKeySelector.fromJWKSource<SecurityContext>(it)
7673
}
7774

7875
processor.jwsKeySelector = if (keySelectors.size == 1) keySelectors.single() else CompositeJWSKeySelector(keySelectors)
@@ -91,13 +88,19 @@ class ModelixJWTUtil {
9188
}.also { jwtProcessor = it }
9289
}
9390

94-
private fun resetJwtProcess() {
91+
fun getPrivateKey(): JWK? {
92+
return jwkSources.flatMap {
93+
it.get(JWKSelector(JWKMatcher.Builder().privateOnly(true).algorithms(JWSAlgorithm.Family.RSA.toSet()).build()), null)
94+
}.firstOrNull()
95+
}
96+
97+
private fun resetJwtProcessor() {
9598
jwtProcessor = null
9699
}
97100

98101
@Synchronized
99102
fun canVerifyTokens(): Boolean {
100-
return hmacKeys.isNotEmpty() || rsaPublicKeys.isNotEmpty() || jwksUrls.isNotEmpty()
103+
return hmacKeys.isNotEmpty() || jwkSources.isNotEmpty()
101104
}
102105

103106
/**
@@ -110,7 +113,7 @@ class ModelixJWTUtil {
110113

111114
@Synchronized
112115
fun useKtorClient(client: HttpClient) {
113-
resetJwtProcess()
116+
resetJwtProcessor()
114117
this.ktorClient = client.config {
115118
expectSuccess = true
116119
}
@@ -123,8 +126,8 @@ class ModelixJWTUtil {
123126

124127
@Synchronized
125128
fun addJwksUrl(url: URL) {
126-
resetJwtProcess()
127-
jwksUrls += url
129+
resetJwtProcessor()
130+
jwkSources.add(RemoteJWKSet(url, ktorClient?.let { KtorResourceRetriever(it) }))
128131
}
129132

130133
fun setHmac512Key(key: String) {
@@ -140,56 +143,67 @@ class ModelixJWTUtil {
140143
fun addPublicKey(key: JWK) {
141144
requireNotNull(key.keyID) { "Key doesn't specify a key ID: $key" }
142145
requireNotNull(key.algorithm) { "Key doesn't specify an algorithm: $key" }
143-
resetJwtProcess()
144-
rsaPublicKeys.add(key)
146+
resetJwtProcessor()
147+
jwkSources.add(ImmutableJWKSet(JWKSet(key.toPublicJWK())))
145148
}
146149

147150
@Synchronized
148151
fun setRSAPrivateKey(key: JWK) {
149152
requireNotNull(key.keyID) { "Key doesn't specify a key ID: $key" }
150153
requireNotNull(key.algorithm) { "Key doesn't specify an algorithm: $key" }
151-
resetJwtProcess()
152-
this.rsaPrivateKey = key
153-
addPublicKey(key.toPublicJWK())
154+
resetJwtProcessor()
155+
jwkSources.add(ImmutableJWKSet(JWKSet(listOf(key, key.toPublicJWK()))))
154156
}
155157

156158
@Synchronized
157-
private fun addHmacKey(key: ByteArray, algorithm: JWSAlgorithm) {
158-
resetJwtProcess()
159-
hmacKeys[algorithm] = key
159+
fun addJWK(key: JWK) {
160+
requireNotNull(key.keyID) { "Key doesn't specify a key ID: $key" }
161+
requireNotNull(key.algorithm) { "Key doesn't specify an algorithm: $key" }
162+
resetJwtProcessor()
163+
if (key.isPrivate) {
164+
jwkSources.add(ImmutableJWKSet(JWKSet(listOf(key, key.toPublicJWK()))))
165+
} else {
166+
jwkSources.add(ImmutableJWKSet(JWKSet(key)))
167+
}
160168
}
161169

162170
@Synchronized
163-
fun getPublicJWKS(): JWKSet {
164-
return JWKSet(listOfNotNull(rsaPrivateKey)).toPublicJWKSet()
171+
private fun addHmacKey(key: ByteArray, algorithm: JWSAlgorithm) {
172+
resetJwtProcessor()
173+
hmacKeys[algorithm] = key
165174
}
166175

167176
@Synchronized
168177
fun loadKeysFromEnvironment() {
169-
resetJwtProcess()
178+
resetJwtProcessor()
170179
System.getenv().filter { it.key.startsWith("MODELIX_JWK_FILE") }.values.forEach {
171-
File(it).walk().forEach { file ->
172-
when (file.extension) {
173-
"pem" -> loadPemFile(file.readText())
174-
"json" -> loadJwkFile(file.readText())
175-
}
176-
}
180+
loadKeysFromFiles(File(it))
177181
}
178182

179183
// allows multiple URLs (MODELIX_JWK_URI1, MODELIX_JWK_URI2, MODELIX_JWK_URI_MODEL_SERVER, ...)
180184
System.getenv().filter { it.key.startsWith("MODELIX_JWK_URI") }.values
181185
.forEach { addJwksUrl(URI(it).toURL()) }
182186
}
183187

188+
fun loadKeysFromFiles(fileOrFolder: File) {
189+
fileOrFolder.walk().forEach { file ->
190+
when (file.extension) {
191+
"pem" -> jwkSources.add(PemFileJWKSet(file))
192+
"json" -> jwkSources.add(FileJWKSet(file))
193+
}
194+
}
195+
}
196+
184197
@Synchronized
185198
fun createAccessToken(user: String, grantedPermissions: List<String>, additionalTokenContent: (TokenBuilder) -> Unit = {}): String {
186199
val signer: JWSSigner
187200
val algorithm: JWSAlgorithm
188201
val signingKeyId: String?
189-
val jwk = this.rsaPrivateKey
202+
val jwk = getPrivateKey()
190203
if (jwk != null) {
191204
signer = RSASSASigner(jwk.toRSAKey().toRSAPrivateKey())
192-
algorithm = checkNotNull(jwk.algorithm) { "RSA key doesn't specify an algorithm" } as JWSAlgorithm
205+
algorithm = checkNotNull(jwk.algorithm) { "RSA key doesn't specify an algorithm" }
206+
.let { it as? JWSAlgorithm ?: JWSAlgorithm.parse(it.name) }
193207
signingKeyId = checkNotNull(jwk.keyID) { "RSA key doesn't specify a key ID" }
194208
} else {
195209
val entry = checkNotNull(hmacKeys.entries.firstOrNull()) { "No keys for signing provided" }
@@ -273,11 +287,7 @@ class ModelixJWTUtil {
273287
.issueTime(Date())
274288
.algorithm(JWSAlgorithm.RS256)
275289
.generate()
276-
.also { setRSAPrivateKey(it) }
277-
}
278-
279-
fun loadPemFile(fileContent: String): JWK {
280-
return ensureValidKey(JWK.parseFromPEMEncodedObjects(fileContent)).also { loadJwk(it) }
290+
.also { addJWK(it) }
281291
}
282292

283293
private fun ensureValidKey(key: JWK): JWK {
@@ -302,19 +312,6 @@ class ModelixJWTUtil {
302312
return RSAKey.Builder(rsaKey).keyID(keyId).build()
303313
}
304314

305-
fun loadJwkFile(fileContent: String): JWK {
306-
return JWK.parse(fileContent).also { loadJwk(it) }
307-
}
308-
309-
private fun loadJwk(key: JWK) {
310-
resetJwtProcess()
311-
if (key.isPrivate) {
312-
setRSAPrivateKey(key)
313-
} else {
314-
addPublicKey(key)
315-
}
316-
}
317-
318315
@Synchronized
319316
fun verifyToken(token: String) {
320317
getOrCreateJwtProcessor().process(JWTParser.parse(token), null)
@@ -327,6 +324,30 @@ class ModelixJWTUtil {
327324
}
328325
}
329326

327+
private inner class PemFileJWKSet<C : SecurityContext>(pemFile: File) : FileJWKSet<C>(pemFile) {
328+
override fun readFile(): JWKSet {
329+
return JWKSet(ensureValidKey(JWK.parseFromPEMEncodedObjects(file.readText())))
330+
}
331+
}
332+
333+
private open inner class FileJWKSet<C : SecurityContext>(val file: File) : JWKSource<C> {
334+
private var loadedAt: Long = 0
335+
private var cached: JWKSet? = null
336+
337+
open fun readFile(): JWKSet {
338+
return JWKSet(JWK.parse(file.readText()))
339+
}
340+
341+
override fun get(jwkSelector: JWKSelector, context: C?): List<JWK?>? {
342+
val jwks = cached.takeIf { System.currentTimeMillis() - loadedAt < fileRefreshTime.inWholeMilliseconds }
343+
?: readFile().also {
344+
cached = it
345+
loadedAt = System.currentTimeMillis()
346+
}
347+
return jwkSelector.select(jwks)
348+
}
349+
}
350+
330351
companion object {
331352
fun extractUserId(jwt: DecodedJWT): String? {
332353
return jwt.getClaim(KeycloakTokenConstants.EMAIL)?.asString()

0 commit comments

Comments
 (0)