Skip to content

Commit 2b80425

Browse files
committed
fix(authorization): reload keys when the file changes
If the Kubernetes secret containing the RSA key changed, the container usually isn't restarted, so we have to make sure we don't use an outdated key.
1 parent 59ca7b6 commit 2b80425

File tree

2 files changed

+259
-71
lines changed

2 files changed

+259
-71
lines changed

authorization/src/main/kotlin/org/modelix/authorization/ModelixJWTUtil.kt

Lines changed: 76 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,15 @@ import com.nimbusds.jose.crypto.MACSigner
1010
import com.nimbusds.jose.crypto.RSASSASigner
1111
import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory
1212
import com.nimbusds.jose.jwk.JWK
13+
import com.nimbusds.jose.jwk.JWKMatcher
14+
import com.nimbusds.jose.jwk.JWKSelector
1315
import com.nimbusds.jose.jwk.JWKSet
1416
import com.nimbusds.jose.jwk.KeyType
1517
import com.nimbusds.jose.jwk.KeyUse
1618
import com.nimbusds.jose.jwk.RSAKey
1719
import com.nimbusds.jose.jwk.gen.RSAKeyGenerator
1820
import com.nimbusds.jose.jwk.source.ImmutableJWKSet
21+
import com.nimbusds.jose.jwk.source.JWKSource
1922
import com.nimbusds.jose.jwk.source.RemoteJWKSet
2023
import com.nimbusds.jose.proc.BadJOSEException
2124
import com.nimbusds.jose.proc.JWSAlgorithmFamilyJWSKeySelector
@@ -47,32 +50,26 @@ import java.util.Base64
4750
import java.util.Date
4851
import java.util.UUID
4952
import javax.crypto.spec.SecretKeySpec
53+
import kotlin.time.Duration
54+
import kotlin.time.Duration.Companion.seconds
5055

