diff --git a/src/main/java/com/uid2/core/service/JWTTokenProvider.java b/src/main/java/com/uid2/core/service/JWTTokenProvider.java index 81b607e..c8d05ea 100644 --- a/src/main/java/com/uid2/core/service/JWTTokenProvider.java +++ b/src/main/java/com/uid2/core/service/JWTTokenProvider.java @@ -13,6 +13,7 @@ import java.util.Base64; import java.util.Map; import java.util.Optional; +import java.util.function.Supplier; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.WebIdentityTokenFileCredentialsProvider; @@ -32,13 +33,12 @@ public class JWTTokenProvider { private static final Logger LOGGER = LoggerFactory.getLogger(JWTTokenProvider.class); private static final Base64.Encoder encoder = Base64.getUrlEncoder().withoutPadding(); - + private final Supplier kmsClientBuilderSupplier; private final JsonObject config; - private final KmsClientBuilder kmsClientBuilder; - public JWTTokenProvider(JsonObject config, KmsClientBuilder clientBuilder) { + public JWTTokenProvider(JsonObject config, Supplier kmsClientBuilderSupplier) { this.config = config; - this.kmsClientBuilder = clientBuilder; + this.kmsClientBuilderSupplier = kmsClientBuilderSupplier; } public String getJWT(Instant expiresAt, Instant issuedAt, Map customClaims) throws JwtSigningException { @@ -64,7 +64,7 @@ public String getJWT(Instant expiresAt, Instant issuedAt, Map he KmsClient client = null; try { - client = getKmsClient(this.kmsClientBuilder, this.config); + client = getKmsClient(this.kmsClientBuilderSupplier.get(), this.config); } catch (URISyntaxException e) { throw new JwtSigningException(Optional.of("Unable to get KMS Client"), e); } diff --git a/src/main/java/com/uid2/core/service/OperatorJWTTokenProvider.java b/src/main/java/com/uid2/core/service/OperatorJWTTokenProvider.java index 4c09e41..39f4d39 100644 --- a/src/main/java/com/uid2/core/service/OperatorJWTTokenProvider.java +++ b/src/main/java/com/uid2/core/service/OperatorJWTTokenProvider.java @@ -24,7 +24,7 @@ public class OperatorJWTTokenProvider { private final Clock clock; public OperatorJWTTokenProvider(JsonObject config) { - this(config, new JWTTokenProvider(config, KmsClient.builder()), Clock.systemUTC()); + this(config, new JWTTokenProvider(config, KmsClient::builder), Clock.systemUTC()); } public OperatorJWTTokenProvider(JsonObject config, JWTTokenProvider jwtTokenProvider, Clock clock) { diff --git a/src/test/java/com/uid2/core/service/JWTTokenProviderTest.java b/src/test/java/com/uid2/core/service/JWTTokenProviderTest.java index 0b84b37..8a00aeb 100644 --- a/src/test/java/com/uid2/core/service/JWTTokenProviderTest.java +++ b/src/test/java/com/uid2/core/service/JWTTokenProviderTest.java @@ -57,7 +57,7 @@ void getJwtReturnsValidToken() throws JWTTokenProvider.JwtSigningException { content.put("iss", "issuer"); var builder = getBuilder(true, "TestSignature"); - JWTTokenProvider provider = new JWTTokenProvider(config, builder); + JWTTokenProvider provider = new JWTTokenProvider(config, () -> builder); Instant i = Clock.systemUTC().instant(); @@ -84,7 +84,7 @@ void getJwtReturnsValidToken() throws JWTTokenProvider.JwtSigningException { void getJwtEmptySignatureThrowsException() { var builder = getBuilder(false, ""); - JWTTokenProvider provider = new JWTTokenProvider(config, builder); + JWTTokenProvider provider = new JWTTokenProvider(config, () -> builder); JWTTokenProvider.JwtSigningException e = assertThrows( JWTTokenProvider.JwtSigningException.class, @@ -97,7 +97,7 @@ void getJwtEmptySignatureThrowsException() { void getJwtEmptySignatureEmptyResponseText() { var builder = getBuilder(false, "", Optional.empty()); - JWTTokenProvider provider = new JWTTokenProvider(config, builder); + JWTTokenProvider provider = new JWTTokenProvider(config, () -> builder); JWTTokenProvider.JwtSigningException e = assertThrows( JWTTokenProvider.JwtSigningException.class, @@ -110,7 +110,7 @@ void getJwtEmptySignatureEmptyResponseText() { void getJwtEmptySignatureNullResponseText() { var builder = getBuilder(false, "", null); - JWTTokenProvider provider = new JWTTokenProvider(config, builder); + JWTTokenProvider provider = new JWTTokenProvider(config, () -> builder); JWTTokenProvider.JwtSigningException e = assertThrows( JWTTokenProvider.JwtSigningException.class, @@ -123,7 +123,7 @@ void getJwtEmptySignatureNullResponseText() { void getJwtSignatureThrowsKmsException() { var builder = getBuilder(false, "", Optional.empty()); - JWTTokenProvider provider = new JWTTokenProvider(config, builder); + JWTTokenProvider provider = new JWTTokenProvider(config, () -> builder); var ex = KmsException.builder().message("Test Error").build(); when(mockClient.sign(capturedSignRequest.capture())).thenThrow(ex); @@ -144,7 +144,7 @@ void getJwtMissingKeyInConfig() throws IOException { var builder = getBuilder(false, "", Optional.empty()); - JWTTokenProvider provider = new JWTTokenProvider(config, builder); + JWTTokenProvider provider = new JWTTokenProvider(config, () -> builder); JWTTokenProvider.JwtSigningException e = assertThrows( JWTTokenProvider.JwtSigningException.class,