Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class SecurityConfig @Autowired()(authConfigsProvider: AuthConfigProvider) {
"/actuator/**",
"/token/refresh", // access+refresh JWT in payload, no auth
"/token/public-key-jwks",
"/token/public-key").permitAll()
"/token/public-key",
"/token/public-key/all").permitAll()
.anyRequest().authenticated()
.and()
.sessionManagement()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,14 @@ case class AwsSecretsLdapUserConfig(private val secretName: String,

private def getUsernameAndPasswordFromSecret: (String, String) = {
try {
val secrets = AwsSecretsUtils.fetchSecret(secretName, region, Array(usernameFieldName, passwordFieldName))
val secretsOption = AwsSecretsUtils.fetchSecret(secretName, region, Array(usernameFieldName, passwordFieldName))

if(secretsOption.isEmpty)
throw new Exception("Error retrieving username and password from from AWS Secrets Manager")

val secrets = secretsOption.get
Copy link
Collaborator

@dk1844 dk1844 Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think option processing can be done better without calling Option.get e.g. via

      secretsOption.fold(
        throw new Exception("Error retrieving username and password from from AWS Secrets Manager")
      ){ secrets =>
        (secrets(usernameFieldName), secrets(passwordFieldName))  
      }

or more explicitly via

      secretsOption match {
        case None => throw new Exception("Error retrieving username and password from from AWS Secrets Manager")
        case Some(secrets) =>
          (secrets(usernameFieldName), secrets(passwordFieldName))
      }

(secrets(usernameFieldName), secrets(passwordFieldName))
}
catch {
} catch {
case e: Throwable =>
logger.error(s"Error occurred retrieving account data from AWS Secrets Manager", e)
throw e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,32 +40,47 @@ case class AwsSecretsManagerKeyConfig(
private val logger = LoggerFactory.getLogger(classOf[AwsSecretsManagerKeyConfig])

override def keyRotationTime : Option[FiniteDuration] = pollTime
override def keyPair(): KeyPair = {
override def keyPair(): (KeyPair, Option[KeyPair]) = {
try {
val secrets = AwsSecretsUtils.fetchSecret(secretName, region, Array(privateKeyFieldName, publicKeyFieldName))

val publicKeySpec: X509EncodedKeySpec = new X509EncodedKeySpec(
Base64.getDecoder.decode(
secrets(publicKeyFieldName)
)
)
val privateKeySpec: PKCS8EncodedKeySpec = new PKCS8EncodedKeySpec(
Base64.getDecoder.decode(
secrets(privateKeyFieldName)
)
val currentSecretsOption = AwsSecretsUtils.fetchSecret(
secretName,
region,
Array(privateKeyFieldName, publicKeyFieldName)
)

logger.info("Key Data successfully retrieved and parsed from AWS Secrets Manager")
if(currentSecretsOption.isEmpty)
throw new Exception("Error retrieving AWSCURRENT key from from AWS Secrets Manager")

val keyFactory: KeyFactory = KeyFactory.getInstance(jwtAlgorithmToCryptoAlgorithm)
new KeyPair(keyFactory.generatePublic(publicKeySpec), keyFactory.generatePrivate(privateKeySpec))
}
catch {
val currentKeyPair = createKeyPair(currentSecretsOption.get)
logger.info("AWSCURRENT Key Data successfully retrieved and parsed from AWS Secrets Manager")

val previousSecretsOption = AwsSecretsUtils.fetchSecret(
secretName,
region,
Array(privateKeyFieldName, publicKeyFieldName),
Some("AWSPREVIOUS")
)

val previousKeyPair = previousSecretsOption.flatMap { previousSecrets =>
try {
val keys = createKeyPair(previousSecrets)
logger.info("AWSPREVIOUS Key Data successfully retrieved and parsed from AWS Secrets Manager")
Some(keys)
} catch {
case e: Throwable =>
logger.warn(s"Error occurred decoding AWSPREVIOUSKEYS, skipping previous keys.", e)
None
}
}

(currentKeyPair, previousKeyPair)
} catch {
case e: Throwable =>
logger.error(s"Error occurred retrieving and decoding keys from AWS Secrets Manager", e)
throw e
}
}

override def throwErrors(): Unit = this.validate().throwOnErrors()

override def validate(): ConfigValidationResult = {
Expand All @@ -92,4 +107,21 @@ case class AwsSecretsManagerKeyConfig(

super.validate().merge(awsSecretsResultsMerge)
}

private def createKeyPair(secretKeys: Map[String, String]): KeyPair = {

val publicKeySpec: X509EncodedKeySpec = new X509EncodedKeySpec(
Base64.getDecoder.decode(
secretKeys(publicKeyFieldName)
)
)
val privateKeySpec: PKCS8EncodedKeySpec = new PKCS8EncodedKeySpec(
Base64.getDecoder.decode(
secretKeys(privateKeyFieldName)
)
)

val keyFactory: KeyFactory = KeyFactory.getInstance(jwtAlgorithmToCryptoAlgorithm)
new KeyPair(keyFactory.generatePublic(publicKeySpec), keyFactory.generatePrivate(privateKeySpec))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ case class InMemoryKeyConfig(
keyRotationTime: Option[FiniteDuration]
) extends KeyConfig {

private var oldKeyPair: Option[KeyPair] = None
private val logger = LoggerFactory.getLogger(classOf[InMemoryKeyConfig])

override def keyPair(): KeyPair = {
override def keyPair(): (KeyPair, Option[KeyPair]) = {
logger.info(s"Generating new keys - every ${keyRotationTime.getOrElse("?")}")
Keys.keyPairFor(SignatureAlgorithm.valueOf(algName))
val newKeyPair = Keys.keyPairFor(SignatureAlgorithm.valueOf(algName))
val result = (newKeyPair, oldKeyPair)
oldKeyPair = Some(newKeyPair)
result
}

override def throwErrors(): Unit = this.validate().throwOnErrors()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ trait KeyConfig extends ConfigValidatable {
def accessExpTime: FiniteDuration
def refreshExpTime: FiniteDuration
def keyRotationTime: Option[FiniteDuration]
def keyPair(): KeyPair
def keyPair(): (KeyPair, Option[KeyPair])
def throwErrors(): Unit

final def jwtAlgorithmToCryptoAlgorithm : String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,36 @@ class TokenController @Autowired()(jwtService: JWTService) {
@ResponseStatus(HttpStatus.OK)
def getPublicKey(): CompletableFuture[PublicKey] = {
val publicKey = jwtService.publicKey
val publicKeyBase64 = Base64.getEncoder.encodeToString(publicKey.getEncoded)

val publicKeyBase64 = Base64.getEncoder.encodeToString(publicKey._1.getEncoded)
Future.successful(PublicKey(publicKeyBase64))
}

@Tags(Array(new Tag(name = "token")))
@Operation(
summary = "Gives payload with the current and previously rotated RSA256 public key",
description = """Alternative to /public-key - exposes current and previous public keys allowing users to verify a JWT after rotation.""",
responses = Array(
new ApiResponse(responseCode = "200", description = "Payload containing current and previous public keys is returned",
content = Array(new Content(
schema = new Schema(implementation = classOf[PublicKey]),
examples = Array(new ExampleObject(value = "{\n \"key\": \"ABCDEFGH1234\"\n}")))
)
)
)
)
@GetMapping(
path = Array("/public-key/all"),
produces = Array(MediaType.APPLICATION_JSON_VALUE)
)
@ResponseStatus(HttpStatus.OK)
def getAllPublicKeys(): CompletableFuture[PublicKey] = {
val publicKey = jwtService.publicKey
val currentPublicKeyBase64 = Base64.getEncoder.encodeToString(publicKey._1.getEncoded)
val previousKeyBase64 = publicKey._2.map(pk =>
Base64.getEncoder.encodeToString(pk.getEncoded))
Future.successful(PublicKey(currentPublicKeyBase64, previousKeyBase64))
}

@Tags(Array(new Tag(name = "token")))
@Operation(
summary = "Gives payload with the RSA256 public key in JWKS format",
Expand All @@ -163,10 +188,10 @@ class TokenController @Autowired()(jwtService: JWTService) {
)
@ResponseStatus(HttpStatus.OK)
def getPublicKeyJwks(): CompletableFuture[Map[String, AnyRef]] = {
val jwks = jwtService.jwks
val jwk = jwtService.jwks

import scala.collection.JavaConverters._
Future.successful(jwks.toJSONObject(true).asScala.toMap)
Future.successful(jwk.toJSONObject(true).asScala.toMap)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ import io.swagger.v3.oas.annotations.media.Schema.RequiredMode

case class PublicKey(
@JsonProperty("key")
@Schema(example = "ABCDEFGH1234", requiredMode = RequiredMode.REQUIRED)
key: String
) extends AnyVal
@Schema(example = "ABCDEFGH1234", requiredMode = RequiredMode.REQUIRED,
description = "The public key currently signing JWTs")
key: String,

@JsonProperty("previousKey")
@Schema(
example = "ZYXWVUT9876", requiredMode = RequiredMode.NOT_REQUIRED,
description = "The previous public key, if available."
)
previousKey: Option[String] = None
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package za.co.absa.loginsvc.rest.service.jwt

import com.nimbusds.jose.JWSAlgorithm
import com.nimbusds.jose.jwk.{JWKSet, KeyUse, RSAKey}
import com.nimbusds.jose.jwk.{JWK, JWKSet, KeyUse, RSAKey}
import io.jsonwebtoken._
import org.slf4j.LoggerFactory
import org.springframework.beans.factory.annotation.Autowired
Expand Down Expand Up @@ -48,7 +48,7 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe
})

private val jwtConfig = jwtConfigProvider.getJwtKeyConfig
@volatile private var keyPair: KeyPair = jwtConfig.keyPair()
@volatile private var keyPair: (KeyPair, Option[KeyPair]) = jwtConfig.keyPair()

if(jwtConfig.keyRotationTime.nonEmpty)
{
Expand Down Expand Up @@ -82,7 +82,7 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe
}.asJava
)
.claim("type", Token.TokenType.Access.toString)
.signWith(keyPair.getPrivate)
.signWith(keyPair._1.getPrivate)
.compact()

AccessToken(tokenContent)
Expand All @@ -101,7 +101,7 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe
.setExpiration(expiration)
.setIssuedAt(issuedAt)
.claim("type", Token.TokenType.Refresh.toString)
.signWith(keyPair.getPrivate)
.signWith(keyPair._1.getPrivate)
.compact()

RefreshToken(tokenContent)
Expand All @@ -110,7 +110,7 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe
def refreshTokens(accessToken: AccessToken, refreshToken: RefreshToken): (AccessToken, RefreshToken) = {
val oldAccessJws: Jws[Claims] = Jwts.parserBuilder()
.require("type", Token.TokenType.Access.toString)
.setSigningKey(keyPair.getPublic)
.setSigningKey(keyPair._1.getPublic)
.setClock(() => Date.from(Instant.now().minus(jwtConfig.refreshExpTime.toJava))) // allowing expired access token - up to refresh token validity window
.build()
.parseClaimsJws(accessToken.token) // checks requirements: type=access, signature, custom validity window
Expand All @@ -120,7 +120,7 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe
Jwts.parserBuilder()
.require("type", Token.TokenType.Refresh.toString)
.requireSubject(userFromOldAccessToken.name)
.setSigningKey(keyPair.getPublic)
.setSigningKey(keyPair._1.getPublic)
.build()
.parseClaimsJws(refreshToken.token) // checks username, validity, and signature.

Expand All @@ -140,13 +140,24 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe
(refreshedAccessToken, refreshToken)
}

def publicKey: PublicKey = keyPair.getPublic
def publicKey: (PublicKey, Option[PublicKey]) = {
val currentPublicKey = keyPair._1.getPublic
val previousPublicKey = keyPair._2.map(_.getPublic)
(currentPublicKey, previousPublicKey)
}

def publicKeyThumbprint: String = rsaPublicKey.getKeyID
def publicKeyThumbprint: String = rsaPublicKey(publicKey._1).getKeyID

def jwks: JWKSet = {
val jwk = rsaPublicKey
new JWKSet(jwk).toPublicJWKSet
val currentJwk = rsaPublicKey(publicKey._1)
val previousJwk = publicKey._2.map(rsaPublicKey)

val jwkList = previousJwk match {
case Some(previousJwk) => List[JWK](currentJwk, previousJwk)
case None => List[JWK](currentJwk)
}

new JWKSet(jwkList.asJava)
}

def close() : Unit = {
Expand All @@ -165,8 +176,8 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe
}
}

private def rsaPublicKey: RSAKey = {
publicKey match {
private def rsaPublicKey(key: PublicKey): RSAKey = {
key match {
case rsaKey: RSAPublicKey => new RSAKey.Builder(rsaKey)
.keyUse(KeyUse.SIGNATURE)
.algorithm(JWSAlgorithm.parse(jwtConfig.algName))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ import software.amazon.awssdk.services.secretsmanager.model.{GetSecretValueReque
object AwsSecretsUtils {

private val logger = LoggerFactory.getLogger(getClass)
def fetchSecret(secretName: String, region: String, secretFields: Array[String]): Map[String, String] = {
def fetchSecret(
secretName: String,
region: String,
secretFields: Array[String],
versionStage: Option[String] = None
): Option[Map[String, String]] = {

val default = DefaultCredentialsProvider.create

Expand All @@ -35,7 +40,9 @@ object AwsSecretsUtils {
.credentialsProvider(default)
.build

val getSecretValueRequest = GetSecretValueRequest.builder.secretId(secretName).build
val getSecretValueRequestBuilder = GetSecretValueRequest.builder.secretId(secretName)
versionStage.foreach(getSecretValueRequestBuilder.versionStage)
val getSecretValueRequest = getSecretValueRequestBuilder.build()

try {
logger.info("Attempting to fetch secret from AWS Secrets Manager")
Expand All @@ -44,14 +51,14 @@ object AwsSecretsUtils {
logger.info("secret retrieved. Attempting to Parse data")
val rootNode: JsonNode = new ObjectMapper().readTree(secret)

secretFields.map(field => {
Some(secretFields.map(field => {
field -> rootNode.get(field).asText()
}).toMap
}).toMap)
}
catch {
case e: Throwable =>
logger.error(s"Error occurred retrieving and parsing secrets from AWS Secrets Manager", e)
throw e
logger.warn(s"Error occurred retrieving and parsing secrets from AWS Secrets Manager", e)
None
}
}
}
Loading
Loading