Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/src/main/resources/example.application.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ loginsvc:
access-exp-time: 15min
refresh-exp-time: 9h
key-rotation-time: 9h
key-phase-out-time: 30min
alg-name: "RS256"
#Instead of generating the key in memory
#The Below Config allows for the application to fetch keys from AWS Secrets Manager.
Expand All @@ -18,6 +19,7 @@ loginsvc:
#access-exp-time: 15min
#refresh-exp-time: 9h
#poll-time: 5min
#key-phase-out-time: 30min
#alg-name: "RS256"
config:
# Generates git.properties file for use on info endpoint.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class SecurityConfig @Autowired()(authConfigsProvider: AuthConfigProvider) {
"/actuator/**",
"/token/refresh", // access+refresh JWT in payload, no auth
"/token/public-key-jwks",
"/token/public-key").permitAll()
"/token/public-key",
"/token/public-keys").permitAll()
.anyRequest().authenticated()
.and()
.sessionManagement()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,15 @@ case class AwsSecretsLdapUserConfig(private val secretName: String,

private def getUsernameAndPasswordFromSecret: (String, String) = {
try {
val secrets = AwsSecretsUtils.fetchSecret(secretName, region, Array(usernameFieldName, passwordFieldName))
(secrets(usernameFieldName), secrets(passwordFieldName))
}
catch {
val secretsOption = AwsSecretsUtils.fetchSecret(secretName, region, Array(usernameFieldName, passwordFieldName))

secretsOption.fold(
throw new Exception("Error retrieving username and password from from AWS Secrets Manager")
) { secrets =>
(secrets.secretValue(usernameFieldName), secrets.secretValue(passwordFieldName))
}

} catch {
case e: Throwable =>
logger.error(s"Error occurred retrieving account data from AWS Secrets Manager", e)
throw e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import za.co.absa.loginsvc.utils.AwsSecretsUtils

import java.security.{KeyFactory, KeyPair}
import java.security.spec.{PKCS8EncodedKeySpec, X509EncodedKeySpec}
import java.time.Instant
import java.util.Base64
import scala.concurrent.duration.FiniteDuration

Expand All @@ -34,38 +35,57 @@ case class AwsSecretsManagerKeyConfig(
algName: String,
accessExpTime: FiniteDuration,
refreshExpTime: FiniteDuration,
pollTime: Option[FiniteDuration]
pollTime: Option[FiniteDuration],
keyPhaseOutTime: Option[FiniteDuration]
) extends KeyConfig {

private val logger = LoggerFactory.getLogger(classOf[AwsSecretsManagerKeyConfig])

override def keyRotationTime : Option[FiniteDuration] = pollTime
override def keyPair(): KeyPair = {
override def keyPair(): (KeyPair, Option[KeyPair]) = {
try {
val secrets = AwsSecretsUtils.fetchSecret(secretName, region, Array(privateKeyFieldName, publicKeyFieldName))

val publicKeySpec: X509EncodedKeySpec = new X509EncodedKeySpec(
Base64.getDecoder.decode(
secrets(publicKeyFieldName)
)
)
val privateKeySpec: PKCS8EncodedKeySpec = new PKCS8EncodedKeySpec(
Base64.getDecoder.decode(
secrets(privateKeyFieldName)
)
val currentSecretsOption = AwsSecretsUtils.fetchSecret(
secretName,
region,
Array(privateKeyFieldName, publicKeyFieldName)
)

logger.info("Key Data successfully retrieved and parsed from AWS Secrets Manager")
if(currentSecretsOption.isEmpty)
throw new Exception("Error retrieving AWSCURRENT key from from AWS Secrets Manager")

val keyFactory: KeyFactory = KeyFactory.getInstance(jwtAlgorithmToCryptoAlgorithm)
new KeyPair(keyFactory.generatePublic(publicKeySpec), keyFactory.generatePrivate(privateKeySpec))
}
catch {
val currentKeyPair = createKeyPair(currentSecretsOption.get.secretValue)
logger.info("AWSCURRENT Key Data successfully retrieved and parsed from AWS Secrets Manager")

val previousSecretsOption =
AwsSecretsUtils.fetchSecret(
secretName,
region,
Array(privateKeyFieldName, publicKeyFieldName),
Some("AWSPREVIOUS")
)

val previousKeyPair = previousSecretsOption.flatMap { previousSecrets =>
try {
val keys = createKeyPair(previousSecrets.secretValue)
logger.info("AWSPREVIOUS Key Data successfully retrieved and parsed from AWS Secrets Manager")
val exp = keyPhaseOutTime.exists(isExpired(previousSecrets.createTime, _))
if(exp) { None }
else { Some(keys) }
} catch {
case e: Throwable =>
logger.warn(s"Error occurred decoding AWSPREVIOUSKEYS, skipping previous keys.", e)
None
}
}

(currentKeyPair, previousKeyPair)
} catch {
case e: Throwable =>
logger.error(s"Error occurred retrieving and decoding keys from AWS Secrets Manager", e)
throw e
}
}

override def throwErrors(): Unit = this.validate().throwOnErrors()

override def validate(): ConfigValidationResult = {
Expand All @@ -92,4 +112,26 @@ case class AwsSecretsManagerKeyConfig(

super.validate().merge(awsSecretsResultsMerge)
}

private def createKeyPair(secretKeys: Map[String, String]): KeyPair = {

val publicKeySpec: X509EncodedKeySpec = new X509EncodedKeySpec(
Base64.getDecoder.decode(
secretKeys(publicKeyFieldName)
)
)
val privateKeySpec: PKCS8EncodedKeySpec = new PKCS8EncodedKeySpec(
Base64.getDecoder.decode(
secretKeys(privateKeyFieldName)
)
)

val keyFactory: KeyFactory = KeyFactory.getInstance(jwtAlgorithmToCryptoAlgorithm)
new KeyPair(keyFactory.generatePublic(publicKeySpec), keyFactory.generatePrivate(privateKeySpec))
}

private def isExpired(creationTime: Instant, finiteDuration: FiniteDuration): Boolean = {
val expirationTime = creationTime.plus(finiteDuration.toMillis, java.time.temporal.ChronoUnit.MILLIS)
Instant.now().isAfter(expirationTime)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package za.co.absa.loginsvc.rest.config.jwt
import io.jsonwebtoken.SignatureAlgorithm
import io.jsonwebtoken.security.Keys
import org.slf4j.LoggerFactory
import za.co.absa.loginsvc.rest.config.validation.{ConfigValidationException, ConfigValidationResult}
import za.co.absa.loginsvc.rest.config.validation.ConfigValidationResult.{ConfigValidationError, ConfigValidationSuccess}

import java.security.KeyPair
import scala.concurrent.duration.FiniteDuration
Expand All @@ -27,17 +29,30 @@ case class InMemoryKeyConfig(
algName: String,
accessExpTime: FiniteDuration,
refreshExpTime: FiniteDuration,
keyRotationTime: Option[FiniteDuration]
keyRotationTime: Option[FiniteDuration],
keyPhaseOutTime: Option[FiniteDuration]
) extends KeyConfig {

private var oldKeyPair: Option[KeyPair] = None
private val logger = LoggerFactory.getLogger(classOf[InMemoryKeyConfig])

override def keyPair(): KeyPair = {
override def keyPair(): (KeyPair, Option[KeyPair]) = {
logger.info(s"Generating new keys - every ${keyRotationTime.getOrElse("?")}")
Keys.keyPairFor(SignatureAlgorithm.valueOf(algName))
val newKeyPair = Keys.keyPairFor(SignatureAlgorithm.valueOf(algName))
val result = (newKeyPair, oldKeyPair)
oldKeyPair = Some(newKeyPair)
result
}

override def throwErrors(): Unit = this.validate().throwOnErrors()
override def validate(): ConfigValidationResult = {
val keyPhaseOutTimeResult = if(keyPhaseOutTime.nonEmpty && keyRotationTime.nonEmpty
&& keyPhaseOutTime.get > keyRotationTime.get) {
ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime must be lower than keyRotationTime!"))
} else ConfigValidationSuccess

super.validate().merge(keyPhaseOutTimeResult)
}

override def throwErrors(): Unit = this.validate().throwOnErrors()
}

Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ import za.co.absa.loginsvc.rest.config.validation.{ConfigValidatable, ConfigVali
import za.co.absa.loginsvc.rest.config.validation.ConfigValidationResult.{ConfigValidationError, ConfigValidationSuccess}

import java.security.KeyPair
import java.util.concurrent.TimeUnit
import scala.concurrent.duration.FiniteDuration
import scala.concurrent.duration._
import scala.util.{Failure, Success, Try}

trait KeyConfig extends ConfigValidatable {
def algName: String
def accessExpTime: FiniteDuration
def refreshExpTime: FiniteDuration
def keyRotationTime: Option[FiniteDuration]
def keyPair(): KeyPair
def keyPhaseOutTime: Option[FiniteDuration]
def keyPair(): (KeyPair, Option[KeyPair])
def throwErrors(): Unit

final def jwtAlgorithmToCryptoAlgorithm : String = {
Expand Down Expand Up @@ -71,16 +71,34 @@ trait KeyConfig extends ConfigValidatable {
ConfigValidationError(ConfigValidationException(s"keyRotationTime must be at least ${KeyConfig.minKeyRotationTime}"))
} else ConfigValidationSuccess

val keyPhaseOutTimeResult = if (keyPhaseOutTime.nonEmpty && keyPhaseOutTime.get < KeyConfig.minKeyPhaseOutTime) {
ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime must be at least ${KeyConfig.minKeyPhaseOutTime}"))
} else ConfigValidationSuccess

val keyPhaseOutWithRotationResult = if (keyPhaseOutTime.nonEmpty && keyRotationTime.isEmpty) {
ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime can only be enable if keyRotationTime is enable!"))
} else ConfigValidationSuccess

if (keyRotationTime.isEmpty) {
logger.warn("keyRotationTime is not set in config, key-pair will not be rotated!")
}

algValidation.merge(accessExpTimeResult).merge(refreshExpTimeResult).merge(keyRotationTimeResult)
if(keyPhaseOutTime.isEmpty) {
logger.warn("keyPhaseOutTime is not set in config, the previously used public key will be viewable till rotation!")
}

algValidation
.merge(accessExpTimeResult)
.merge(refreshExpTimeResult)
.merge(keyRotationTimeResult)
.merge(keyPhaseOutTimeResult)
.merge(keyPhaseOutWithRotationResult)
}
}

object KeyConfig {
val minAccessExpTime: FiniteDuration = FiniteDuration(10, TimeUnit.MILLISECONDS)
val minRefreshExpTime: FiniteDuration = FiniteDuration(10, TimeUnit.MILLISECONDS)
val minKeyRotationTime: FiniteDuration = FiniteDuration(10, TimeUnit.MILLISECONDS)
val minAccessExpTime: FiniteDuration = 10.milliseconds
val minRefreshExpTime: FiniteDuration = 10.milliseconds
val minKeyRotationTime: FiniteDuration = 10.milliseconds
val minKeyPhaseOutTime: FiniteDuration = 10.milliseconds
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.springframework.security.core.Authentication
import org.springframework.web.bind.annotation._
import org.springframework.web.server.ResponseStatusException
import za.co.absa.loginsvc.model.User
import za.co.absa.loginsvc.rest.model.{KerberosUserDetails, PublicKey, TokensWrapper}
import za.co.absa.loginsvc.rest.model.{KerberosUserDetails, PublicKey, PublicKeySet, TokensWrapper}
import za.co.absa.loginsvc.rest.service.jwt.JWTService
import za.co.absa.loginsvc.utils.OptionUtils.ImplicitBuilderExt

Expand Down Expand Up @@ -144,16 +144,41 @@ class TokenController @Autowired()(jwtService: JWTService) {
)
@ResponseStatus(HttpStatus.OK)
def getPublicKey(): CompletableFuture[PublicKey] = {
val publicKey = jwtService.publicKey
val (publicKey, _) = jwtService.publicKeys
val publicKeyBase64 = Base64.getEncoder.encodeToString(publicKey.getEncoded)

Future.successful(PublicKey(publicKeyBase64))
}

@Tags(Array(new Tag(name = "token")))
@Operation(
summary = "Gives payload with the current and previously rotated RSA256 public key",
description = """Alternative to /public-key - exposes current and previous public keys allowing users to verify a JWT after rotation.""",
responses = Array(
new ApiResponse(responseCode = "200", description = "Payload containing current and previous public keys is returned",
content = Array(new Content(
schema = new Schema(implementation = classOf[PublicKey]),
examples = Array(new ExampleObject(value = "{\n \"keys\": [\n {\n \"key\": \"ABCDEFGH1234\"\n}\n]\n}")))
)
)
)
)
@GetMapping(
path = Array("/public-keys"),
produces = Array(MediaType.APPLICATION_JSON_VALUE)
)
@ResponseStatus(HttpStatus.OK)
def getAllPublicKeys(): CompletableFuture[PublicKeySet] = {
val (primaryPublicKey, optionalPublicKey) = jwtService.publicKeys
val currentPublicKey = PublicKey(Base64.getEncoder.encodeToString(primaryPublicKey.getEncoded))
val previousPublicKey = optionalPublicKey.map(pk =>
PublicKey(Base64.getEncoder.encodeToString(pk.getEncoded)))
Future.successful(PublicKeySet(keys = currentPublicKey :: previousPublicKey.toList))
}

@Tags(Array(new Tag(name = "token")))
@Operation(
summary = "Gives payload with the RSA256 public key in JWKS format",
description = "Returns the same information as /token/public-key, but as a JSON Web Key Set",
description = "Returns the same information as /token/public-keys, but as a JSON Web Key Set",
responses = Array(
new ApiResponse(responseCode = "200", description = "Success", content = Array(new Content(examples = Array(new ExampleObject(value = """{"keys":[{"kty": "EC","crv": "P-256","x": "MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4","y": "4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM","use": "enc","kid": "1"},{"kty": "RSA","n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw","e": "AQAB","alg": "RS256","kid": "2011-04-29"}]}"""))))),
))
Expand All @@ -163,10 +188,10 @@ class TokenController @Autowired()(jwtService: JWTService) {
)
@ResponseStatus(HttpStatus.OK)
def getPublicKeyJwks(): CompletableFuture[Map[String, AnyRef]] = {
val jwks = jwtService.jwks
val jwk = jwtService.jwks

import scala.collection.JavaConverters._
Future.successful(jwks.toJSONObject(true).asScala.toMap)
Future.successful(jwk.toJSONObject(true).asScala.toMap)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright 2023 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.loginsvc.rest.model

import java.time.Instant

case class AwsSecret(secretValue: Map[String, String], createTime: Instant)
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import io.swagger.v3.oas.annotations.media.Schema.RequiredMode

case class PublicKey(
@JsonProperty("key")
@Schema(example = "ABCDEFGH1234", requiredMode = RequiredMode.REQUIRED)
@Schema(example = "ABCDEFGH1234", requiredMode = RequiredMode.REQUIRED,
description = "The public key currently signing JWTs")
key: String
) extends AnyVal
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright 2023 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.loginsvc.rest.model

import com.fasterxml.jackson.annotation.JsonProperty
import io.swagger.v3.oas.annotations.media.Schema
import io.swagger.v3.oas.annotations.media.Schema.RequiredMode

case class PublicKeySet(
@JsonProperty("keys")
@Schema(requiredMode = RequiredMode.REQUIRED,
description = "The full set public keys including the one currently signing JWTs",
example = """[
{"key": "ABCDEFGH1234"},
{"key": "ZYXWVUT9876"},
]""")
keys: List[PublicKey]
) extends AnyVal
Loading
Loading