Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

package io.athenz.server.aws.common.store.impl;

import com.yahoo.athenz.auth.util.Crypto;
import io.athenz.server.aws.common.utils.Utils;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3ClientBuilder;
import software.amazon.awssdk.services.s3.model.*;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand All @@ -31,6 +34,10 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.TrustManagerFactory;
import java.net.URI;
import java.security.KeyStore;
import java.security.cert.X509Certificate;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
Expand All @@ -53,6 +60,8 @@ public class S3ChangeLogStore implements ChangeLogStore {

private static final String NUMBER_OF_THREADS = "athenz.zts.bucket.threads";
private static final String DEFAULT_TIMEOUT_SECONDS = "athenz.zts.bucket.threads.timeout";
private static final String ZTS_PROP_AWS_S3_ENDPOINT = "athenz.zts.aws_s3_endpoint";
private static final String ZTS_PROP_AWS_S3_CA_CERT = "athenz.zts.aws_s3_ca_cert";
private final int nThreads = Integer.parseInt(System.getProperty(NUMBER_OF_THREADS, "10"));
private final int defaultTimeoutSeconds = Integer.parseInt(System.getProperty(DEFAULT_TIMEOUT_SECONDS, "1800"));
protected Map<String, SignedDomain> tempSignedDomainMap = new ConcurrentHashMap<>();
Expand Down Expand Up @@ -497,7 +506,38 @@ S3Client getS3Client() {
throw new RuntimeException("S3ChangeLogStore: Couldn't detect AWS region");
}

return S3Client.builder().region(Region.of(awsRegion)).build();
S3ClientBuilder s3ClientBuilder = S3Client.builder().region(Region.of(awsRegion));

// check if we have a custom endpoint
String s3Endpoint = System.getProperty(ZTS_PROP_AWS_S3_ENDPOINT);
if (!StringUtil.isEmpty(s3Endpoint)) {
s3ClientBuilder.endpointOverride(URI.create(s3Endpoint));
}

// check if we have a custom ca cert
String s3CaCert = System.getProperty(ZTS_PROP_AWS_S3_CA_CERT);
if (!StringUtil.isEmpty(s3CaCert)) {
try {
ApacheHttpClient.Builder httpClientBuilder = ApacheHttpClient.builder();
X509Certificate[] certs = Crypto.loadX509Certificates(s3CaCert);
KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
keyStore.load(null, null); // Initialize empty keystore
int i = 0;
for (X509Certificate cert : certs) {
keyStore.setCertificateEntry("custom-ca-" + i++, cert);
}
TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
tmf.init(keyStore);

httpClientBuilder.tlsTrustManagersProvider(tmf::getTrustManagers);
s3ClientBuilder.httpClient(httpClientBuilder.build());
} catch (Exception ex) {
LOGGER.error("S3ChangeLogStore: unable to load custom ca cert: {}", s3CaCert, ex);
throw new RuntimeException("S3ChangeLogStore: unable to load custom ca cert");
}
}
Comment on lines +519 to +538
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for handling custom CA certificates is nearly identical to the one in S3ClientFactory.java. To improve maintainability and reduce code duplication, consider extracting this logic into a shared utility method. This method could take the caCertPath and return a configured TlsTrustManagersProvider or directly configure an ApacheHttpClient.Builder.


return s3ClientBuilder.build();
}

public ExecutorService getExecutorService() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,26 @@

package io.athenz.syncer.aws.common.impl;

import com.yahoo.athenz.auth.util.Crypto;
import io.athenz.syncer.common.zms.Config;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3ClientBuilder;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.core.checksums.RequestChecksumCalculation;
import software.amazon.awssdk.core.checksums.ResponseChecksumValidation;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
import software.amazon.awssdk.services.s3.model.HeadBucketRequest;

import javax.net.ssl.TrustManagerFactory;
import java.net.URI;
import java.security.KeyStore;
import java.security.cert.X509Certificate;
import java.time.Duration;

public class S3ClientFactory {
Expand Down Expand Up @@ -71,31 +79,56 @@ public static S3Client getS3Client() throws Exception {
}
}

SdkHttpClient apacheHttpClient = ApacheHttpClient.builder()
ApacheHttpClient.Builder httpClientBuilder = ApacheHttpClient.builder()
.connectionTimeout(Duration.ofMillis(connectionTimeout))
.socketTimeout(Duration.ofMillis(requestTimeout))
.build();
.socketTimeout(Duration.ofMillis(requestTimeout));

final String caCertPath = Config.getInstance().getConfigParam(Config.SYNC_CFG_PARAM_AWS_S3_CA_CERT);
if (!Config.isEmpty(caCertPath)) {
X509Certificate[] certs = Crypto.loadX509Certificates(caCertPath);
KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
keyStore.load(null, null); // Initialize empty keystore
int i = 0;
for (X509Certificate cert : certs) {
keyStore.setCertificateEntry("custom-ca-" + i++, cert);
}
TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
tmf.init(keyStore);

httpClientBuilder.tlsTrustManagersProvider(tmf::getTrustManagers);
}
Comment on lines +87 to +99
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for handling custom CA certificates is duplicated in S3ChangeLogStore.java. To avoid code duplication and improve maintainability, this logic should be extracted into a common utility method.


SdkHttpClient apacheHttpClient = httpClientBuilder.build();

S3ClientBuilder s3ClientBuilder = S3Client.builder()
.httpClient(apacheHttpClient)
.region(getRegion());

