diff --git a/clientLibrary/src/main/scala/za/co/absa/loginclient/authorization/JwtDecoderProvider.scala b/clientLibrary/src/main/scala/za/co/absa/loginclient/authorization/JwtDecoderProvider.scala index 9a389922..994aed97 100644 --- a/clientLibrary/src/main/scala/za/co/absa/loginclient/authorization/JwtDecoderProvider.scala +++ b/clientLibrary/src/main/scala/za/co/absa/loginclient/authorization/JwtDecoderProvider.scala @@ -16,6 +16,8 @@ package za.co.absa.loginclient.authorization +import com.google.common.cache.CacheBuilder +import org.springframework.cache.concurrent.ConcurrentMapCache import org.springframework.security.oauth2.jwt.{JwtDecoder, NimbusJwtDecoder} import za.co.absa.loginclient.publicKeyRetrieval.model.PublicKey @@ -23,6 +25,7 @@ import java.security.KeyFactory import java.security.interfaces.RSAPublicKey import java.security.spec.X509EncodedKeySpec import java.util.Base64 +import java.util.concurrent.TimeUnit object JwtDecoderProvider { @@ -59,9 +62,16 @@ object JwtDecoderProvider { * Currently implemented by NimbusJwtDecoder.withJwkSetUri(JWKS_PATH). * * @param host The URL from which the public key will be fetched. + * @param refreshPeriod Optional value in seconds for caching keys from host. Default cache expiry is 5 minutes. * @return A JwtDecoder instance initialized with the public key fetched from the URL. */ - def getDecoderFromURL(host: String): JwtDecoder = { - NimbusJwtDecoder.withJwkSetUri(s"$host/token/public-key-jwks").build() + def getDecoderFromURL(host: String, refreshPeriod: Option[Int] = None): JwtDecoder = { + val decoderBuilder = NimbusJwtDecoder.withJwkSetUri(s"$host/token/public-key-jwks") + refreshPeriod.foreach(rp => decoderBuilder.cache( + new ConcurrentMapCache("jwkSetCache", CacheBuilder.newBuilder() + .expireAfterWrite(rp, TimeUnit.SECONDS) + .build().asMap(), false))) + + decoderBuilder.build() } } diff --git a/clientLibrary/src/main/scala/za/co/absa/loginclient/publicKeyRetrieval/client/PublicKeyRetrievalClient.scala b/clientLibrary/src/main/scala/za/co/absa/loginclient/publicKeyRetrieval/client/PublicKeyRetrievalClient.scala index e8e89436..0aa83f4b 100644 --- a/clientLibrary/src/main/scala/za/co/absa/loginclient/publicKeyRetrieval/client/PublicKeyRetrievalClient.scala +++ b/clientLibrary/src/main/scala/za/co/absa/loginclient/publicKeyRetrieval/client/PublicKeyRetrievalClient.scala @@ -17,11 +17,13 @@ package za.co.absa.loginclient.publicKeyRetrieval.client import com.google.gson.JsonParser -import com.nimbusds.jose.jwk.JWK +import com.nimbusds.jose.jwk.JWKSet import org.slf4j.{Logger, LoggerFactory} import org.springframework.web.client.RestTemplate import za.co.absa.loginclient.publicKeyRetrieval.model.PublicKey +import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` + /** * This class is used to retrieve the public key from the issuer. * public key is available in either a string or a JWK format. @@ -33,7 +35,7 @@ case class PublicKeyRetrievalClient(host: String) { private val logger: Logger = LoggerFactory.getLogger(this.getClass) /** - * Retrieves the public key from the login service as a PublicKey object. + * Retrieves the current public key from the login service as a PublicKey object. * This method fetches the public key used for JWT verification and returns it as a PublicKey object. * Key is available as a string within the object * @@ -47,19 +49,37 @@ case class PublicKeyRetrievalClient(host: String) { } /** - * Retrieves the public key from the login service in JWK (JSON Web Key) format. - * This method fetches the public key used for JWT verification and returns it in JWK format. + * Retrieves all available public keys from the login service as a set of PublicKey objects. + * This method fetches the public keys used for JWT verification and returns it as a set of PublicKey objects. + * Keys are available as a string within the objects. + * + * @return A set of PublicKey objects representing the public keys retrieved from the login service. + */ + def getPublicKeys: Set[PublicKey] = { + val issuerUri = s"$host/token/public-keys" + val jsonString = fetchToken(issuerUri) + val publicKeyList = JsonParser.parseString(jsonString).getAsJsonObject.getAsJsonArray("keys").asList() + val publicKeyStrings = publicKeyList.map {jsonElement => + val obj = jsonElement.getAsJsonObject + obj.entrySet().head.getValue.getAsString + }.toSet + + publicKeyStrings.map { publicKeyString => PublicKey(publicKeyString) } + } + + /** + * Retrieves the public key from the login service as a JWK (JSON Web Key) set. + * This method fetches the public key(s) used for JWT verification and returns it as a JWK Set format. * - * @return A String containing the public key in JWK (JSON Web Key) format retrieved from the login service. + * @return A String containing the public key(s) in JWK (JSON Web Key) format retrieved from the login service. */ - def getPublicKeyJwk: JWK = { + def getPublicKeyJwk: JWKSet = { val issuerUri = s"$host/token/public-key-jwks" val jsonString = fetchToken(issuerUri) - val jwkString = JsonParser.parseString(jsonString).getAsJsonObject.get("key").getAsString - JWK.parse(jwkString) + JWKSet.parse(jsonString) } - private def fetchToken(issuerUri: String): String = { + private[client] def fetchToken(issuerUri: String): String = { logger.info(s"Fetching token from $issuerUri") diff --git a/clientLibrary/src/main/scala/za/co/absa/loginclient/tokenRetrieval/client/TokenRetrievalClient.scala b/clientLibrary/src/main/scala/za/co/absa/loginclient/tokenRetrieval/client/TokenRetrievalClient.scala index d3f7c706..a64a9720 100644 --- a/clientLibrary/src/main/scala/za/co/absa/loginclient/tokenRetrieval/client/TokenRetrievalClient.scala +++ b/clientLibrary/src/main/scala/za/co/absa/loginclient/tokenRetrieval/client/TokenRetrievalClient.scala @@ -189,7 +189,7 @@ case class TokenRetrievalClient(host: String) { System.setProperties(properties) } - private def fetchToken(issuerUri: String, username: String, password: String): String = { + private[client] def fetchToken(issuerUri: String, username: String, password: String): String = { logger.info(s"Fetching token from $issuerUri for user $username") @@ -218,7 +218,7 @@ case class TokenRetrievalClient(host: String) { } } - private def fetchToken(issuerUri: String, keyTabLocation: Option[String], userPrincipal: Option[String]): String = { + private[client] def fetchToken(issuerUri: String, keyTabLocation: Option[String], userPrincipal: Option[String]): String = { val restTemplate: KerberosRestTemplate = (keyTabLocation, userPrincipal) match { case (Some(_), Some(_)) => diff --git a/clientLibrary/src/test/scala/za/co/absa/loginclient/authorization/JwtDecoderProviderTest.scala b/clientLibrary/src/test/scala/za/co/absa/loginclient/authorization/JwtDecoderProviderTest.scala new file mode 100644 index 00000000..a285f7e2 --- /dev/null +++ b/clientLibrary/src/test/scala/za/co/absa/loginclient/authorization/JwtDecoderProviderTest.scala @@ -0,0 +1,25 @@ +package za.co.absa.loginclient.authorization + +import io.jsonwebtoken.SignatureAlgorithm +import io.jsonwebtoken.security.Keys +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.springframework.security.oauth2.jwt.JwtDecoder + +import java.security.interfaces.RSAPublicKey +import java.util.Base64 + +class JwtDecoderProviderTest extends AnyFlatSpec with Matchers{ + + "getDecoderFromPublicKeyString" should "correctly generate a JwtDecoder from a valid RSA public key string" in { + val keyPair = Keys.keyPairFor(SignatureAlgorithm.RS256) + + val publicKey: RSAPublicKey = keyPair.getPublic.asInstanceOf[RSAPublicKey] + val publicKeyString = Base64.getEncoder.encodeToString(publicKey.getEncoded) + + val decoder: JwtDecoder = JwtDecoderProvider.getDecoderFromPublicKeyString(publicKeyString) + + decoder should not be null + } + +} diff --git a/clientLibrary/src/test/scala/za/co/absa/loginclient/publicKeyRetrieval/client/PublicKeyRetrievalClientTest.scala b/clientLibrary/src/test/scala/za/co/absa/loginclient/publicKeyRetrieval/client/PublicKeyRetrievalClientTest.scala new file mode 100644 index 00000000..c8274292 --- /dev/null +++ b/clientLibrary/src/test/scala/za/co/absa/loginclient/publicKeyRetrieval/client/PublicKeyRetrievalClientTest.scala @@ -0,0 +1,71 @@ +package za.co.absa.loginclient.publicKeyRetrieval.client + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import za.co.absa.loginclient.publicKeyRetrieval.model.PublicKey + +class PublicKeyRetrievalClientTest extends AnyFlatSpec with Matchers{ + + private val dummyUri = "https://example.com" + + class testPublicKeyRetrievalClient extends PublicKeyRetrievalClient(dummyUri) { + private val publicKeyUri = s"$dummyUri/token/public-key" + private val publicKeyJson = """{"key": "mocked-public-key"}""" + private val publicKeysUri = s"$dummyUri/token/public-keys" + private val publicKeysJson = + """{ + | "keys": [ + | { "key1": "public-key-1" }, + | { "key2": "public-key-2" } + | ] + |}""".stripMargin + private val publicKeyJwkUri = s"$dummyUri/token/public-key-jwks" + private val publicKeyJwkJson = + """{ + | "keys": [ + | { + | "kty": "RSA", + | "kid": "test-key-id", + | "n": "test-modulus", + | "e": "test-exponent" + | } + | ] + |}""".stripMargin + + override private[client] def fetchToken(issuerUri: String): String = { + Map( + publicKeyUri -> publicKeyJson, + publicKeysUri -> publicKeysJson, + publicKeyJwkUri -> publicKeyJwkJson + ).getOrElse(issuerUri, throw new IllegalArgumentException(s"Unexpected URI: $issuerUri")) + } + } + + "getPublicKey" should "return the expected PublicKey object" in { + val testClient = new testPublicKeyRetrievalClient + + val result = testClient.getPublicKey + + result shouldBe PublicKey("mocked-public-key") + } + + "getPublicKeys" should "return the expected set of PublicKey objects" in { + val testClient = new testPublicKeyRetrievalClient + + val result = testClient.getPublicKeys + + result shouldBe Set( + PublicKey("public-key-1"), + PublicKey("public-key-2") + ) + } + + "getPublicKeyJwk" should "return the expected PublicKey JWKSet object" in { + val testClient = new testPublicKeyRetrievalClient + + val result = testClient.getPublicKeyJwk + result.getKeyByKeyId("test-key-id").toString shouldBe + "{\"kty\":\"RSA\",\"e\":\"test-exponent\",\"kid\":\"test-key-id\",\"n\":\"test-modulus\"}" + } + +} diff --git a/clientLibrary/src/test/scala/za/co/absa/loginclient/tokenRetrieval/client/TokenRetrievalClientTest.scala b/clientLibrary/src/test/scala/za/co/absa/loginclient/tokenRetrieval/client/TokenRetrievalClientTest.scala new file mode 100644 index 00000000..1e2c253b --- /dev/null +++ b/clientLibrary/src/test/scala/za/co/absa/loginclient/tokenRetrieval/client/TokenRetrievalClientTest.scala @@ -0,0 +1,30 @@ +package za.co.absa.loginclient.tokenRetrieval.client + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import za.co.absa.loginclient.tokenRetrieval.model.{AccessToken, RefreshToken} + +class TokenRetrievalClientTest extends AnyFlatSpec with Matchers{ + + private val dummyUri = "https://example.com" + private val dummyUser = "exampleUser" + private val dummyPassword = "examplePassword" + private val dummyGroups = List() + + class testTokenRetrievalClient extends TokenRetrievalClient(dummyUri) { + override private[client] def fetchToken(issuerUri: String, username: String, password: String) = + """{ + | "token": "mock-access-token", + | "refresh": "mock-refresh-token" + |}""".stripMargin + } + + "fetchAccessAndRefreshToken" should "return expected tokens" in { + + val testClient = new testTokenRetrievalClient + + val (accessResult, refreshResult) = testClient.fetchAccessAndRefreshToken(dummyUser, dummyPassword, dummyGroups) + accessResult shouldBe AccessToken("mock-access-token") + refreshResult shouldBe RefreshToken("mock-refresh-token") + } +} diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 91da1df9..4a74ce3a 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -38,6 +38,7 @@ object Dependencies { lazy val jacksonModuleScala = "com.fasterxml.jackson.module" %% "jackson-module-scala" % Versions.jacksonModuleScala lazy val jacksonDatabind = "com.fasterxml.jackson.core" % "jackson-databind" % Versions.jacksonDatabind lazy val javaCompat = "org.scala-lang.modules" %% "scala-java8-compat" % Versions.javaCompat + lazy val javaCollectionConverter = "org.scala-lang.modules" %% "scala-collection-compat" % "2.11.0" lazy val springBootWeb = "org.springframework.boot" % "spring-boot-starter-web" % Versions.springBoot lazy val springBootTomcat = "org.springframework.boot" % "spring-boot-starter-tomcat" % Versions.springBoot % Provided @@ -52,6 +53,7 @@ object Dependencies { lazy val jjwtJackson = "io.jsonwebtoken" % "jjwt-jackson" % Versions.jjwt % Runtime lazy val jsonParser = "com.google.code.gson" % "gson" % "2.10.1" + lazy val cacheBuilder = "com.google.guava" % "guava" % "33.0.0-jre" lazy val jwtDecoder = "org.springframework.security" % "spring-security-oauth2-jose" % Versions.spring lazy val nimbusJoseJwt = "com.nimbusds" % "nimbus-jose-jwt" % Versions.nimbusJoseJwt @@ -116,6 +118,7 @@ object Dependencies { def clientLibDependencies: Seq[ModuleID] = Seq( javaCompat, + javaCollectionConverter, nimbusJoseJwt, jwtDecoder, @@ -125,13 +128,15 @@ object Dependencies { jjwtJackson, jsonParser, + cacheBuilder, springBootWeb, springBootSecurity, springSecurityKerberosClient, - scalaTest + scalaTest, + springBootTest ) def exampleDependencies: Seq[ModuleID] = Seq(