Skip to content

Commit 64053df

Browse files
committed
fix(authorization): cache remote keys
RemoteJWKSet already caches keys from remote URLs, but all instances of key sources weren't reused.
1 parent 2201fa0 commit 64053df

File tree

1 file changed

+60
-29
lines changed

1 file changed

+60
-29
lines changed

authorization/src/main/kotlin/org/modelix/authorization/ModelixJWTUtil.kt

Lines changed: 60 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import com.nimbusds.jose.util.Resource
2727
import com.nimbusds.jwt.JWTClaimsSet
2828
import com.nimbusds.jwt.JWTParser
2929
import com.nimbusds.jwt.proc.DefaultJWTProcessor
30+
import com.nimbusds.jwt.proc.JWTProcessor
3031
import io.ktor.client.HttpClient
3132
import io.ktor.client.request.get
3233
import io.ktor.client.statement.bodyAsText
@@ -46,7 +47,6 @@ import java.util.Base64
4647
import java.util.Date
4748
import java.util.UUID
4849
import javax.crypto.spec.SecretKeySpec
49-
import kotlin.String
5050

5151
class ModelixJWTUtil {
5252
private var hmacKeys = LinkedHashMap<JWSAlgorithm, ByteArray>()
@@ -57,28 +57,73 @@ class ModelixJWTUtil {
5757
private var ktorClient: HttpClient? = null
5858
var accessControlDataProvider: IAccessControlDataProvider = EmptyAccessControlDataProvider()
5959

60+
private var jwtProcessor: JWTProcessor<SecurityContext>? = null
61+
62+
@Synchronized
63+
private fun getOrCreateJwtProcessor(): JWTProcessor<SecurityContext> {
64+
return jwtProcessor ?: DefaultJWTProcessor<SecurityContext>().also { processor ->
65+
val keySelectors: List<JWSKeySelector<SecurityContext>> = hmacKeys.map { it.toPair() }.map {
66+
SingleKeyJWSKeySelector<SecurityContext>(it.first, SecretKeySpec(it.second, it.first.name))
67+
} + jwksUrls.map {
68+
val client = this.ktorClient
69+
if (client == null) {
70+
JWSAlgorithmFamilyJWSKeySelector.fromJWKSetURL<SecurityContext>(it)
71+
} else {
72+
JWSAlgorithmFamilyJWSKeySelector.fromJWKSource<SecurityContext>(RemoteJWKSet(it, KtorResourceRetriever(client)))
73+
}
74+
} + rsaPublicKeys.map {
75+
JWSAlgorithmFamilyJWSKeySelector.fromJWKSource<SecurityContext>(ImmutableJWKSet(JWKSet(it.toPublicJWK())))
76+
}
77+
78+
processor.jwsKeySelector = if (keySelectors.size == 1) keySelectors.single() else CompositeJWSKeySelector(keySelectors)
79+
80+
val expectedKeyId = this.expectedKeyId
81+
if (expectedKeyId != null) {
82+
processor.jwsVerifierFactory = object : DefaultJWSVerifierFactory() {
83+
override fun createJWSVerifier(header: JWSHeader, key: Key): JWSVerifier {
84+
if (header.keyID != expectedKeyId) {
85+
throw BadJOSEException("Invalid key ID. [expected=$expectedKeyId, actual=${header.keyID}]")
86+
}
87+
return super.createJWSVerifier(header, key)
88+
}
89+
}
90+
}
91+
}.also { jwtProcessor = it }
92+
}
93+
94+
private fun resetJwtProcess() {
95+
jwtProcessor = null
96+
}
97+
98+
@Synchronized
6099
fun canVerifyTokens(): Boolean {
61100
return hmacKeys.isNotEmpty() || rsaPublicKeys.isNotEmpty() || jwksUrls.isNotEmpty()
62101
}
63102

64103
/**
65104
* Tokens are only valid if they are signed with this key.
66105
*/
106+
@Synchronized
67107
fun requireKeyId(id: String) {
68108
expectedKeyId = id
69109
}
70110

111+
@Synchronized
71112
fun useKtorClient(client: HttpClient) {
113+
resetJwtProcess()
72114
this.ktorClient = client.config {
73115
expectSuccess = true
74116
}
75117
}
76118

119+
@Synchronized
77120
fun addJwksUrl(url: String) {
78121
addJwksUrl(URI(url).toURL())
79122
}
80123

124+
@Synchronized
81125
fun addJwksUrl(url: URL) {
126+
resetJwtProcess()
82127
jwksUrls += url
83128
}
84129

@@ -91,28 +136,37 @@ class ModelixJWTUtil {
91136
addHmacKey(key.toByteArray().ensureMinSecretLength(algorithm), algorithm)
92137
}
93138

139+
@Synchronized
94140
fun addPublicKey(key: JWK) {
95141
requireNotNull(key.keyID) { "Key doesn't specify a key ID: $key" }
96142
requireNotNull(key.algorithm) { "Key doesn't specify an algorithm: $key" }
143+
resetJwtProcess()
97144
rsaPublicKeys.add(key)
98145
}
99146

147+
@Synchronized
100148
fun setRSAPrivateKey(key: JWK) {
101149
requireNotNull(key.keyID) { "Key doesn't specify a key ID: $key" }
102150
requireNotNull(key.algorithm) { "Key doesn't specify an algorithm: $key" }
151+
resetJwtProcess()
103152
this.rsaPrivateKey = key
104153
addPublicKey(key.toPublicJWK())
105154
}
106155

156+
@Synchronized
107157
private fun addHmacKey(key: ByteArray, algorithm: JWSAlgorithm) {
158+
resetJwtProcess()
108159
hmacKeys[algorithm] = key
109160
}
110161

162+
@Synchronized
111163
fun getPublicJWKS(): JWKSet {
112164
return JWKSet(listOfNotNull(rsaPrivateKey)).toPublicJWKSet()
113165
}
114166

167+
@Synchronized
115168
fun loadKeysFromEnvironment() {
169+
resetJwtProcess()
116170
System.getenv().filter { it.key.startsWith("MODELIX_JWK_FILE") }.values.forEach {
117171
File(it).walk().forEach { file ->
118172
when (file.extension) {
@@ -127,6 +181,7 @@ class ModelixJWTUtil {
127181
.forEach { addJwksUrl(URI(it).toURL()) }
128182
}
129183

184+
@Synchronized
130185
fun createAccessToken(user: String, grantedPermissions: List<String>, additionalTokenContent: (TokenBuilder) -> Unit = {}): String {
131186
val signer: JWSSigner
132187
val algorithm: JWSAlgorithm
@@ -174,6 +229,7 @@ class ModelixJWTUtil {
174229
return token.claims[ModelixTokenConstants.PERMISSIONS]?.asList(String::class.java)
175230
}
176231

232+
@Synchronized
177233
fun loadGrantedPermissions(token: DecodedJWT, evaluator: PermissionEvaluator) {
178234
val permissions = extractPermissions(token)
179235

@@ -252,42 +308,17 @@ class ModelixJWTUtil {
252308
}
253309

254310
private fun loadJwk(key: JWK) {
311+
resetJwtProcess()
255312
if (key.isPrivate) {
256313
setRSAPrivateKey(key)
257314
} else {
258315
addPublicKey(key)
259316
}
260317
}
261318

319+
@Synchronized
262320
fun verifyToken(token: String) {
263-
DefaultJWTProcessor<SecurityContext>().also { processor ->
264-
val keySelectors: List<JWSKeySelector<SecurityContext>> = hmacKeys.map { it.toPair() }.map {
265-
SingleKeyJWSKeySelector<SecurityContext>(it.first, SecretKeySpec(it.second, it.first.name))
266-
} + jwksUrls.map {
267-
val client = this.ktorClient
268-
if (client == null) {
269-
JWSAlgorithmFamilyJWSKeySelector.fromJWKSetURL<SecurityContext>(it)
270-
} else {
271-
JWSAlgorithmFamilyJWSKeySelector.fromJWKSource<SecurityContext>(RemoteJWKSet(it, KtorResourceRetriever(client)))
272-
}
273-
} + rsaPublicKeys.map {
274-
JWSAlgorithmFamilyJWSKeySelector.fromJWKSource<SecurityContext>(ImmutableJWKSet(JWKSet(it.toPublicJWK())))
275-
}
276-
277-
processor.jwsKeySelector = if (keySelectors.size == 1) keySelectors.single() else CompositeJWSKeySelector(keySelectors)
278-
279-
val expectedKeyId = this.expectedKeyId
280-
if (expectedKeyId != null) {
281-
processor.jwsVerifierFactory = object : DefaultJWSVerifierFactory() {
282-
override fun createJWSVerifier(header: JWSHeader, key: Key): JWSVerifier {
283-
if (header.keyID != expectedKeyId) {
284-
throw BadJOSEException("Invalid key ID. [expected=$expectedKeyId, actual=${header.keyID}]")
285-
}
286-
return super.createJWSVerifier(header, key)
287-
}
288-
}
289-
}
290-
}.process(JWTParser.parse(token), null)
321+
getOrCreateJwtProcessor().process(JWTParser.parse(token), null)
291322
}
292323

293324
class TokenBuilder(private val builder: JWTClaimsSet.Builder) {

0 commit comments

Comments
 (0)