diff --git a/api/src/main/resources/example.application.yaml b/api/src/main/resources/example.application.yaml index 757d7cdb..a8bee3bf 100644 --- a/api/src/main/resources/example.application.yaml +++ b/api/src/main/resources/example.application.yaml @@ -7,6 +7,7 @@ loginsvc: access-exp-time: 15min refresh-exp-time: 9h key-rotation-time: 9h + key-phase-out-time: 30min alg-name: "RS256" #Instead of generating the key in memory #The Below Config allows for the application to fetch keys from AWS Secrets Manager. @@ -18,6 +19,7 @@ loginsvc: #access-exp-time: 15min #refresh-exp-time: 9h #poll-time: 5min + #key-phase-out-time: 30min #alg-name: "RS256" config: # Generates git.properties file for use on info endpoint. diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/SecurityConfig.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/SecurityConfig.scala index 31fa252e..a3c0a073 100644 --- a/api/src/main/scala/za/co/absa/loginsvc/rest/SecurityConfig.scala +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/SecurityConfig.scala @@ -50,7 +50,8 @@ class SecurityConfig @Autowired()(authConfigsProvider: AuthConfigProvider) { "/actuator/**", "/token/refresh", // access+refresh JWT in payload, no auth "/token/public-key-jwks", - "/token/public-key").permitAll() + "/token/public-key", + "/token/public-keys").permitAll() .anyRequest().authenticated() .and() .sessionManagement() diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/config/auth/ServiceAccountConfig.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/config/auth/ServiceAccountConfig.scala index 8f3a0233..eaeb12ef 100644 --- a/api/src/main/scala/za/co/absa/loginsvc/rest/config/auth/ServiceAccountConfig.scala +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/config/auth/ServiceAccountConfig.scala @@ -88,10 +88,15 @@ case class AwsSecretsLdapUserConfig(private val secretName: String, private def getUsernameAndPasswordFromSecret: (String, String) = { try { - val secrets = AwsSecretsUtils.fetchSecret(secretName, region, Array(usernameFieldName, passwordFieldName)) - (secrets(usernameFieldName), secrets(passwordFieldName)) - } - catch { + val secretsOption = AwsSecretsUtils.fetchSecret(secretName, region, Array(usernameFieldName, passwordFieldName)) + + secretsOption.fold( + throw new Exception("Error retrieving username and password from from AWS Secrets Manager") + ) { secrets => + (secrets.secretValue(usernameFieldName), secrets.secretValue(passwordFieldName)) + } + + } catch { case e: Throwable => logger.error(s"Error occurred retrieving account data from AWS Secrets Manager", e) throw e diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/AwsSecretsManagerKeyConfig.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/AwsSecretsManagerKeyConfig.scala index c3798486..dbbcb714 100644 --- a/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/AwsSecretsManagerKeyConfig.scala +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/AwsSecretsManagerKeyConfig.scala @@ -23,6 +23,7 @@ import za.co.absa.loginsvc.utils.AwsSecretsUtils import java.security.{KeyFactory, KeyPair} import java.security.spec.{PKCS8EncodedKeySpec, X509EncodedKeySpec} +import java.time.Instant import java.util.Base64 import scala.concurrent.duration.FiniteDuration @@ -34,38 +35,57 @@ case class AwsSecretsManagerKeyConfig( algName: String, accessExpTime: FiniteDuration, refreshExpTime: FiniteDuration, - pollTime: Option[FiniteDuration] + pollTime: Option[FiniteDuration], + keyPhaseOutTime: Option[FiniteDuration] ) extends KeyConfig { private val logger = LoggerFactory.getLogger(classOf[AwsSecretsManagerKeyConfig]) override def keyRotationTime : Option[FiniteDuration] = pollTime - override def keyPair(): KeyPair = { + override def keyPair(): (KeyPair, Option[KeyPair]) = { try { - val secrets = AwsSecretsUtils.fetchSecret(secretName, region, Array(privateKeyFieldName, publicKeyFieldName)) - - val publicKeySpec: X509EncodedKeySpec = new X509EncodedKeySpec( - Base64.getDecoder.decode( - secrets(publicKeyFieldName) - ) - ) - val privateKeySpec: PKCS8EncodedKeySpec = new PKCS8EncodedKeySpec( - Base64.getDecoder.decode( - secrets(privateKeyFieldName) - ) + val currentSecretsOption = AwsSecretsUtils.fetchSecret( + secretName, + region, + Array(privateKeyFieldName, publicKeyFieldName) ) - logger.info("Key Data successfully retrieved and parsed from AWS Secrets Manager") + if(currentSecretsOption.isEmpty) + throw new Exception("Error retrieving AWSCURRENT key from from AWS Secrets Manager") - val keyFactory: KeyFactory = KeyFactory.getInstance(jwtAlgorithmToCryptoAlgorithm) - new KeyPair(keyFactory.generatePublic(publicKeySpec), keyFactory.generatePrivate(privateKeySpec)) - } - catch { + val currentKeyPair = createKeyPair(currentSecretsOption.get.secretValue) + logger.info("AWSCURRENT Key Data successfully retrieved and parsed from AWS Secrets Manager") + + val previousSecretsOption = + AwsSecretsUtils.fetchSecret( + secretName, + region, + Array(privateKeyFieldName, publicKeyFieldName), + Some("AWSPREVIOUS") + ) + + val previousKeyPair = previousSecretsOption.flatMap { previousSecrets => + try { + val keys = createKeyPair(previousSecrets.secretValue) + logger.info("AWSPREVIOUS Key Data successfully retrieved and parsed from AWS Secrets Manager") + val exp = keyPhaseOutTime.exists(isExpired(previousSecrets.createTime, _)) + if(exp) { None } + else { Some(keys) } + } catch { + case e: Throwable => + logger.warn(s"Error occurred decoding AWSPREVIOUSKEYS, skipping previous keys.", e) + None + } + } + + (currentKeyPair, previousKeyPair) + } catch { case e: Throwable => logger.error(s"Error occurred retrieving and decoding keys from AWS Secrets Manager", e) throw e } } + override def throwErrors(): Unit = this.validate().throwOnErrors() override def validate(): ConfigValidationResult = { @@ -92,4 +112,26 @@ case class AwsSecretsManagerKeyConfig( super.validate().merge(awsSecretsResultsMerge) } + + private def createKeyPair(secretKeys: Map[String, String]): KeyPair = { + + val publicKeySpec: X509EncodedKeySpec = new X509EncodedKeySpec( + Base64.getDecoder.decode( + secretKeys(publicKeyFieldName) + ) + ) + val privateKeySpec: PKCS8EncodedKeySpec = new PKCS8EncodedKeySpec( + Base64.getDecoder.decode( + secretKeys(privateKeyFieldName) + ) + ) + + val keyFactory: KeyFactory = KeyFactory.getInstance(jwtAlgorithmToCryptoAlgorithm) + new KeyPair(keyFactory.generatePublic(publicKeySpec), keyFactory.generatePrivate(privateKeySpec)) + } + + private def isExpired(creationTime: Instant, finiteDuration: FiniteDuration): Boolean = { + val expirationTime = creationTime.plus(finiteDuration.toMillis, java.time.temporal.ChronoUnit.MILLIS) + Instant.now().isAfter(expirationTime) + } } diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/InMemoryKeyConfig.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/InMemoryKeyConfig.scala index 4a43f1e1..0239a595 100644 --- a/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/InMemoryKeyConfig.scala +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/InMemoryKeyConfig.scala @@ -19,6 +19,8 @@ package za.co.absa.loginsvc.rest.config.jwt import io.jsonwebtoken.SignatureAlgorithm import io.jsonwebtoken.security.Keys import org.slf4j.LoggerFactory +import za.co.absa.loginsvc.rest.config.validation.{ConfigValidationException, ConfigValidationResult} +import za.co.absa.loginsvc.rest.config.validation.ConfigValidationResult.{ConfigValidationError, ConfigValidationSuccess} import java.security.KeyPair import scala.concurrent.duration.FiniteDuration @@ -27,17 +29,30 @@ case class InMemoryKeyConfig( algName: String, accessExpTime: FiniteDuration, refreshExpTime: FiniteDuration, - keyRotationTime: Option[FiniteDuration] + keyRotationTime: Option[FiniteDuration], + keyPhaseOutTime: Option[FiniteDuration] ) extends KeyConfig { + private var oldKeyPair: Option[KeyPair] = None private val logger = LoggerFactory.getLogger(classOf[InMemoryKeyConfig]) - override def keyPair(): KeyPair = { + override def keyPair(): (KeyPair, Option[KeyPair]) = { logger.info(s"Generating new keys - every ${keyRotationTime.getOrElse("?")}") - Keys.keyPairFor(SignatureAlgorithm.valueOf(algName)) + val newKeyPair = Keys.keyPairFor(SignatureAlgorithm.valueOf(algName)) + val result = (newKeyPair, oldKeyPair) + oldKeyPair = Some(newKeyPair) + result } - override def throwErrors(): Unit = this.validate().throwOnErrors() + override def validate(): ConfigValidationResult = { + val keyPhaseOutTimeResult = if(keyPhaseOutTime.nonEmpty && keyRotationTime.nonEmpty + && keyPhaseOutTime.get > keyRotationTime.get) { + ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime must be lower than keyRotationTime!")) + } else ConfigValidationSuccess + + super.validate().merge(keyPhaseOutTimeResult) + } + override def throwErrors(): Unit = this.validate().throwOnErrors() } diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/KeyConfig.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/KeyConfig.scala index 38f289fc..ee96104e 100644 --- a/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/KeyConfig.scala +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/KeyConfig.scala @@ -22,8 +22,7 @@ import za.co.absa.loginsvc.rest.config.validation.{ConfigValidatable, ConfigVali import za.co.absa.loginsvc.rest.config.validation.ConfigValidationResult.{ConfigValidationError, ConfigValidationSuccess} import java.security.KeyPair -import java.util.concurrent.TimeUnit -import scala.concurrent.duration.FiniteDuration +import scala.concurrent.duration._ import scala.util.{Failure, Success, Try} trait KeyConfig extends ConfigValidatable { @@ -31,7 +30,8 @@ trait KeyConfig extends ConfigValidatable { def accessExpTime: FiniteDuration def refreshExpTime: FiniteDuration def keyRotationTime: Option[FiniteDuration] - def keyPair(): KeyPair + def keyPhaseOutTime: Option[FiniteDuration] + def keyPair(): (KeyPair, Option[KeyPair]) def throwErrors(): Unit final def jwtAlgorithmToCryptoAlgorithm : String = { @@ -71,16 +71,34 @@ trait KeyConfig extends ConfigValidatable { ConfigValidationError(ConfigValidationException(s"keyRotationTime must be at least ${KeyConfig.minKeyRotationTime}")) } else ConfigValidationSuccess + val keyPhaseOutTimeResult = if (keyPhaseOutTime.nonEmpty && keyPhaseOutTime.get < KeyConfig.minKeyPhaseOutTime) { + ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime must be at least ${KeyConfig.minKeyPhaseOutTime}")) + } else ConfigValidationSuccess + + val keyPhaseOutWithRotationResult = if (keyPhaseOutTime.nonEmpty && keyRotationTime.isEmpty) { + ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime can only be enable if keyRotationTime is enable!")) + } else ConfigValidationSuccess + if (keyRotationTime.isEmpty) { logger.warn("keyRotationTime is not set in config, key-pair will not be rotated!") } - algValidation.merge(accessExpTimeResult).merge(refreshExpTimeResult).merge(keyRotationTimeResult) + if(keyPhaseOutTime.isEmpty) { + logger.warn("keyPhaseOutTime is not set in config, the previously used public key will be viewable till rotation!") + } + + algValidation + .merge(accessExpTimeResult) + .merge(refreshExpTimeResult) + .merge(keyRotationTimeResult) + .merge(keyPhaseOutTimeResult) + .merge(keyPhaseOutWithRotationResult) } } object KeyConfig { - val minAccessExpTime: FiniteDuration = FiniteDuration(10, TimeUnit.MILLISECONDS) - val minRefreshExpTime: FiniteDuration = FiniteDuration(10, TimeUnit.MILLISECONDS) - val minKeyRotationTime: FiniteDuration = FiniteDuration(10, TimeUnit.MILLISECONDS) + val minAccessExpTime: FiniteDuration = 10.milliseconds + val minRefreshExpTime: FiniteDuration = 10.milliseconds + val minKeyRotationTime: FiniteDuration = 10.milliseconds + val minKeyPhaseOutTime: FiniteDuration = 10.milliseconds } diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/controller/TokenController.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/controller/TokenController.scala index 5728615d..52503efa 100644 --- a/api/src/main/scala/za/co/absa/loginsvc/rest/controller/TokenController.scala +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/controller/TokenController.scala @@ -28,7 +28,7 @@ import org.springframework.security.core.Authentication import org.springframework.web.bind.annotation._ import org.springframework.web.server.ResponseStatusException import za.co.absa.loginsvc.model.User -import za.co.absa.loginsvc.rest.model.{KerberosUserDetails, PublicKey, TokensWrapper} +import za.co.absa.loginsvc.rest.model.{KerberosUserDetails, PublicKey, PublicKeySet, TokensWrapper} import za.co.absa.loginsvc.rest.service.jwt.JWTService import za.co.absa.loginsvc.utils.OptionUtils.ImplicitBuilderExt @@ -144,16 +144,41 @@ class TokenController @Autowired()(jwtService: JWTService) { ) @ResponseStatus(HttpStatus.OK) def getPublicKey(): CompletableFuture[PublicKey] = { - val publicKey = jwtService.publicKey + val (publicKey, _) = jwtService.publicKeys val publicKeyBase64 = Base64.getEncoder.encodeToString(publicKey.getEncoded) - Future.successful(PublicKey(publicKeyBase64)) } + @Tags(Array(new Tag(name = "token"))) + @Operation( + summary = "Gives payload with the current and previously rotated RSA256 public key", + description = """Alternative to /public-key - exposes current and previous public keys allowing users to verify a JWT after rotation.""", + responses = Array( + new ApiResponse(responseCode = "200", description = "Payload containing current and previous public keys is returned", + content = Array(new Content( + schema = new Schema(implementation = classOf[PublicKey]), + examples = Array(new ExampleObject(value = "{\n \"keys\": [\n {\n \"key\": \"ABCDEFGH1234\"\n}\n]\n}"))) + ) + ) + ) + ) + @GetMapping( + path = Array("/public-keys"), + produces = Array(MediaType.APPLICATION_JSON_VALUE) + ) + @ResponseStatus(HttpStatus.OK) + def getAllPublicKeys(): CompletableFuture[PublicKeySet] = { + val (primaryPublicKey, optionalPublicKey) = jwtService.publicKeys + val currentPublicKey = PublicKey(Base64.getEncoder.encodeToString(primaryPublicKey.getEncoded)) + val previousPublicKey = optionalPublicKey.map(pk => + PublicKey(Base64.getEncoder.encodeToString(pk.getEncoded))) + Future.successful(PublicKeySet(keys = currentPublicKey :: previousPublicKey.toList)) + } + @Tags(Array(new Tag(name = "token"))) @Operation( summary = "Gives payload with the RSA256 public key in JWKS format", - description = "Returns the same information as /token/public-key, but as a JSON Web Key Set", + description = "Returns the same information as /token/public-keys, but as a JSON Web Key Set", responses = Array( new ApiResponse(responseCode = "200", description = "Success", content = Array(new Content(examples = Array(new ExampleObject(value = """{"keys":[{"kty": "EC","crv": "P-256","x": "MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4","y": "4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM","use": "enc","kid": "1"},{"kty": "RSA","n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw","e": "AQAB","alg": "RS256","kid": "2011-04-29"}]}"""))))), )) @@ -163,10 +188,10 @@ class TokenController @Autowired()(jwtService: JWTService) { ) @ResponseStatus(HttpStatus.OK) def getPublicKeyJwks(): CompletableFuture[Map[String, AnyRef]] = { - val jwks = jwtService.jwks + val jwk = jwtService.jwks import scala.collection.JavaConverters._ - Future.successful(jwks.toJSONObject(true).asScala.toMap) + Future.successful(jwk.toJSONObject(true).asScala.toMap) } } diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/model/AwsSecret.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/model/AwsSecret.scala new file mode 100644 index 00000000..be128937 --- /dev/null +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/model/AwsSecret.scala @@ -0,0 +1,21 @@ +/* + * Copyright 2023 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.loginsvc.rest.model + +import java.time.Instant + +case class AwsSecret(secretValue: Map[String, String], createTime: Instant) diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/model/PublicKey.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/model/PublicKey.scala index 9b2bb719..cd1ee0ce 100644 --- a/api/src/main/scala/za/co/absa/loginsvc/rest/model/PublicKey.scala +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/model/PublicKey.scala @@ -22,6 +22,7 @@ import io.swagger.v3.oas.annotations.media.Schema.RequiredMode case class PublicKey( @JsonProperty("key") - @Schema(example = "ABCDEFGH1234", requiredMode = RequiredMode.REQUIRED) + @Schema(example = "ABCDEFGH1234", requiredMode = RequiredMode.REQUIRED, + description = "The public key currently signing JWTs") key: String ) extends AnyVal diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/model/PublicKeySet.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/model/PublicKeySet.scala new file mode 100644 index 00000000..70f1164a --- /dev/null +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/model/PublicKeySet.scala @@ -0,0 +1,32 @@ +/* + * Copyright 2023 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.loginsvc.rest.model + +import com.fasterxml.jackson.annotation.JsonProperty +import io.swagger.v3.oas.annotations.media.Schema +import io.swagger.v3.oas.annotations.media.Schema.RequiredMode + +case class PublicKeySet( + @JsonProperty("keys") + @Schema(requiredMode = RequiredMode.REQUIRED, + description = "The full set public keys including the one currently signing JWTs", + example = """[ + {"key": "ABCDEFGH1234"}, + {"key": "ZYXWVUT9876"}, + ]""") + keys: List[PublicKey] +) extends AnyVal diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/service/jwt/JWTService.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/service/jwt/JWTService.scala index e5fe8747..459fdbdf 100644 --- a/api/src/main/scala/za/co/absa/loginsvc/rest/service/jwt/JWTService.scala +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/service/jwt/JWTService.scala @@ -17,12 +17,13 @@ package za.co.absa.loginsvc.rest.service.jwt import com.nimbusds.jose.JWSAlgorithm -import com.nimbusds.jose.jwk.{JWKSet, KeyUse, RSAKey} +import com.nimbusds.jose.jwk.{JWK, JWKSet, KeyUse, RSAKey} import io.jsonwebtoken._ import org.slf4j.LoggerFactory import org.springframework.beans.factory.annotation.Autowired import org.springframework.stereotype.Service import za.co.absa.loginsvc.model.User +import za.co.absa.loginsvc.rest.config.jwt.InMemoryKeyConfig import za.co.absa.loginsvc.rest.config.provider.JwtConfigProvider import za.co.absa.loginsvc.rest.model.{AccessToken, RefreshToken, Token} import za.co.absa.loginsvc.rest.service.jwt.JWTService.extractUserFrom @@ -32,7 +33,7 @@ import java.security.interfaces.RSAPublicKey import java.security.{KeyPair, PublicKey} import java.time.Instant import java.util.Date -import java.util.concurrent.{Executors, TimeUnit} +import java.util.concurrent.{ScheduledThreadPoolExecutor, ThreadFactory, TimeUnit} import scala.collection.JavaConverters._ import scala.compat.java8.DurationConverters._ import scala.concurrent.duration.FiniteDuration @@ -41,20 +42,18 @@ import scala.concurrent.duration.FiniteDuration class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchService: UserSearchService) { private val logger = LoggerFactory.getLogger(classOf[JWTService]) - private val scheduler = Executors.newSingleThreadScheduledExecutor(r => { - val t = new Thread(r) - t.setDaemon(true) - t + private val scheduler = new ScheduledThreadPoolExecutor(2, new ThreadFactory { + override def newThread(r: Runnable): Thread = { + val t = new Thread(r) + t.setDaemon(true) + t + } }) private val jwtConfig = jwtConfigProvider.getJwtKeyConfig - @volatile private var keyPair: KeyPair = jwtConfig.keyPair() + @volatile private var (primaryKeyPair: KeyPair, optionalKeyPair: Option[KeyPair]) = jwtConfig.keyPair() - if(jwtConfig.keyRotationTime.nonEmpty) - { - val refreshTime = jwtConfig.keyRotationTime.get - scheduleSecretsRefresh(refreshTime) - } + jwtConfig.keyRotationTime.foreach(scheduleSecretsRefresh) def generateAccessToken(user: User, isRefresh: Boolean = false): AccessToken = { val msgIntro = if (isRefresh) "Refreshing" else "Generating new" @@ -82,7 +81,7 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe }.asJava ) .claim("type", Token.TokenType.Access.toString) - .signWith(keyPair.getPrivate) + .signWith(primaryKeyPair.getPrivate) .compact() AccessToken(tokenContent) @@ -101,7 +100,7 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe .setExpiration(expiration) .setIssuedAt(issuedAt) .claim("type", Token.TokenType.Refresh.toString) - .signWith(keyPair.getPrivate) + .signWith(primaryKeyPair.getPrivate) .compact() RefreshToken(tokenContent) @@ -110,7 +109,7 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe def refreshTokens(accessToken: AccessToken, refreshToken: RefreshToken): (AccessToken, RefreshToken) = { val oldAccessJws: Jws[Claims] = Jwts.parserBuilder() .require("type", Token.TokenType.Access.toString) - .setSigningKey(keyPair.getPublic) + .setSigningKey(primaryKeyPair.getPublic) .setClock(() => Date.from(Instant.now().minus(jwtConfig.refreshExpTime.toJava))) // allowing expired access token - up to refresh token validity window .build() .parseClaimsJws(accessToken.token) // checks requirements: type=access, signature, custom validity window @@ -120,7 +119,7 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe Jwts.parserBuilder() .require("type", Token.TokenType.Refresh.toString) .requireSubject(userFromOldAccessToken.name) - .setSigningKey(keyPair.getPublic) + .setSigningKey(primaryKeyPair.getPublic) .build() .parseClaimsJws(refreshToken.token) // checks username, validity, and signature. @@ -140,13 +139,24 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe (refreshedAccessToken, refreshToken) } - def publicKey: PublicKey = keyPair.getPublic + def publicKeys: (PublicKey, Option[PublicKey]) = { + val currentPublicKey = primaryKeyPair.getPublic + val previousPublicKey = optionalKeyPair.map(_.getPublic) + (currentPublicKey, previousPublicKey) + } - def publicKeyThumbprint: String = rsaPublicKey.getKeyID + def publicKeyThumbprint: String = rsaPublicKey(primaryKeyPair.getPublic).getKeyID def jwks: JWKSet = { - val jwk = rsaPublicKey - new JWKSet(jwk).toPublicJWKSet + val currentJwk = rsaPublicKey(primaryKeyPair.getPublic) + val previousJwk = optionalKeyPair.map(kp => rsaPublicKey(kp.getPublic)) + + val jwkList = previousJwk match { + case Some(previousJwk) => List[JWK](currentJwk, previousJwk) + case None => List[JWK](currentJwk) + } + + new JWKSet(jwkList.asJava) } def close() : Unit = { @@ -165,8 +175,8 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe } } - private def rsaPublicKey: RSAKey = { - publicKey match { + private def rsaPublicKey(key: PublicKey): RSAKey = { + key match { case rsaKey: RSAPublicKey => new RSAKey.Builder(rsaKey) .keyUse(KeyUse.SIGNATURE) .algorithm(JWSAlgorithm.parse(jwtConfig.algName)) @@ -180,9 +190,16 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe val scheduledFuture = scheduler.scheduleAtFixedRate(() => { logger.info("Attempting to Refresh for new Keys") try { - val newKeyPair = jwtConfig.keyPair() + val (newPrimaryKeyPair, newOptionalKeyPair) = jwtConfig.keyPair() logger.info("Keys have been Refreshed") - keyPair = newKeyPair + jwtConfig.keyPhaseOutTime.foreach { kp => { + jwtConfig match { + case i: InMemoryKeyConfig => + scheduleKeyPhaseOut(kp) + } + }} + primaryKeyPair = newPrimaryKeyPair + optionalKeyPair = newOptionalKeyPair } catch { case e: Throwable => @@ -195,12 +212,25 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe ) Runtime.getRuntime.addShutdownHook(new Thread(() => { - scheduledFuture.cancel(false) this.close() })) } + private def scheduleKeyPhaseOut(phaseOutTime: FiniteDuration): Unit = { + val scheduledFuture = scheduler.schedule(new Runnable { + override def run(): Unit = { + logger.info("Phasing out previous KeyPair.") + optionalKeyPair = None + } + }, phaseOutTime.toMillis, TimeUnit.MILLISECONDS) + + Runtime.getRuntime.addShutdownHook(new Thread(() => { + scheduledFuture.cancel(false) + this.close() + })) + } + def getConfiguredRefreshExpDuration: FiniteDuration = jwtConfig.refreshExpTime } diff --git a/api/src/main/scala/za/co/absa/loginsvc/utils/AwsSecretsUtils.scala b/api/src/main/scala/za/co/absa/loginsvc/utils/AwsSecretsUtils.scala index 29027d97..939ed9a5 100644 --- a/api/src/main/scala/za/co/absa/loginsvc/utils/AwsSecretsUtils.scala +++ b/api/src/main/scala/za/co/absa/loginsvc/utils/AwsSecretsUtils.scala @@ -22,11 +22,17 @@ import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider import software.amazon.awssdk.regions.Region import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient import software.amazon.awssdk.services.secretsmanager.model.{GetSecretValueRequest, GetSecretValueResponse} +import za.co.absa.loginsvc.rest.model.AwsSecret object AwsSecretsUtils { private val logger = LoggerFactory.getLogger(getClass) - def fetchSecret(secretName: String, region: String, secretFields: Array[String]): Map[String, String] = { + def fetchSecret( + secretName: String, + region: String, + secretFields: Array[String], + versionStage: Option[String] = None + ): Option[AwsSecret] = { val default = DefaultCredentialsProvider.create @@ -35,7 +41,9 @@ object AwsSecretsUtils { .credentialsProvider(default) .build - val getSecretValueRequest = GetSecretValueRequest.builder.secretId(secretName).build + val getSecretValueRequestBuilder = GetSecretValueRequest.builder.secretId(secretName) + versionStage.foreach(getSecretValueRequestBuilder.versionStage) + val getSecretValueRequest = getSecretValueRequestBuilder.build() try { logger.info("Attempting to fetch secret from AWS Secrets Manager") @@ -44,14 +52,17 @@ object AwsSecretsUtils { logger.info("secret retrieved. Attempting to Parse data") val rootNode: JsonNode = new ObjectMapper().readTree(secret) - secretFields.map(field => { + val secretValues = secretFields.map(field => { field -> rootNode.get(field).asText() }).toMap + val createTime = getSecretValueResponse.createdDate() + + Option(AwsSecret(secretValues, createTime)) } catch { case e: Throwable => - logger.error(s"Error occurred retrieving and parsing secrets from AWS Secrets Manager", e) - throw e + logger.warn(s"Error occurred retrieving and parsing secrets from AWS Secrets Manager", e) + None } } } diff --git a/api/src/test/resources/application.yaml b/api/src/test/resources/application.yaml index d5785638..95b418ef 100644 --- a/api/src/test/resources/application.yaml +++ b/api/src/test/resources/application.yaml @@ -6,6 +6,7 @@ loginsvc: access-exp-time: 15min refresh-exp-time: 10h key-rotation-time: 5sec + key-phase-out-time: 3sec alg-name: "RS256" # Rest Auth Config (AD) diff --git a/api/src/test/scala/za/co/absa/loginsvc/rest/config/jwt/AwsSecretsManagerKeyConfigTest.scala b/api/src/test/scala/za/co/absa/loginsvc/rest/config/jwt/AwsSecretsManagerKeyConfigTest.scala index e7f3735d..22d16419 100644 --- a/api/src/test/scala/za/co/absa/loginsvc/rest/config/jwt/AwsSecretsManagerKeyConfigTest.scala +++ b/api/src/test/scala/za/co/absa/loginsvc/rest/config/jwt/AwsSecretsManagerKeyConfigTest.scala @@ -21,8 +21,7 @@ import org.scalatest.matchers.should.Matchers import za.co.absa.loginsvc.rest.config.validation.ConfigValidationException import za.co.absa.loginsvc.rest.config.validation.ConfigValidationResult.{ConfigValidationError, ConfigValidationSuccess} -import java.util.concurrent.TimeUnit -import scala.concurrent.duration.FiniteDuration +import scala.concurrent.duration._ class AwsSecretsManagerKeyConfigTest extends AnyFlatSpec with Matchers { @@ -31,9 +30,10 @@ class AwsSecretsManagerKeyConfigTest extends AnyFlatSpec with Matchers { "private", "public", "RS256", - FiniteDuration(15, TimeUnit.MINUTES), - FiniteDuration(9, TimeUnit.HOURS), - Option(FiniteDuration(30, TimeUnit.MINUTES))) + 15.minutes, + 9.minutes, + Option(30.minutes), + Option(15.minutes)) "awsSecretsManagerKeyConfig" should "validate expected content" in { awsSecretsManagerKeyConfig.validate() shouldBe ConfigValidationSuccess @@ -45,20 +45,30 @@ class AwsSecretsManagerKeyConfigTest extends AnyFlatSpec with Matchers { } it should "fail on non-negative accessExpTime" in { - awsSecretsManagerKeyConfig.copy(accessExpTime = FiniteDuration(5, TimeUnit.MILLISECONDS)).validate() shouldBe + awsSecretsManagerKeyConfig.copy(accessExpTime = 5.milliseconds).validate() shouldBe ConfigValidationError(ConfigValidationException(s"accessExpTime must be at least ${KeyConfig.minAccessExpTime}")) } it should "fail on non-negative refreshExpTime" in { - awsSecretsManagerKeyConfig.copy(refreshExpTime = FiniteDuration(5, TimeUnit.MILLISECONDS)).validate() shouldBe + awsSecretsManagerKeyConfig.copy(refreshExpTime = 5.milliseconds).validate() shouldBe ConfigValidationError(ConfigValidationException(s"refreshExpTime must be at least ${KeyConfig.minRefreshExpTime}")) } it should "fail on non-negative keyRotationTime" in { - awsSecretsManagerKeyConfig.copy(pollTime = Option(FiniteDuration(5, TimeUnit.MILLISECONDS))).validate() shouldBe + awsSecretsManagerKeyConfig.copy(pollTime = Option(5.milliseconds)).validate() shouldBe ConfigValidationError(ConfigValidationException(s"keyRotationTime must be at least ${KeyConfig.minKeyRotationTime}")) } + it should "fail on non-negative keyPhaseOutTime" in { + awsSecretsManagerKeyConfig.copy(keyPhaseOutTime = Option(5.milliseconds)).validate() shouldBe + ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime must be at least ${KeyConfig.minKeyPhaseOutTime}")) + } + + it should "fail on keyPhaseOutTime being configured without keyRotationTime" in { + awsSecretsManagerKeyConfig.copy(pollTime = None).validate() shouldBe + ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime can only be enable if keyRotationTime is enable!")) + } + it should "fail on missing value" in { awsSecretsManagerKeyConfig.copy(secretName = null).validate() shouldBe ConfigValidationError(ConfigValidationException("secretName is empty")) diff --git a/api/src/test/scala/za/co/absa/loginsvc/rest/config/jwt/InMemoryKeyConfigTest.scala b/api/src/test/scala/za/co/absa/loginsvc/rest/config/jwt/InMemoryKeyConfigTest.scala index e275fc24..1144961e 100644 --- a/api/src/test/scala/za/co/absa/loginsvc/rest/config/jwt/InMemoryKeyConfigTest.scala +++ b/api/src/test/scala/za/co/absa/loginsvc/rest/config/jwt/InMemoryKeyConfigTest.scala @@ -21,7 +21,6 @@ import org.scalatest.matchers.should.Matchers import za.co.absa.loginsvc.rest.config.validation.ConfigValidationException import za.co.absa.loginsvc.rest.config.validation.ConfigValidationResult.{ConfigValidationError, ConfigValidationSuccess} -import java.util.concurrent.TimeUnit import scala.concurrent.duration._ class InMemoryKeyConfigTest extends AnyFlatSpec with Matchers { @@ -29,7 +28,8 @@ class InMemoryKeyConfigTest extends AnyFlatSpec with Matchers { val inMemoryKeyConfig: InMemoryKeyConfig = InMemoryKeyConfig("RS256", 15.minutes, 2.hours, - Option(30.minutes)) + Option(30.minutes), + Option(15.minutes)) "inMemoryKeyConfig" should "validate expected content" in { inMemoryKeyConfig.validate() shouldBe ConfigValidationSuccess @@ -41,17 +41,32 @@ class InMemoryKeyConfigTest extends AnyFlatSpec with Matchers { } it should "fail on non-negative accessExpTime" in { - inMemoryKeyConfig.copy(accessExpTime = FiniteDuration(5, TimeUnit.MILLISECONDS)).validate() shouldBe + inMemoryKeyConfig.copy(accessExpTime = 5.milliseconds).validate() shouldBe ConfigValidationError(ConfigValidationException(s"accessExpTime must be at least ${KeyConfig.minAccessExpTime}")) } it should "fail on non-negative refreshExpTime" in { - inMemoryKeyConfig.copy(refreshExpTime = FiniteDuration(5, TimeUnit.MILLISECONDS)).validate() shouldBe + inMemoryKeyConfig.copy(refreshExpTime = 5.milliseconds).validate() shouldBe ConfigValidationError(ConfigValidationException(s"refreshExpTime must be at least ${KeyConfig.minRefreshExpTime}")) } it should "fail on non-negative keyRotationTime" in { - inMemoryKeyConfig.copy(keyRotationTime = Option(FiniteDuration(5, TimeUnit.MILLISECONDS))).validate() shouldBe + inMemoryKeyConfig.copy(keyRotationTime = Option(5.milliseconds), keyPhaseOutTime = None).validate() shouldBe ConfigValidationError(ConfigValidationException(s"keyRotationTime must be at least ${KeyConfig.minKeyRotationTime}")) } + + it should "fail on non-negative keyPhaseOutTime" in { + inMemoryKeyConfig.copy(keyPhaseOutTime = Option(5.milliseconds)).validate() shouldBe + ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime must be at least ${KeyConfig.minKeyPhaseOutTime}")) + } + + it should "fail on keyPhaseOutTime being configured without keyRotationTime" in { + inMemoryKeyConfig.copy(keyRotationTime = None).validate() shouldBe + ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime can only be enable if keyRotationTime is enable!")) + } + + it should "fail on keyPhaseOutTime being larger than keyRotationTime" in { + inMemoryKeyConfig.copy(keyRotationTime = Option(10.minutes)).validate() shouldBe + ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime must be lower than keyRotationTime!")) + } } diff --git a/api/src/test/scala/za/co/absa/loginsvc/rest/controller/TokenControllerTest.scala b/api/src/test/scala/za/co/absa/loginsvc/rest/controller/TokenControllerTest.scala index f1be56ab..e13107f4 100644 --- a/api/src/test/scala/za/co/absa/loginsvc/rest/controller/TokenControllerTest.scala +++ b/api/src/test/scala/za/co/absa/loginsvc/rest/controller/TokenControllerTest.scala @@ -181,7 +181,7 @@ class TokenControllerTest extends AnyFlatSpec it should "return a Base64 encoded public key from JWTService when user is authenticated" in { val publicKey = Keys.keyPairFor(SignatureAlgorithm.RS256).getPublic - when(jwtService.publicKey).thenReturn(publicKey) + when(jwtService.publicKeys).thenReturn((publicKey, None)) val expectedPublicKeyBase64 = Base64.getEncoder.encodeToString(publicKey.getEncoded) @@ -194,7 +194,7 @@ class TokenControllerTest extends AnyFlatSpec it should "return a Base64 encoded public key from JWTService when user is not authenticated" in { val publicKey = Keys.keyPairFor(SignatureAlgorithm.RS256).getPublic - when(jwtService.publicKey).thenReturn(publicKey) + when(jwtService.publicKeys).thenReturn((publicKey, None)) val expectedPublicKeyBase64 = Base64.getEncoder.encodeToString(publicKey.getEncoded) @@ -205,6 +205,98 @@ class TokenControllerTest extends AnyFlatSpec )(auth = None) } + it should "return only a single Base64 encoded public key from JWTService previous and current keys are available" in { + val publicKey = Keys.keyPairFor(SignatureAlgorithm.RS256).getPublic + val secondaryPublicKey = Keys.keyPairFor(SignatureAlgorithm.RS256).getPublic + when(jwtService.publicKeys).thenReturn((publicKey, Option(secondaryPublicKey))) + + val expectedPublicKeyBase64 = Base64.getEncoder.encodeToString(publicKey.getEncoded) + + assertExpectedResponseFields( + "/token/public-key", + Get())( + expectedJsonBody = s"""{"key": "$expectedPublicKeyBase64"}""" + )(auth = None) + } + + behavior of "getAllPublicKeys" + + it should "return a single Base64 encoded public key from JWTService when user is authenticated" in { + val publicKey = Keys.keyPairFor(SignatureAlgorithm.RS256).getPublic + when(jwtService.publicKeys).thenReturn((publicKey, None)) + + val expectedPublicKeyBase64 = Base64.getEncoder.encodeToString(publicKey.getEncoded) + val expectedResponse = + s""" + |{ + | "keys": [ + | { + | "key": "$expectedPublicKeyBase64" + | } + | ] + |} + |""".stripMargin + + assertExpectedResponseFields( + "/token/public-keys", + Get())( + expectedJsonBody = expectedResponse + )(auth = None) + } + + it should "return a single Base64 encoded public key from JWTService when user is not authenticated" in { + val publicKey = Keys.keyPairFor(SignatureAlgorithm.RS256).getPublic + when(jwtService.publicKeys).thenReturn((publicKey, None)) + + val expectedPublicKeyBase64 = Base64.getEncoder.encodeToString(publicKey.getEncoded) + val expectedResponse = + s""" + |{ + | "keys": [ + | { + | "key": "$expectedPublicKeyBase64" + | } + | ] + |} + |""".stripMargin + + + assertExpectedResponseFields( + "/token/public-keys", + Get())( + expectedJsonBody = expectedResponse + )(auth = None) + } + + it should "return both the current and previous Base64 encoded public keys from JWTService" in { + val publicKey = Keys.keyPairFor(SignatureAlgorithm.RS256).getPublic + val secondaryPublicKey = Keys.keyPairFor(SignatureAlgorithm.RS256).getPublic + when(jwtService.publicKeys).thenReturn((publicKey, Option(secondaryPublicKey))) + + val expectedPublicKeyBase64 = Base64.getEncoder.encodeToString(publicKey.getEncoded) + val expectedSecondaryPublicKeyBase64 = Base64.getEncoder.encodeToString(secondaryPublicKey.getEncoded) + val expectedResponse = + s""" + |{ + | "keys": [ + | { + | "key": "$expectedPublicKeyBase64" + | }, + | { + | "key": "$expectedSecondaryPublicKeyBase64" + | } + | ] + |} + |""".stripMargin + + + assertExpectedResponseFields( + "/token/public-keys", + Get())( + expectedJsonBody = expectedResponse + )(auth = None) + } + behavior of "getPublicKeyJwks" it should "return a JWKS from JWTService when user is authenticated" in { diff --git a/api/src/test/scala/za/co/absa/loginsvc/rest/service/jwt/JWTServiceTest.scala b/api/src/test/scala/za/co/absa/loginsvc/rest/service/jwt/JWTServiceTest.scala index 0c8559ef..78324196 100644 --- a/api/src/test/scala/za/co/absa/loginsvc/rest/service/jwt/JWTServiceTest.scala +++ b/api/src/test/scala/za/co/absa/loginsvc/rest/service/jwt/JWTServiceTest.scala @@ -54,7 +54,7 @@ class JWTServiceTest extends AnyFlatSpec with BeforeAndAfterEach with Matchers { groups = Seq("group2") ) - private def parseJWT(jwt: Token, publicKey: PublicKey = jwtService.publicKey): Try[Jws[Claims]] = Try { + private def parseJWT(jwt: Token, publicKey: PublicKey = jwtService.publicKeys._1): Try[Jws[Claims]] = Try { Jwts.parserBuilder().setSigningKey(publicKey).build().parseClaimsJws(jwt.token) } @@ -205,7 +205,7 @@ class JWTServiceTest extends AnyFlatSpec with BeforeAndAfterEach with Matchers { def customTimedJwtService(accessExpTime: FiniteDuration, refreshExpTime: FiniteDuration): JWTService = { val configP = new JwtConfigProvider { override def getJwtKeyConfig: KeyConfig = InMemoryKeyConfig( - "RS256", accessExpTime, refreshExpTime, None + "RS256", accessExpTime, refreshExpTime, None, None ) } @@ -221,10 +221,10 @@ class JWTServiceTest extends AnyFlatSpec with BeforeAndAfterEach with Matchers { val refreshJwt = customJwtService.generateRefreshToken(userWithGroups) Thread.sleep(3 * 1000) // make sure that access is past due - as set above - parseJWT(accessJwt, customJwtService.publicKey).isFailure shouldBe true // expired + parseJWT(accessJwt, customJwtService.publicKeys._1).isFailure shouldBe true // expired val (refreshedAccessJwt, _) = customJwtService.refreshTokens(accessJwt, refreshJwt) - val parsedRefreshedAccessJWT = parseJWT(refreshedAccessJwt, customJwtService.publicKey) + val parsedRefreshedAccessJWT = parseJWT(refreshedAccessJwt, customJwtService.publicKeys._1) assert(parsedRefreshedAccessJWT.isSuccess) parsedRefreshedAccessJWT match { @@ -247,7 +247,7 @@ class JWTServiceTest extends AnyFlatSpec with BeforeAndAfterEach with Matchers { val refreshJwt = customJwtService.generateRefreshToken(userWithGroups) Thread.sleep(2 * 1000) // make sure that refresh is past due - as set above - parseJWT(refreshJwt, customJwtService.publicKey).isFailure shouldBe true // expired + parseJWT(refreshJwt, customJwtService.publicKeys._1).isFailure shouldBe true // expired an[ExpiredJwtException] should be thrownBy { customJwtService.refreshTokens(accessJwt, refreshJwt) @@ -260,11 +260,11 @@ class JWTServiceTest extends AnyFlatSpec with BeforeAndAfterEach with Matchers { it should "return a JWK that is equivalent to the `publicKey`" in { import scala.collection.JavaConverters._ - val publicKey = jwtService.publicKey + val publicKey = jwtService.publicKeys val jwks = jwtService.jwks val rsaKey = jwks.getKeys.asScala.head.toRSAKey - assert(publicKey == rsaKey.toPublicKey) + assert(publicKey._1 == rsaKey.toPublicKey) } it should "return a JWK with parameters" in { @@ -278,15 +278,36 @@ class JWTServiceTest extends AnyFlatSpec with BeforeAndAfterEach with Matchers { assert(jwk.getKeyUse == KeyUse.SIGNATURE) } + behavior of "keyRotation" + it should "rotate an public and private keys after 5 seconds" in { val initToken = jwtService.generateAccessToken(userWithoutGroups) - val initPublicKey = jwtService.publicKey + val initPublicKey = jwtService.publicKeys Thread.sleep(6 * 1000) val refreshedToken = jwtService.generateAccessToken(userWithoutGroups) assert(parseJWT(initToken).isFailure) assert(parseJWT(refreshedToken).isSuccess) - assert(initPublicKey != jwtService.publicKey) + assert(initPublicKey != jwtService.publicKeys) + assert(initPublicKey._1 != jwtService.publicKeys._1) + assert(initPublicKey._1 == jwtService.publicKeys._2.orNull) + } + + it should "phase out older keys after 8 seconds" in { + val initToken = jwtService.generateAccessToken(userWithoutGroups) + val initPublicKey = jwtService.publicKeys + + Thread.sleep(6 * 1000) + val refreshedToken = jwtService.generateAccessToken(userWithoutGroups) + + assert(parseJWT(initToken).isFailure) + assert(parseJWT(refreshedToken).isSuccess) + assert(initPublicKey != jwtService.publicKeys) + assert(initPublicKey._1 != jwtService.publicKeys._1) + assert(initPublicKey._1 == jwtService.publicKeys._2.orNull) + + Thread.sleep(3 * 1000) + assert(jwtService.publicKeys._2.isEmpty) } } diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 47465f78..91da1df9 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -61,6 +61,7 @@ object Dependencies { lazy val awsSecrets = "software.amazon.awssdk" % "secretsmanager" % "2.20.68" lazy val awsSts = "software.amazon.awssdk" % "sts" % "2.20.69" + lazy val awsSoo = "software.amazon.awssdk" % "sso" % "2.20.69" lazy val servletApi = "javax.servlet" % "javax.servlet-api" % "3.0.1" % Provided @@ -100,6 +101,7 @@ object Dependencies { awsSecrets, awsSts, + awsSoo, springDoc,