Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/src/main/resources/example.application.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ loginsvc:
refresh-exp-time: 9h
key-rotation-time: 9h
key-phase-out-time: 30min
key-lay-over-time: 15min
alg-name: "RS256"
#Instead of generating the key in memory
#The Below Config allows for the application to fetch keys from AWS Secrets Manager.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ case class AwsSecretsManagerKeyConfig(
accessExpTime: FiniteDuration,
refreshExpTime: FiniteDuration,
pollTime: Option[FiniteDuration],
keyPhaseOutTime: Option[FiniteDuration]
keyPhaseOutTime: Option[FiniteDuration],
keyLayOverTime: Option[FiniteDuration]
) extends KeyConfig {

private val logger = LoggerFactory.getLogger(classOf[AwsSecretsManagerKeyConfig])
Expand Down Expand Up @@ -79,7 +80,15 @@ case class AwsSecretsManagerKeyConfig(
}
}

(currentKeyPair, previousKeyPair)
previousKeyPair.fold {(currentKeyPair, previousKeyPair)} { pk =>
val exp = keyLayOverTime.exists(!isExpired(currentSecrets.createTime, _))
if (!exp) {
(currentKeyPair, previousKeyPair)
}
else {
(pk, Some(currentKeyPair))
}
}
} catch {
case e: Throwable =>
logger.error(s"Error occurred retrieving and decoding keys from AWS Secrets Manager", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ case class InMemoryKeyConfig(
accessExpTime: FiniteDuration,
refreshExpTime: FiniteDuration,
keyRotationTime: Option[FiniteDuration],
keyPhaseOutTime: Option[FiniteDuration]
keyPhaseOutTime: Option[FiniteDuration],
keyLayOverTime: Option[FiniteDuration]
) extends KeyConfig {

private var oldKeyPair: Option[KeyPair] = None
Expand All @@ -50,7 +51,14 @@ case class InMemoryKeyConfig(
ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime must be lower than keyRotationTime!"))
} else ConfigValidationSuccess

super.validate().merge(keyPhaseOutTimeResult)
val keyLayOverTimeResult = if(keyLayOverTime.nonEmpty && keyRotationTime.nonEmpty
&& keyLayOverTime.get > keyRotationTime.get) {
ConfigValidationError(ConfigValidationException(s"keyLayOverTime must be lower than keyRotationTime!"))
} else ConfigValidationSuccess

super.validate()
.merge(keyPhaseOutTimeResult)
.merge(keyLayOverTimeResult)
}

override def throwErrors(): Unit = this.validate().throwOnErrors()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ trait KeyConfig extends ConfigValidatable {
def refreshExpTime: FiniteDuration
def keyRotationTime: Option[FiniteDuration]
def keyPhaseOutTime: Option[FiniteDuration]
def keyLayOverTime: Option[FiniteDuration]
def keyPair(): (KeyPair, Option[KeyPair])
def throwErrors(): Unit

Expand Down Expand Up @@ -79,6 +80,19 @@ trait KeyConfig extends ConfigValidatable {
ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime can only be enable if keyRotationTime is enable!"))
} else ConfigValidationSuccess

val keyLayoverTimeResult = if (keyLayOverTime.nonEmpty && keyLayOverTime.get < KeyConfig.minKeyLayOverTime) {
ConfigValidationError(ConfigValidationException(s"keyLayOverTime must be at least ${KeyConfig.minKeyLayOverTime}"))
} else ConfigValidationSuccess

val keyLayOverWithRotationResult = if (keyLayOverTime.nonEmpty && keyRotationTime.isEmpty) {
ConfigValidationError(ConfigValidationException(s"keyLayOverTime can only be enable if keyRotationTime is enable!"))
} else ConfigValidationSuccess

val keyLayOverWithPhaseResult = if(keyLayOverTime.nonEmpty && keyPhaseOutTime.nonEmpty
&& keyLayOverTime.get > keyPhaseOutTime.get) {
ConfigValidationError(ConfigValidationException(s"keyLayOverTime must be lower than keyPhaseOutTime!"))
} else ConfigValidationSuccess

if (keyRotationTime.isEmpty) {
logger.warn("keyRotationTime is not set in config, key-pair will not be rotated!")
}
Expand All @@ -93,6 +107,9 @@ trait KeyConfig extends ConfigValidatable {
.merge(keyRotationTimeResult)
.merge(keyPhaseOutTimeResult)
.merge(keyPhaseOutWithRotationResult)
.merge(keyLayoverTimeResult)
.merge(keyLayOverWithRotationResult)
.merge(keyLayOverWithPhaseResult)
}
}

Expand All @@ -101,4 +118,5 @@ object KeyConfig {
val minRefreshExpTime: FiniteDuration = 10.milliseconds
val minKeyRotationTime: FiniteDuration = 10.milliseconds
val minKeyPhaseOutTime: FiniteDuration = 10.milliseconds
val minKeyLayOverTime: FiniteDuration = 10.milliseconds
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import za.co.absa.loginsvc.model.User
import za.co.absa.loginsvc.rest.config.jwt.InMemoryKeyConfig
import za.co.absa.loginsvc.rest.config.provider.JwtConfigProvider
import za.co.absa.loginsvc.rest.model.{AccessToken, RefreshToken, Token}
import za.co.absa.loginsvc.rest.service.jwt.JWTService.extractUserFrom
import za.co.absa.loginsvc.rest.service.jwt.JWTService.{extractUserFrom, parseWithKeys}
import za.co.absa.loginsvc.rest.service.search.UserSearchService

import java.security.interfaces.RSAPublicKey
Expand All @@ -42,7 +42,7 @@ import scala.concurrent.duration.FiniteDuration
class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchService: UserSearchService) {

private val logger = LoggerFactory.getLogger(classOf[JWTService])
private val scheduler = new ScheduledThreadPoolExecutor(2, new ThreadFactory {
private val scheduler = new ScheduledThreadPoolExecutor(3, new ThreadFactory {
override def newThread(r: Runnable): Thread = {
val t = new Thread(r)
t.setDaemon(true)
Expand Down Expand Up @@ -107,21 +107,29 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe
}

def refreshTokens(accessToken: AccessToken, refreshToken: RefreshToken): (AccessToken, RefreshToken) = {
val oldAccessJws: Jws[Claims] = Jwts.parserBuilder()
.require("type", Token.TokenType.Access.toString)
.setSigningKey(primaryKeyPair.getPublic)
.setClock(() => Date.from(Instant.now().minus(jwtConfig.refreshExpTime.toJava))) // allowing expired access token - up to refresh token validity window
.build()
.parseClaimsJws(accessToken.token) // checks requirements: type=access, signature, custom validity window

val userFromOldAccessToken: User = extractUserFrom(oldAccessJws.getBody)

Jwts.parserBuilder()
.require("type", Token.TokenType.Refresh.toString)
.requireSubject(userFromOldAccessToken.name)
.setSigningKey(primaryKeyPair.getPublic)
.build()
.parseClaimsJws(refreshToken.token) // checks username, validity, and signature.

val keyList: List[PublicKey] = List(primaryKeyPair.getPublic) ++ optionalKeyPair.map(_.getPublic).toList

val oldAccessJws: Option[Jws[Claims]] = parseWithKeys(
accessToken,
keyList,
Token.TokenType.Access.toString,
Some(jwtConfig.refreshExpTime)
) // checks requirements: type=access, signature, custom validity window

if(oldAccessJws.isEmpty)
throw new JwtException("Tokens are incompatible with current keys. Please request new Tokens!")

val userFromOldAccessToken: User = extractUserFrom(oldAccessJws.get.getBody)

val refreshClaims = parseWithKeys(
refreshToken,
keyList,
Token.TokenType.Refresh.toString
)

if(refreshClaims.isEmpty)
throw new JwtException("Tokens are incompatible with current keys. Please request new Tokens!")

val userUpdatedDetails = {
try {
Expand Down Expand Up @@ -190,7 +198,7 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe
val scheduledFuture = scheduler.scheduleAtFixedRate(() => {
logger.info("Attempting to Refresh for new Keys")
try {
val (newPrimaryKeyPair, newOptionalKeyPair) = jwtConfig.keyPair()
var (newPrimaryKeyPair, newOptionalKeyPair) = jwtConfig.keyPair()
logger.info("Keys have been Refreshed")
jwtConfig.keyPhaseOutTime.foreach { kp => {
jwtConfig match {
Expand All @@ -199,6 +207,18 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe
case _ =>
}
}}
jwtConfig.keyLayOverTime.foreach { kl => {
jwtConfig match {
case _: InMemoryKeyConfig =>
newOptionalKeyPair.foreach { tok =>
scheduleKeyLayOver(kl)
val temp = tok
newOptionalKeyPair = Some(newPrimaryKeyPair)
newPrimaryKeyPair = temp
}
case _ =>
}
}}
primaryKeyPair = newPrimaryKeyPair
optionalKeyPair = newOptionalKeyPair
}
Expand All @@ -211,7 +231,6 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe
refreshTime.toMillis,
TimeUnit.MILLISECONDS
)

Runtime.getRuntime.addShutdownHook(new Thread(() => {
scheduledFuture.cancel(false)
this.close()
Expand All @@ -225,7 +244,23 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe
optionalKeyPair = None
}
}, phaseOutTime.toMillis, TimeUnit.MILLISECONDS)
Runtime.getRuntime.addShutdownHook(new Thread(() => {
scheduledFuture.cancel(false)
this.close()
}))
}

private def scheduleKeyLayOver(layOverTime: FiniteDuration): Unit = {
val scheduledFuture = scheduler.schedule(new Runnable {
override def run(): Unit = {
logger.info("Switching Signing key")
optionalKeyPair.foreach { okp =>
val temp = okp
optionalKeyPair = Some(primaryKeyPair)
primaryKeyPair = temp
}
}
}, layOverTime.toMillis, TimeUnit.MILLISECONDS)
Runtime.getRuntime.addShutdownHook(new Thread(() => {
scheduledFuture.cancel(false)
this.close()
Expand Down Expand Up @@ -254,4 +289,29 @@ object JWTService {

User(name, groups, optionalAttributes)
}

def parseWithKeys(
token: Token,
keys: List[PublicKey],
accessType: String,
clock: Option[FiniteDuration] = None
): Option[Jws[Claims]] = {
keys.flatMap { key =>
try {
val builder = Jwts.parserBuilder()
.require("type", accessType)
.setSigningKey(key)

clock.foreach(time => builder.setClock(() => Date.from(Instant.now().minus(time.toJava))))

Some(builder.build().parseClaimsJws(token.token))
} catch {
case e: MalformedJwtException =>
throw e
case e: ExpiredJwtException =>
throw e
case _: JwtException => None
}
}.headOption
}
}
5 changes: 3 additions & 2 deletions api/src/test/resources/application.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ loginsvc:
generate-in-memory:
access-exp-time: 15min
refresh-exp-time: 10h
key-rotation-time: 5sec
key-phase-out-time: 3sec
key-rotation-time: 20sec
key-phase-out-time: 10sec
key-lay-over-time: 5sec
alg-name: "RS256"

# Rest Auth Config (AD)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class AwsSecretsManagerKeyConfigTest extends AnyFlatSpec with Matchers {
15.minutes,
9.minutes,
Option(30.minutes),
Option(15.minutes))
Option(15.minutes),
Option(5.minutes))

"awsSecretsManagerKeyConfig" should "validate expected content" in {
awsSecretsManagerKeyConfig.validate() shouldBe ConfigValidationSuccess
Expand All @@ -60,15 +61,30 @@ class AwsSecretsManagerKeyConfigTest extends AnyFlatSpec with Matchers {
}

it should "fail on non-negative keyPhaseOutTime" in {
awsSecretsManagerKeyConfig.copy(keyPhaseOutTime = Option(5.milliseconds)).validate() shouldBe
awsSecretsManagerKeyConfig.copy(keyPhaseOutTime = Option(5.milliseconds), keyLayOverTime = None).validate() shouldBe
ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime must be at least ${KeyConfig.minKeyPhaseOutTime}"))
}

it should "fail on keyPhaseOutTime being configured without keyRotationTime" in {
awsSecretsManagerKeyConfig.copy(pollTime = None).validate() shouldBe
awsSecretsManagerKeyConfig.copy(pollTime = None, keyLayOverTime = None).validate() shouldBe
ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime can only be enable if keyRotationTime is enable!"))
}

it should "fail on non-negative keyLayOverTime" in {
awsSecretsManagerKeyConfig.copy(keyLayOverTime = Option(5.milliseconds)).validate() shouldBe
ConfigValidationError(ConfigValidationException(s"keyLayOverTime must be at least ${KeyConfig.minKeyLayOverTime}"))
}

it should "fail on keyLayOverTime being configured without keyRotationTime" in {
awsSecretsManagerKeyConfig.copy(pollTime = None, keyPhaseOutTime = None).validate() shouldBe
ConfigValidationError(ConfigValidationException(s"keyLayOverTime can only be enable if keyRotationTime is enable!"))
}

it should "fail on keyLayOverTime being larger than keyPhaseOutTime" in {
awsSecretsManagerKeyConfig.copy(keyPhaseOutTime = Option(4.minutes)).validate() shouldBe
ConfigValidationError(ConfigValidationException(s"keyLayOverTime must be lower than keyPhaseOutTime!"))
}

it should "fail on missing value" in {
awsSecretsManagerKeyConfig.copy(secretName = null).validate() shouldBe
ConfigValidationError(ConfigValidationException("secretName is empty"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class InMemoryKeyConfigTest extends AnyFlatSpec with Matchers {
15.minutes,
2.hours,
Option(30.minutes),
Option(15.minutes))
Option(15.minutes),
Option(5.minutes))

"inMemoryKeyConfig" should "validate expected content" in {
inMemoryKeyConfig.validate() shouldBe ConfigValidationSuccess
Expand All @@ -51,22 +52,42 @@ class InMemoryKeyConfigTest extends AnyFlatSpec with Matchers {
}

it should "fail on non-negative keyRotationTime" in {
inMemoryKeyConfig.copy(keyRotationTime = Option(5.milliseconds), keyPhaseOutTime = None).validate() shouldBe
inMemoryKeyConfig.copy(keyRotationTime = Option(5.milliseconds), keyPhaseOutTime = None, keyLayOverTime = None).validate() shouldBe
ConfigValidationError(ConfigValidationException(s"keyRotationTime must be at least ${KeyConfig.minKeyRotationTime}"))
}

it should "fail on non-negative keyPhaseOutTime" in {
inMemoryKeyConfig.copy(keyPhaseOutTime = Option(5.milliseconds)).validate() shouldBe
inMemoryKeyConfig.copy(keyPhaseOutTime = Option(5.milliseconds), keyLayOverTime = None).validate() shouldBe
ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime must be at least ${KeyConfig.minKeyPhaseOutTime}"))
}

it should "fail on keyPhaseOutTime being configured without keyRotationTime" in {
inMemoryKeyConfig.copy(keyRotationTime = None).validate() shouldBe
inMemoryKeyConfig.copy(keyRotationTime = None, keyLayOverTime = None).validate() shouldBe
ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime can only be enable if keyRotationTime is enable!"))
}

it should "fail on keyPhaseOutTime being larger than keyRotationTime" in {
inMemoryKeyConfig.copy(keyRotationTime = Option(10.minutes)).validate() shouldBe
inMemoryKeyConfig.copy(keyRotationTime = Option(10.minutes), keyLayOverTime = None).validate() shouldBe
ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime must be lower than keyRotationTime!"))
}

it should "fail on non-negative keyLayOverTime" in {
inMemoryKeyConfig.copy(keyLayOverTime = Option(5.milliseconds)).validate() shouldBe
ConfigValidationError(ConfigValidationException(s"keyLayOverTime must be at least ${KeyConfig.minKeyLayOverTime}"))
}

it should "fail on keyLayOverTime being configured without keyRotationTime" in {
inMemoryKeyConfig.copy(keyRotationTime = None, keyPhaseOutTime = None).validate() shouldBe
ConfigValidationError(ConfigValidationException(s"keyLayOverTime can only be enable if keyRotationTime is enable!"))
}

it should "fail on keyLayOverTime being larger than keyRotationTime" in {
inMemoryKeyConfig.copy(keyRotationTime = Option(4.minutes), keyPhaseOutTime = None).validate() shouldBe
ConfigValidationError(ConfigValidationException(s"keyLayOverTime must be lower than keyRotationTime!"))
}

it should "fail on keyLayOverTime being larger than keyPhaseOutTime" in {
inMemoryKeyConfig.copy(keyPhaseOutTime = Option(4.minutes)).validate() shouldBe
ConfigValidationError(ConfigValidationException(s"keyLayOverTime must be lower than keyPhaseOutTime!"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class ConfigProviderTest extends AnyFlatSpec with Matchers {
keyConfig.algName shouldBe "RS256"
keyConfig.accessExpTime shouldBe FiniteDuration(15, TimeUnit.MINUTES)
keyConfig.refreshExpTime shouldBe FiniteDuration(10, TimeUnit.HOURS)
keyConfig.keyRotationTime.get shouldBe FiniteDuration(5, TimeUnit.SECONDS)
keyConfig.keyRotationTime.get shouldBe FiniteDuration(20, TimeUnit.SECONDS)
keyConfig.keyPhaseOutTime.get shouldBe FiniteDuration(10, TimeUnit.SECONDS)
keyConfig.keyLayOverTime.get shouldBe FiniteDuration(5, TimeUnit.SECONDS)
}

"The ldapConfig properties" should "Match" in {
Expand Down
Loading
Loading