Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/main/java/com/uid2/core/service/JWTTokenProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.Base64;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.WebIdentityTokenFileCredentialsProvider;
Expand All @@ -32,13 +33,12 @@
public class JWTTokenProvider {
private static final Logger LOGGER = LoggerFactory.getLogger(JWTTokenProvider.class);
private static final Base64.Encoder encoder = Base64.getUrlEncoder().withoutPadding();

private final Supplier<KmsClientBuilder> kmsClientBuilderSupplier;
private final JsonObject config;
private final KmsClientBuilder kmsClientBuilder;

public JWTTokenProvider(JsonObject config, KmsClientBuilder clientBuilder) {
public JWTTokenProvider(JsonObject config, Supplier<KmsClientBuilder> kmsClientBuilderSupplier) {
this.config = config;
this.kmsClientBuilder = clientBuilder;
this.kmsClientBuilderSupplier = kmsClientBuilderSupplier;
}

public String getJWT(Instant expiresAt, Instant issuedAt, Map<String, String> customClaims) throws JwtSigningException {
Expand All @@ -64,7 +64,7 @@ public String getJWT(Instant expiresAt, Instant issuedAt, Map<String, String> he

KmsClient client = null;
try {
client = getKmsClient(this.kmsClientBuilder, this.config);
client = getKmsClient(this.kmsClientBuilderSupplier.get(), this.config);
} catch (URISyntaxException e) {
throw new JwtSigningException(Optional.of("Unable to get KMS Client"), e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public class OperatorJWTTokenProvider {
private final Clock clock;

public OperatorJWTTokenProvider(JsonObject config) {
this(config, new JWTTokenProvider(config, KmsClient.builder()), Clock.systemUTC());
this(config, new JWTTokenProvider(config, KmsClient::builder), Clock.systemUTC());
}

public OperatorJWTTokenProvider(JsonObject config, JWTTokenProvider jwtTokenProvider, Clock clock) {
Expand Down
12 changes: 6 additions & 6 deletions src/test/java/com/uid2/core/service/JWTTokenProviderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void getJwtReturnsValidToken() throws JWTTokenProvider.JwtSigningException {
content.put("iss", "issuer");

var builder = getBuilder(true, "TestSignature");
JWTTokenProvider provider = new JWTTokenProvider(config, builder);
JWTTokenProvider provider = new JWTTokenProvider(config, () -> builder);

Instant i = Clock.systemUTC().instant();

Expand All @@ -84,7 +84,7 @@ void getJwtReturnsValidToken() throws JWTTokenProvider.JwtSigningException {
void getJwtEmptySignatureThrowsException() {
var builder = getBuilder(false, "");

JWTTokenProvider provider = new JWTTokenProvider(config, builder);
JWTTokenProvider provider = new JWTTokenProvider(config, () -> builder);

JWTTokenProvider.JwtSigningException e = assertThrows(
JWTTokenProvider.JwtSigningException.class,
Expand All @@ -97,7 +97,7 @@ void getJwtEmptySignatureThrowsException() {
void getJwtEmptySignatureEmptyResponseText() {
var builder = getBuilder(false, "", Optional.empty());

JWTTokenProvider provider = new JWTTokenProvider(config, builder);
JWTTokenProvider provider = new JWTTokenProvider(config, () -> builder);

JWTTokenProvider.JwtSigningException e = assertThrows(
JWTTokenProvider.JwtSigningException.class,
Expand All @@ -110,7 +110,7 @@ void getJwtEmptySignatureEmptyResponseText() {
void getJwtEmptySignatureNullResponseText() {
var builder = getBuilder(false, "", null);

JWTTokenProvider provider = new JWTTokenProvider(config, builder);
JWTTokenProvider provider = new JWTTokenProvider(config, () -> builder);

JWTTokenProvider.JwtSigningException e = assertThrows(
JWTTokenProvider.JwtSigningException.class,
Expand All @@ -123,7 +123,7 @@ void getJwtEmptySignatureNullResponseText() {
void getJwtSignatureThrowsKmsException() {
var builder = getBuilder(false, "", Optional.empty());

JWTTokenProvider provider = new JWTTokenProvider(config, builder);
JWTTokenProvider provider = new JWTTokenProvider(config, () -> builder);
var ex = KmsException.builder().message("Test Error").build();
when(mockClient.sign(capturedSignRequest.capture())).thenThrow(ex);

Expand All @@ -144,7 +144,7 @@ void getJwtMissingKeyInConfig() throws IOException {

var builder = getBuilder(false, "", Optional.empty());

JWTTokenProvider provider = new JWTTokenProvider(config, builder);
JWTTokenProvider provider = new JWTTokenProvider(config, () -> builder);

JWTTokenProvider.JwtSigningException e = assertThrows(
JWTTokenProvider.JwtSigningException.class,
Expand Down