// Enable checksum calculation and validation if configured
final String checksumValidation = Config.getInstance().getConfigParam(Config.SYNC_CFG_PARAM_AWS_S3_CHECKSUM_VALIDATION);
if (!Config.isEmpty(checksumValidation) && Boolean.parseBoolean(checksumValidation)) {
s3ClientBuilder
.requestChecksumCalculation(RequestChecksumCalculation.WHEN_REQUIRED)
.responseChecksumValidation(ResponseChecksumValidation.WHEN_REQUIRED);
LOGGER.debug("S3 checksum calculation and validation enabled");
}

final String awsS3Endpoint = Config.getInstance().getConfigParam(Config.SYNC_CFG_PARAM_AWS_S3_ENDPOINT);
if (!Config.isEmpty(awsS3Endpoint)) {
s3ClientBuilder.endpointOverride(URI.create(awsS3Endpoint));
}

S3Client s3client;
final String awsKeyId = Config.getInstance().getConfigParam(Config.SYNC_CFG_PARAM_AWS_KEY_ID);
final String awsAccKey = Config.getInstance().getConfigParam(Config.SYNC_CFG_PARAM_AWS_ACCESS_KEY);
if (!Config.isEmpty(awsKeyId) && !Config.isEmpty(awsAccKey)) {
AwsBasicCredentials awsCreds = AwsBasicCredentials.builder()
.accessKeyId(awsKeyId).secretAccessKey(awsAccKey).build();
StaticCredentialsProvider credentialsProvider = StaticCredentialsProvider.create(awsCreds);

s3client = S3Client.builder()
.credentialsProvider(credentialsProvider)
.httpClient(apacheHttpClient)
.region(getRegion())
.build();
} else {
s3client = S3Client.builder()
.httpClient(apacheHttpClient)
.region(getRegion())
.build();
s3ClientBuilder.credentialsProvider(credentialsProvider);
}

S3Client s3client = s3ClientBuilder.build();

verifyBucketExist(s3client, bucket);

LOGGER.debug("success: using bucket: {}", bucket);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,27 @@
import static org.testng.Assert.*;

import java.io.*;
import java.net.URI;
import java.security.cert.X509Certificate;
import java.util.*;
import java.util.concurrent.TimeUnit;

import com.yahoo.athenz.auth.util.Crypto;
import com.yahoo.athenz.zms.JWSDomain;
import com.yahoo.rdl.Timestamp;
import org.mockito.ArgumentCaptor;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.TlsTrustManagersProvider;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3ClientBuilder;
import software.amazon.awssdk.services.s3.model.*;
import com.yahoo.athenz.zms.DomainData;
import com.yahoo.athenz.zms.SignedDomain;
Expand Down Expand Up @@ -641,6 +651,56 @@ public void testGetS3Client() {
assertNotNull(s3Client);
}

@Test
public void testGetS3ClientWithCustomEndpointAndCaCert() throws Exception {
System.setProperty(ZTS_PROP_AWS_BUCKET_NAME, "test-bucket");
System.setProperty(ZTS_PROP_AWS_REGION_NAME, "us-west-2");
System.setProperty("athenz.zts.aws_s3_endpoint", "https://custom.s3.endpoint");
System.setProperty("athenz.zts.aws_s3_ca_cert", "src/test/resources/dummy_ca.pem");

// Mocks
try (MockedStatic<ApacheHttpClient> mockHttpClientStatic = Mockito.mockStatic(ApacheHttpClient.class);
MockedStatic<S3Client> mockS3ClientStatic = Mockito.mockStatic(S3Client.class);
MockedStatic<Crypto> mockCryptoStatic = Mockito.mockStatic(Crypto.class)) {

// Mock Crypto
mockCryptoStatic.when(() -> Crypto.loadX509Certificates(any(String.class))).thenReturn(new X509Certificate[]{mock(X509Certificate.class)});

// Mock ApacheHttpClient builder
ApacheHttpClient.Builder mockHttpBuilder = mock(ApacheHttpClient.Builder.class);
SdkHttpClient mockHttpClient = mock(SdkHttpClient.class);

mockHttpClientStatic.when(ApacheHttpClient::builder).thenReturn(mockHttpBuilder);
when(mockHttpBuilder.tlsTrustManagersProvider(any(TlsTrustManagersProvider.class))).thenReturn(mockHttpBuilder);
when(mockHttpBuilder.build()).thenReturn(mockHttpClient);

// Mock S3Client builder
S3ClientBuilder mockS3Builder = mock(S3ClientBuilder.class);
S3Client mockS3Client = mock(S3Client.class);

mockS3ClientStatic.when(S3Client::builder).thenReturn(mockS3Builder);
when(mockS3Builder.region(any(Region.class))).thenReturn(mockS3Builder);
when(mockS3Builder.endpointOverride(any(URI.class))).thenReturn(mockS3Builder);
when(mockS3Builder.httpClient(any(SdkHttpClient.class))).thenReturn(mockS3Builder);
when(mockS3Builder.build()).thenReturn(mockS3Client);

S3ChangeLogStore store = new S3ChangeLogStore();
S3Client client = store.getS3Client();
assertNotNull(client);

// Verify ApacheHttpClient configured with TrustManager
Mockito.verify(mockHttpBuilder).tlsTrustManagersProvider(any(TlsTrustManagersProvider.class));

// Verify S3Client configured with Endpoint Override
ArgumentCaptor<URI> uriCaptor = ArgumentCaptor.forClass(URI.class);
Mockito.verify(mockS3Builder).endpointOverride(uriCaptor.capture());
assertEquals(uriCaptor.getValue().toString(), "https://custom.s3.endpoint");
} finally {
System.clearProperty("athenz.zts.aws_s3_endpoint");
System.clearProperty("athenz.zts.aws_s3_ca_cert");
}
}

@Test
public void initNoRegionException() {
System.clearProperty(ZTS_PROP_AWS_REGION_NAME);
Expand Down
Loading
Loading