Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@

package za.co.absa.loginclient.authorization

import com.google.common.cache.CacheBuilder
import org.springframework.cache.concurrent.ConcurrentMapCache
import org.springframework.security.oauth2.jwt.{JwtDecoder, NimbusJwtDecoder}
import za.co.absa.loginclient.publicKeyRetrieval.model.PublicKey

import java.security.KeyFactory
import java.security.interfaces.RSAPublicKey
import java.security.spec.X509EncodedKeySpec
import java.util.Base64
import java.util.concurrent.TimeUnit

object JwtDecoderProvider {

Expand Down Expand Up @@ -59,9 +62,16 @@ object JwtDecoderProvider {
* Currently implemented by NimbusJwtDecoder.withJwkSetUri(JWKS_PATH).
*
* @param host The URL from which the public key will be fetched.
* @param refreshPeriod Optional value in seconds for caching keys from host. Default cache expiry is 5 minutes.
* @return A JwtDecoder instance initialized with the public key fetched from the URL.
*/
def getDecoderFromURL(host: String): JwtDecoder = {
NimbusJwtDecoder.withJwkSetUri(s"$host/token/public-key-jwks").build()
def getDecoderFromURL(host: String, refreshPeriod: Option[Int] = None): JwtDecoder = {
val decoderBuilder = NimbusJwtDecoder.withJwkSetUri(s"$host/token/public-key-jwks")
refreshPeriod.foreach(rp => decoderBuilder.cache(
new ConcurrentMapCache("jwkSetCache", CacheBuilder.newBuilder()
.expireAfterWrite(rp, TimeUnit.SECONDS)
.build().asMap(), false)))

decoderBuilder.build()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
package za.co.absa.loginclient.publicKeyRetrieval.client

import com.google.gson.JsonParser
import com.nimbusds.jose.jwk.JWK
import com.nimbusds.jose.jwk.JWKSet
import org.slf4j.{Logger, LoggerFactory}
import org.springframework.web.client.RestTemplate
import za.co.absa.loginclient.publicKeyRetrieval.model.PublicKey

import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable`

/**
* This class is used to retrieve the public key from the issuer.
* public key is available in either a string or a JWK format.
Expand All @@ -33,7 +35,7 @@ case class PublicKeyRetrievalClient(host: String) {
private val logger: Logger = LoggerFactory.getLogger(this.getClass)

/**
* Retrieves the public key from the login service as a PublicKey object.
* Retrieves the current public key from the login service as a PublicKey object.
* This method fetches the public key used for JWT verification and returns it as a PublicKey object.
* Key is available as a string within the object
*
Expand All @@ -47,19 +49,37 @@ case class PublicKeyRetrievalClient(host: String) {
}

/**
* Retrieves the public key from the login service in JWK (JSON Web Key) format.
* This method fetches the public key used for JWT verification and returns it in JWK format.
* Retrieves all available public keys from the login service as a set of PublicKey objects.
* This method fetches the public keys used for JWT verification and returns it as a set of PublicKey objects.
* Keys are available as a string within the objects.
*
* @return A set of PublicKey objects representing the public keys retrieved from the login service.
*/
def getPublicKeys: Set[PublicKey] = {
val issuerUri = s"$host/token/public-keys"
val jsonString = fetchToken(issuerUri)
val publicKeyList = JsonParser.parseString(jsonString).getAsJsonObject.getAsJsonArray("keys").asList()
val publicKeyStrings = publicKeyList.map {jsonElement =>
val obj = jsonElement.getAsJsonObject
obj.entrySet().head.getValue.getAsString
}.toSet

publicKeyStrings.map { publicKeyString => PublicKey(publicKeyString) }
}

/**
* Retrieves the public key from the login service as a JWK (JSON Web Key) set.
* This method fetches the public key(s) used for JWT verification and returns it as a JWK Set format.
*
* @return A String containing the public key in JWK (JSON Web Key) format retrieved from the login service.
* @return A String containing the public key(s) in JWK (JSON Web Key) format retrieved from the login service.
*/
def getPublicKeyJwk: JWK = {
def getPublicKeyJwk: JWKSet = {
val issuerUri = s"$host/token/public-key-jwks"
val jsonString = fetchToken(issuerUri)
val jwkString = JsonParser.parseString(jsonString).getAsJsonObject.get("key").getAsString
JWK.parse(jwkString)
JWKSet.parse(jsonString)
}

private def fetchToken(issuerUri: String): String = {
private[client] def fetchToken(issuerUri: String): String = {

logger.info(s"Fetching token from $issuerUri")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ case class TokenRetrievalClient(host: String) {
System.setProperties(properties)
}

private def fetchToken(issuerUri: String, username: String, password: String): String = {
private[client] def fetchToken(issuerUri: String, username: String, password: String): String = {

logger.info(s"Fetching token from $issuerUri for user $username")

Expand Down Expand Up @@ -218,7 +218,7 @@ case class TokenRetrievalClient(host: String) {
}
}

private def fetchToken(issuerUri: String, keyTabLocation: Option[String], userPrincipal: Option[String]): String = {
private[client] def fetchToken(issuerUri: String, keyTabLocation: Option[String], userPrincipal: Option[String]): String = {

val restTemplate: KerberosRestTemplate = (keyTabLocation, userPrincipal) match {
case (Some(_), Some(_)) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package za.co.absa.loginclient.authorization

import io.jsonwebtoken.SignatureAlgorithm
import io.jsonwebtoken.security.Keys
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.springframework.security.oauth2.jwt.JwtDecoder

import java.security.interfaces.RSAPublicKey
import java.util.Base64

class JwtDecoderProviderTest extends AnyFlatSpec with Matchers{

"getDecoderFromPublicKeyString" should "correctly generate a JwtDecoder from a valid RSA public key string" in {
val keyPair = Keys.keyPairFor(SignatureAlgorithm.RS256)

val publicKey: RSAPublicKey = keyPair.getPublic.asInstanceOf[RSAPublicKey]
val publicKeyString = Base64.getEncoder.encodeToString(publicKey.getEncoded)

val decoder: JwtDecoder = JwtDecoderProvider.getDecoderFromPublicKeyString(publicKeyString)

decoder should not be null
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package za.co.absa.loginclient.publicKeyRetrieval.client

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import za.co.absa.loginclient.publicKeyRetrieval.model.PublicKey

class PublicKeyRetrievalClientTest extends AnyFlatSpec with Matchers{

private val dummyUri = "https://example.com"

class testPublicKeyRetrievalClient extends PublicKeyRetrievalClient(dummyUri) {
private val publicKeyUri = s"$dummyUri/token/public-key"
private val publicKeyJson = """{"key": "mocked-public-key"}"""
private val publicKeysUri = s"$dummyUri/token/public-keys"
private val publicKeysJson =
"""{
| "keys": [
| { "key1": "public-key-1" },
| { "key2": "public-key-2" }
| ]
|}""".stripMargin
private val publicKeyJwkUri = s"$dummyUri/token/public-key-jwks"
private val publicKeyJwkJson =
"""{
| "keys": [
| {
| "kty": "RSA",
| "kid": "test-key-id",
| "n": "test-modulus",
| "e": "test-exponent"
| }
| ]
|}""".stripMargin

override private[client] def fetchToken(issuerUri: String): String = {
Map(
publicKeyUri -> publicKeyJson,
publicKeysUri -> publicKeysJson,
publicKeyJwkUri -> publicKeyJwkJson
).getOrElse(issuerUri, throw new IllegalArgumentException(s"Unexpected URI: $issuerUri"))
}
}

"getPublicKey" should "return the expected PublicKey object" in {
val testClient = new testPublicKeyRetrievalClient

val result = testClient.getPublicKey

result shouldBe PublicKey("mocked-public-key")
}

"getPublicKeys" should "return the expected set of PublicKey objects" in {
val testClient = new testPublicKeyRetrievalClient

val result = testClient.getPublicKeys

result shouldBe Set(
PublicKey("public-key-1"),
PublicKey("public-key-2")
)
}

"getPublicKeyJwk" should "return the expected PublicKey JWKSet object" in {
val testClient = new testPublicKeyRetrievalClient

val result = testClient.getPublicKeyJwk
result.getKeyByKeyId("test-key-id").toString shouldBe
"{\"kty\":\"RSA\",\"e\":\"test-exponent\",\"kid\":\"test-key-id\",\"n\":\"test-modulus\"}"
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package za.co.absa.loginclient.tokenRetrieval.client

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import za.co.absa.loginclient.tokenRetrieval.model.{AccessToken, RefreshToken}

class TokenRetrievalClientTest extends AnyFlatSpec with Matchers{

private val dummyUri = "https://example.com"
private val dummyUser = "exampleUser"
private val dummyPassword = "examplePassword"
private val dummyGroups = List()

class testTokenRetrievalClient extends TokenRetrievalClient(dummyUri) {
override private[client] def fetchToken(issuerUri: String, username: String, password: String) =
"""{
| "token": "mock-access-token",
| "refresh": "mock-refresh-token"
|}""".stripMargin
}

"fetchAccessAndRefreshToken" should "return expected tokens" in {

val testClient = new testTokenRetrievalClient

val (accessResult, refreshResult) = testClient.fetchAccessAndRefreshToken(dummyUser, dummyPassword, dummyGroups)
accessResult shouldBe AccessToken("mock-access-token")
refreshResult shouldBe RefreshToken("mock-refresh-token")
}
}
7 changes: 6 additions & 1 deletion project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ object Dependencies {
lazy val jacksonModuleScala = "com.fasterxml.jackson.module" %% "jackson-module-scala" % Versions.jacksonModuleScala
lazy val jacksonDatabind = "com.fasterxml.jackson.core" % "jackson-databind" % Versions.jacksonDatabind
lazy val javaCompat = "org.scala-lang.modules" %% "scala-java8-compat" % Versions.javaCompat
lazy val javaCollectionConverter = "org.scala-lang.modules" %% "scala-collection-compat" % "2.11.0"

lazy val springBootWeb = "org.springframework.boot" % "spring-boot-starter-web" % Versions.springBoot
lazy val springBootTomcat = "org.springframework.boot" % "spring-boot-starter-tomcat" % Versions.springBoot % Provided
Expand All @@ -52,6 +53,7 @@ object Dependencies {
lazy val jjwtJackson = "io.jsonwebtoken" % "jjwt-jackson" % Versions.jjwt % Runtime

lazy val jsonParser = "com.google.code.gson" % "gson" % "2.10.1"
lazy val cacheBuilder = "com.google.guava" % "guava" % "33.0.0-jre"

lazy val jwtDecoder = "org.springframework.security" % "spring-security-oauth2-jose" % Versions.spring
lazy val nimbusJoseJwt = "com.nimbusds" % "nimbus-jose-jwt" % Versions.nimbusJoseJwt
Expand Down Expand Up @@ -116,6 +118,7 @@ object Dependencies {

def clientLibDependencies: Seq[ModuleID] = Seq(
javaCompat,
javaCollectionConverter,

nimbusJoseJwt,
jwtDecoder,
Expand All @@ -125,13 +128,15 @@ object Dependencies {
jjwtJackson,

jsonParser,
cacheBuilder,

springBootWeb,
springBootSecurity,

springSecurityKerberosClient,

scalaTest
scalaTest,
springBootTest
)

def exampleDependencies: Seq[ModuleID] = Seq(
Expand Down
Loading