@@ -10,12 +10,15 @@ import com.nimbusds.jose.crypto.MACSigner
10
10
import com.nimbusds.jose.crypto.RSASSASigner
11
11
import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory
12
12
import com.nimbusds.jose.jwk.JWK
13
+ import com.nimbusds.jose.jwk.JWKMatcher
14
+ import com.nimbusds.jose.jwk.JWKSelector
13
15
import com.nimbusds.jose.jwk.JWKSet
14
16
import com.nimbusds.jose.jwk.KeyType
15
17
import com.nimbusds.jose.jwk.KeyUse
16
18
import com.nimbusds.jose.jwk.RSAKey
17
19
import com.nimbusds.jose.jwk.gen.RSAKeyGenerator
18
20
import com.nimbusds.jose.jwk.source.ImmutableJWKSet
21
+ import com.nimbusds.jose.jwk.source.JWKSource
19
22
import com.nimbusds.jose.jwk.source.RemoteJWKSet
20
23
import com.nimbusds.jose.proc.BadJOSEException
21
24
import com.nimbusds.jose.proc.JWSAlgorithmFamilyJWSKeySelector
@@ -47,32 +50,26 @@ import java.util.Base64
47
50
import java.util.Date
48
51
import java.util.UUID
49
52
import javax.crypto.spec.SecretKeySpec
53
+ import kotlin.time.Duration
54
+ import kotlin.time.Duration.Companion.seconds
50
55
51
56
class ModelixJWTUtil {
52
- private var hmacKeys = LinkedHashMap <JWSAlgorithm , ByteArray >()
53
- private var rsaPrivateKey: JWK ? = null
54
- private var rsaPublicKeys = ArrayList <JWK >()
55
- private val jwksUrls = LinkedHashSet <URL >()
57
+ private val hmacKeys = LinkedHashMap <JWSAlgorithm , ByteArray >()
58
+ private val jwkSources = ArrayList <JWKSource <SecurityContext >>()
56
59
private var expectedKeyId: String? = null
57
60
private var ktorClient: HttpClient ? = null
58
61
var accessControlDataProvider: IAccessControlDataProvider = EmptyAccessControlDataProvider ()
59
62
60
63
private var jwtProcessor: JWTProcessor <SecurityContext >? = null
64
+ var fileRefreshTime: Duration = 5 .seconds
61
65
62
66
@Synchronized
63
67
private fun getOrCreateJwtProcessor (): JWTProcessor <SecurityContext > {
64
68
return jwtProcessor ? : DefaultJWTProcessor <SecurityContext >().also { processor ->
65
69
val keySelectors: List <JWSKeySelector <SecurityContext >> = hmacKeys.map { it.toPair() }.map {
66
70
SingleKeyJWSKeySelector <SecurityContext >(it.first, SecretKeySpec (it.second, it.first.name))
67
- } + jwksUrls.map {
68
- val client = this .ktorClient
69
- if (client == null ) {
70
- JWSAlgorithmFamilyJWSKeySelector .fromJWKSetURL<SecurityContext >(it)
71
- } else {
72
- JWSAlgorithmFamilyJWSKeySelector .fromJWKSource<SecurityContext >(RemoteJWKSet (it, KtorResourceRetriever (client)))
73
- }
74
- } + rsaPublicKeys.map {
75
- JWSAlgorithmFamilyJWSKeySelector .fromJWKSource<SecurityContext >(ImmutableJWKSet (JWKSet (it.toPublicJWK())))
71
+ } + jwkSources.map {
72
+ JWSAlgorithmFamilyJWSKeySelector .fromJWKSource<SecurityContext >(it)
76
73
}
77
74
78
75
processor.jwsKeySelector = if (keySelectors.size == 1 ) keySelectors.single() else CompositeJWSKeySelector (keySelectors)
@@ -91,13 +88,19 @@ class ModelixJWTUtil {
91
88
}.also { jwtProcessor = it }
92
89
}
93
90
94
- private fun resetJwtProcess () {
91
+ fun getPrivateKey (): JWK ? {
92
+ return jwkSources.flatMap {
93
+ it.get(JWKSelector (JWKMatcher .Builder ().privateOnly(true ).algorithms(JWSAlgorithm .Family .RSA .toSet()).build()), null )
94
+ }.firstOrNull()
95
+ }
96
+
97
+ private fun resetJwtProcessor () {
95
98
jwtProcessor = null
96
99
}
97
100
98
101
@Synchronized
99
102
fun canVerifyTokens (): Boolean {
100
- return hmacKeys.isNotEmpty() || rsaPublicKeys.isNotEmpty() || jwksUrls .isNotEmpty()
103
+ return hmacKeys.isNotEmpty() || jwkSources .isNotEmpty()
101
104
}
102
105
103
106
/* *
@@ -110,7 +113,7 @@ class ModelixJWTUtil {
110
113
111
114
@Synchronized
112
115
fun useKtorClient (client : HttpClient ) {
113
- resetJwtProcess ()
116
+ resetJwtProcessor ()
114
117
this .ktorClient = client.config {
115
118
expectSuccess = true
116
119
}
@@ -123,8 +126,8 @@ class ModelixJWTUtil {
123
126
124
127
@Synchronized
125
128
fun addJwksUrl (url : URL ) {
126
- resetJwtProcess ()
127
- jwksUrls + = url
129
+ resetJwtProcessor ()
130
+ jwkSources.add( RemoteJWKSet (url, ktorClient?. let { KtorResourceRetriever (it) }))
128
131
}
129
132
130
133
fun setHmac512Key (key : String ) {
@@ -140,56 +143,67 @@ class ModelixJWTUtil {
140
143
fun addPublicKey (key : JWK ) {
141
144
requireNotNull(key.keyID) { " Key doesn't specify a key ID: $key " }
142
145
requireNotNull(key.algorithm) { " Key doesn't specify an algorithm: $key " }
143
- resetJwtProcess ()
144
- rsaPublicKeys .add(key)
146
+ resetJwtProcessor ()
147
+ jwkSources .add(ImmutableJWKSet ( JWKSet ( key.toPublicJWK())) )
145
148
}
146
149
147
150
@Synchronized
148
151
fun setRSAPrivateKey (key : JWK ) {
149
152
requireNotNull(key.keyID) { " Key doesn't specify a key ID: $key " }
150
153
requireNotNull(key.algorithm) { " Key doesn't specify an algorithm: $key " }
151
- resetJwtProcess()
152
- this .rsaPrivateKey = key
153
- addPublicKey(key.toPublicJWK())
154
+ resetJwtProcessor()
155
+ jwkSources.add(ImmutableJWKSet (JWKSet (listOf (key, key.toPublicJWK()))))
154
156
}
155
157
156
158
@Synchronized
157
- private fun addHmacKey (key : ByteArray , algorithm : JWSAlgorithm ) {
158
- resetJwtProcess()
159
- hmacKeys[algorithm] = key
159
+ fun addJWK (key : JWK ) {
160
+ requireNotNull(key.keyID) { " Key doesn't specify a key ID: $key " }
161
+ requireNotNull(key.algorithm) { " Key doesn't specify an algorithm: $key " }
162
+ resetJwtProcessor()
163
+ if (key.isPrivate) {
164
+ jwkSources.add(ImmutableJWKSet (JWKSet (listOf (key, key.toPublicJWK()))))
165
+ } else {
166
+ jwkSources.add(ImmutableJWKSet (JWKSet (key)))
167
+ }
160
168
}
161
169
162
170
@Synchronized
163
- fun getPublicJWKS (): JWKSet {
164
- return JWKSet (listOfNotNull(rsaPrivateKey)).toPublicJWKSet()
171
+ private fun addHmacKey (key : ByteArray , algorithm : JWSAlgorithm ) {
172
+ resetJwtProcessor()
173
+ hmacKeys[algorithm] = key
165
174
}
166
175
167
176
@Synchronized
168
177
fun loadKeysFromEnvironment () {
169
- resetJwtProcess ()
178
+ resetJwtProcessor ()
170
179
System .getenv().filter { it.key.startsWith(" MODELIX_JWK_FILE" ) }.values.forEach {
171
- File (it).walk().forEach { file ->
172
- when (file.extension) {
173
- " pem" -> loadPemFile(file.readText())
174
- " json" -> loadJwkFile(file.readText())
175
- }
176
- }
180
+ loadKeysFromFiles(File (it))
177
181
}
178
182
179
183
// allows multiple URLs (MODELIX_JWK_URI1, MODELIX_JWK_URI2, MODELIX_JWK_URI_MODEL_SERVER, ...)
180
184
System .getenv().filter { it.key.startsWith(" MODELIX_JWK_URI" ) }.values
181
185
.forEach { addJwksUrl(URI (it).toURL()) }
182
186
}
183
187
188
+ fun loadKeysFromFiles (fileOrFolder : File ) {
189
+ fileOrFolder.walk().forEach { file ->
190
+ when (file.extension) {
191
+ " pem" -> jwkSources.add(PemFileJWKSet (file))
192
+ " json" -> jwkSources.add(FileJWKSet (file))
193
+ }
194
+ }
195
+ }
196
+
184
197
@Synchronized
185
198
fun createAccessToken (user : String , grantedPermissions : List <String >, additionalTokenContent : (TokenBuilder ) -> Unit = {}): String {
186
199
val signer: JWSSigner
187
200
val algorithm: JWSAlgorithm
188
201
val signingKeyId: String?
189
- val jwk = this .rsaPrivateKey
202
+ val jwk = getPrivateKey()
190
203
if (jwk != null ) {
191
204
signer = RSASSASigner (jwk.toRSAKey().toRSAPrivateKey())
192
- algorithm = checkNotNull(jwk.algorithm) { " RSA key doesn't specify an algorithm" } as JWSAlgorithm
205
+ algorithm = checkNotNull(jwk.algorithm) { " RSA key doesn't specify an algorithm" }
206
+ .let { it as ? JWSAlgorithm ? : JWSAlgorithm .parse(it.name) }
193
207
signingKeyId = checkNotNull(jwk.keyID) { " RSA key doesn't specify a key ID" }
194
208
} else {
195
209
val entry = checkNotNull(hmacKeys.entries.firstOrNull()) { " No keys for signing provided" }
@@ -273,11 +287,7 @@ class ModelixJWTUtil {
273
287
.issueTime(Date ())
274
288
.algorithm(JWSAlgorithm .RS256 )
275
289
.generate()
276
- .also { setRSAPrivateKey(it) }
277
- }
278
-
279
- fun loadPemFile (fileContent : String ): JWK {
280
- return ensureValidKey(JWK .parseFromPEMEncodedObjects(fileContent)).also { loadJwk(it) }
290
+ .also { addJWK(it) }
281
291
}
282
292
283
293
private fun ensureValidKey (key : JWK ): JWK {
@@ -302,19 +312,6 @@ class ModelixJWTUtil {
302
312
return RSAKey .Builder (rsaKey).keyID(keyId).build()
303
313
}
304
314
305
- fun loadJwkFile (fileContent : String ): JWK {
306
- return JWK .parse(fileContent).also { loadJwk(it) }
307
- }
308
-
309
- private fun loadJwk (key : JWK ) {
310
- resetJwtProcess()
311
- if (key.isPrivate) {
312
- setRSAPrivateKey(key)
313
- } else {
314
- addPublicKey(key)
315
- }
316
- }
317
-
318
315
@Synchronized
319
316
fun verifyToken (token : String ) {
320
317
getOrCreateJwtProcessor().process(JWTParser .parse(token), null )
@@ -327,6 +324,30 @@ class ModelixJWTUtil {
327
324
}
328
325
}
329
326
327
+ private inner class PemFileJWKSet <C : SecurityContext >(pemFile : File ) : FileJWKSet<C>(pemFile) {
328
+ override fun readFile (): JWKSet {
329
+ return JWKSet (ensureValidKey(JWK .parseFromPEMEncodedObjects(file.readText())))
330
+ }
331
+ }
332
+
333
+ private open inner class FileJWKSet <C : SecurityContext >(val file : File ) : JWKSource<C> {
334
+ private var loadedAt: Long = 0
335
+ private var cached: JWKSet ? = null
336
+
337
+ open fun readFile (): JWKSet {
338
+ return JWKSet (JWK .parse(file.readText()))
339
+ }
340
+
341
+ override fun get (jwkSelector : JWKSelector , context : C ? ): List <JWK ?>? {
342
+ val jwks = cached.takeIf { System .currentTimeMillis() - loadedAt < fileRefreshTime.inWholeMilliseconds }
343
+ ? : readFile().also {
344
+ cached = it
345
+ loadedAt = System .currentTimeMillis()
346
+ }
347
+ return jwkSelector.select(jwks)
348
+ }
349
+ }
350
+
330
351
companion object {
331
352
fun extractUserId (jwt : DecodedJWT ): String? {
332
353
return jwt.getClaim(KeycloakTokenConstants .EMAIL )?.asString()
0 commit comments