5156
class ModelixJWTUtil {
52-
private var hmacKeys = LinkedHashMap<JWSAlgorithm, ByteArray>()
53-
private var rsaPrivateKey: JWK? = null
54-
private var rsaPublicKeys = ArrayList<JWK>()
55-
private val jwksUrls = LinkedHashSet<URL>()
57+
private val hmacKeys = LinkedHashMap<JWSAlgorithm, ByteArray>()
58+
private val jwkSources = ArrayList<JWKSource<SecurityContext>>()
5659
private var expectedKeyId: String? = null
5760
private var ktorClient: HttpClient? = null
5861
var accessControlDataProvider: IAccessControlDataProvider = EmptyAccessControlDataProvider()
5962

6063
private var jwtProcessor: JWTProcessor<SecurityContext>? = null
64+
var fileRefreshTime: Duration = 5.seconds
6165

6266
@Synchronized
6367
private fun getOrCreateJwtProcessor(): JWTProcessor<SecurityContext> {
6468
return jwtProcessor ?: DefaultJWTProcessor<SecurityContext>().also { processor ->
6569
val keySelectors: List<JWSKeySelector<SecurityContext>> = hmacKeys.map { it.toPair() }.map {
6670
SingleKeyJWSKeySelector<SecurityContext>(it.first, SecretKeySpec(it.second, it.first.name))
67-
} + jwksUrls.map {
68-
val client = this.ktorClient
69-
if (client == null) {
70-
JWSAlgorithmFamilyJWSKeySelector.fromJWKSetURL<SecurityContext>(it)
71-
} else {
72-
JWSAlgorithmFamilyJWSKeySelector.fromJWKSource<SecurityContext>(RemoteJWKSet(it, KtorResourceRetriever(client)))
73-
}
74-
} + rsaPublicKeys.map {
75-
JWSAlgorithmFamilyJWSKeySelector.fromJWKSource<SecurityContext>(ImmutableJWKSet(JWKSet(it.toPublicJWK())))
71+
} + jwkSources.map {
72+
JWSAlgorithmFamilyJWSKeySelector.fromJWKSource<SecurityContext>(it)
7673
}
7774

7875
processor.jwsKeySelector = if (keySelectors.size == 1) keySelectors.single() else CompositeJWSKeySelector(keySelectors)
@@ -91,13 +88,19 @@ class ModelixJWTUtil {
9188
}.also { jwtProcessor = it }
9289
}
9390

94-
private fun resetJwtProcess() {
91+
fun getPrivateKey(): JWK? {
92+
return jwkSources.flatMap {
93+
it.get(JWKSelector(JWKMatcher.Builder().privateOnly(true).algorithms(JWSAlgorithm.Family.RSA.toSet()).build()), null)
94+
}.firstOrNull()
95+
}
96+
97+
private fun resetJwtProcessor() {
9598
jwtProcessor = null
9699
}
97100

98101
@Synchronized
99102
fun canVerifyTokens(): Boolean {
100-
return hmacKeys.isNotEmpty() || rsaPublicKeys.isNotEmpty() || jwksUrls.isNotEmpty()
103+
return hmacKeys.isNotEmpty() || jwkSources.isNotEmpty()
101104
}
102105

103106
/**
@@ -110,7 +113,7 @@ class ModelixJWTUtil {
110113

111114
@Synchronized
112115
fun useKtorClient(client: HttpClient) {
113-
resetJwtProcess()
116+
resetJwtProcessor()
114117
this.ktorClient = client.config {
115118
expectSuccess = true
116119
}
@@ -123,8 +126,8 @@ class ModelixJWTUtil {
123126

124127
@Synchronized
125128
fun addJwksUrl(url: URL) {
126-
resetJwtProcess()
127-
jwksUrls += url
129+
resetJwtProcessor()
130+
jwkSources.add(RemoteJWKSet(url, ktorClient?.let { KtorResourceRetriever(it) }))
128131
}
129132

130133
fun setHmac512Key(key: String) {
@@ -140,56 +143,67 @@ class ModelixJWTUtil {
140143
fun addPublicKey(key: JWK) {
141144
requireNotNull(key.keyID) { "Key doesn't specify a key ID: $key" }
142145
requireNotNull(key.algorithm) { "Key doesn't specify an algorithm: $key" }
143-
resetJwtProcess()
144-
rsaPublicKeys.add(key)
146+
resetJwtProcessor()
147+
jwkSources.add(ImmutableJWKSet(JWKSet(key.toPublicJWK())))
145148
}
146149

147150
@Synchronized
148151
fun setRSAPrivateKey(key: JWK) {
149152
requireNotNull(key.keyID) { "Key doesn't specify a key ID: $key" }
150153
requireNotNull(key.algorithm) { "Key doesn't specify an algorithm: $key" }
151-
resetJwtProcess()
152-
this.rsaPrivateKey = key
153-
addPublicKey(key.toPublicJWK())
154+
resetJwtProcessor()
155+
jwkSources.add(ImmutableJWKSet(JWKSet(listOf(key, key.toPublicJWK()))))
154156
}
155157

156158
@Synchronized
157-
private fun addHmacKey(key: ByteArray, algorithm: JWSAlgorithm) {
158-
resetJwtProcess()
159-
hmacKeys[algorithm] = key
159+
fun addJWK(key: JWK) {
160+
requireNotNull(key.keyID) { "Key doesn't specify a key ID: $key" }
161+
requireNotNull(key.algorithm) { "Key doesn't specify an algorithm: $key" }
162+
resetJwtProcessor()
163+
if (key.isPrivate) {
164+
jwkSources.add(ImmutableJWKSet(JWKSet(listOf(key, key.toPublicJWK()))))
165+
} else {
166+
jwkSources.add(ImmutableJWKSet(JWKSet(key)))
167+
}
160168
}
161169

162170
@Synchronized
163-
fun getPublicJWKS(): JWKSet {
164-
return JWKSet(listOfNotNull(rsaPrivateKey)).toPublicJWKSet()
171+
private fun addHmacKey(key: ByteArray, algorithm: JWSAlgorithm) {
172+
resetJwtProcessor()
173+
hmacKeys[algorithm] = key
165174
}
166175

167176
@Synchronized
168177
fun loadKeysFromEnvironment() {
169-
resetJwtProcess()
178+
resetJwtProcessor()
170179
System.getenv().filter { it.key.startsWith("MODELIX_JWK_FILE") }.values.forEach {
171-
File(it).walk().forEach { file ->
172-
when (file.extension) {
173-
"pem" -> loadPemFile(file.readText())
174-
"json" -> loadJwkFile(file.readText())
175-
}
176-
}
180+
loadKeysFromFiles(File(it))
177181
}
178182

179183
// allows multiple URLs (MODELIX_JWK_URI1, MODELIX_JWK_URI2, MODELIX_JWK_URI_MODEL_SERVER, ...)
180184
System.getenv().filter { it.key.startsWith("MODELIX_JWK_URI") }.values
181185
.forEach { addJwksUrl(URI(it).toURL()) }
182186
}
183187

188+
fun loadKeysFromFiles(fileOrFolder: File) {
189+
fileOrFolder.walk().forEach { file ->
190+
when (file.extension) {
191+
"pem" -> jwkSources.add(PemFileJWKSet(file))
192+
"json" -> jwkSources.add(FileJWKSet(file))
193+
}
194+
}
195+
}
196+
184197
@Synchronized
185198
fun createAccessToken(user: String, grantedPermissions: List<String>, additionalTokenContent: (TokenBuilder) -> Unit = {}): String {
186199
val signer: JWSSigner
187200
val algorithm: JWSAlgorithm
188201
val signingKeyId: String?
189-
val jwk = this.rsaPrivateKey
202+
val jwk = getPrivateKey()
190203
if (jwk != null) {
191204
signer = RSASSASigner(jwk.toRSAKey().toRSAPrivateKey())
192-
algorithm = checkNotNull(jwk.algorithm) { "RSA key doesn't specify an algorithm" } as JWSAlgorithm
205+
algorithm = checkNotNull(jwk.algorithm) { "RSA key doesn't specify an algorithm" }
206+
.let { it as? JWSAlgorithm ?: JWSAlgorithm.parse(it.name) }
193207
signingKeyId = checkNotNull(jwk.keyID) { "RSA key doesn't specify a key ID" }
194208
} else {
195209
val entry = checkNotNull(hmacKeys.entries.firstOrNull()) { "No keys for signing provided" }
@@ -273,11 +287,7 @@ class ModelixJWTUtil {
273287
.issueTime(Date())
274288
.algorithm(JWSAlgorithm.RS256)
275289
.generate()
276-
.also { setRSAPrivateKey(it) }
277-
}
278-
279-
fun loadPemFile(fileContent: String): JWK {
280-
return ensureValidKey(JWK.parseFromPEMEncodedObjects(fileContent)).also { loadJwk(it) }
290+
.also { addJWK(it) }
281291
}
282292

283293
private fun ensureValidKey(key: JWK): JWK {
@@ -302,19 +312,6 @@ class ModelixJWTUtil {
302312
return RSAKey.Builder(rsaKey).keyID(keyId).build()
303313
}
304314

305-
fun loadJwkFile(fileContent: String): JWK {
306-
return JWK.parse(fileContent).also { loadJwk(it) }
307-
}
308-
309-
private fun loadJwk(key: JWK) {
310-
resetJwtProcess()
311-
if (key.isPrivate) {
312-
setRSAPrivateKey(key)
313-
} else {
314-
addPublicKey(key)
315-
}
316-
}
317-
318315
@Synchronized
319316
fun verifyToken(token: String) {
320317
getOrCreateJwtProcessor().process(JWTParser.parse(token), null)
@@ -327,6 +324,30 @@ class ModelixJWTUtil {
327324
}
328325
}
329326

327+
private inner class PemFileJWKSet<C : SecurityContext>(pemFile: File) : FileJWKSet<C>(pemFile) {
328+
override fun readFile(): JWKSet {
329+
return JWKSet(ensureValidKey(JWK.parseFromPEMEncodedObjects(file.readText())))
330+
}
331+
}
332+
333+
private open inner class FileJWKSet<C : SecurityContext>(val file: File) : JWKSource<C> {
334+
private var loadedAt: Long = 0
335+
private var cached: JWKSet? = null
336+
337+
open fun readFile(): JWKSet {
338+
return JWKSet(JWK.parse(file.readText()))
339+
}
340+
341+
override fun get(jwkSelector: JWKSelector, context: C?): List<JWK?>? {
342+
val jwks = cached.takeIf { System.currentTimeMillis() - loadedAt < fileRefreshTime.inWholeMilliseconds }
343+
?: readFile().also {
344+
cached = it
345+
loadedAt = System.currentTimeMillis()
346+
}
347+
return jwkSelector.select(jwks)
348+
}
349+
}
350+
330351
companion object {
331352
fun extractUserId(jwt: DecodedJWT): String? {
332353
return jwt.getClaim(KeycloakTokenConstants.EMAIL)?.asString()

0 commit comments

Comments
 (0)