diff --git a/api/src/main/java/io/grpc/TlsChannelCredentials.java b/api/src/main/java/io/grpc/TlsChannelCredentials.java index b58048be0b2..83dc36ff325 100644 --- a/api/src/main/java/io/grpc/TlsChannelCredentials.java +++ b/api/src/main/java/io/grpc/TlsChannelCredentials.java @@ -26,6 +26,7 @@ import java.util.Collections; import java.util.EnumSet; import java.util.List; +import java.util.Map; import java.util.Set; import javax.net.ssl.KeyManager; import javax.net.ssl.TrustManager; @@ -49,6 +50,7 @@ public static ChannelCredentials create() { private final List keyManagers; private final byte[] rootCertificates; private final List trustManagers; + private final Map customCertificatesConfig; TlsChannelCredentials(Builder builder) { fakeFeature = builder.fakeFeature; @@ -58,6 +60,7 @@ public static ChannelCredentials create() { keyManagers = builder.keyManagers; rootCertificates = builder.rootCertificates; trustManagers = builder.trustManagers; + customCertificatesConfig = builder.customCertificatesConfig; } /** @@ -121,6 +124,21 @@ public List getTrustManagers() { return trustManagers; } + /** + * Returns custom certificates config. It contains following entries: + * + *
    + *
  • {@code "ca_certificate_file"} key containing the path to the root certificate file
  • + *
  • {@code "certificate_file"} key containing the path to the identity certificate file
  • + *
  • {@code "private_key_file"} key containing the path to the private key
  • + *
  • {@code "refresh_interval"} key specifying the frequency of updates to the above + * files
  • + *
+ */ + public Map getCustomCertificatesConfig() { + return customCertificatesConfig; + } + /** * Returns an empty set if this credential can be adequately understood via * the features listed, otherwise returns a hint of features that are lacking @@ -228,6 +246,7 @@ public static final class Builder { private List keyManagers; private byte[] rootCertificates; private List trustManagers; + private Map customCertificatesConfig; private Builder() {} @@ -355,6 +374,11 @@ public Builder trustManager(TrustManager... trustManagers) { return this; } + public Builder customCertificatesConfig(Map customCertificatesConfig) { + this.customCertificatesConfig = customCertificatesConfig; + return this; + } + private void clearTrustManagers() { this.rootCertificates = null; this.trustManagers = null; diff --git a/xds/src/main/java/io/grpc/xds/client/BootstrapperImpl.java b/xds/src/main/java/io/grpc/xds/client/BootstrapperImpl.java index c00685f1781..e7c4527a3a2 100644 --- a/xds/src/main/java/io/grpc/xds/client/BootstrapperImpl.java +++ b/xds/src/main/java/io/grpc/xds/client/BootstrapperImpl.java @@ -21,6 +21,7 @@ import com.google.common.collect.ImmutableMap; import io.grpc.Internal; import io.grpc.InternalLogId; +import io.grpc.TlsChannelCredentials; import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil.GrpcBuildVersion; import io.grpc.internal.JsonParser; @@ -162,9 +163,9 @@ protected BootstrapInfo.Builder bootstrapBuilder(Map rawData) builder.node(nodeBuilder.build()); Map certProvidersBlob = JsonUtil.getObject(rawData, "certificate_providers"); + Map certProviders = new HashMap<>(); if (certProvidersBlob != null) { logger.log(XdsLogLevel.INFO, "Configured with {0} cert providers", certProvidersBlob.size()); - Map certProviders = new HashMap<>(certProvidersBlob.size()); for (String name : certProvidersBlob.keySet()) { Map valueMap = JsonUtil.getObject(certProvidersBlob, name); String pluginName = @@ -175,6 +176,21 @@ protected BootstrapInfo.Builder bootstrapBuilder(Map rawData) CertificateProviderInfo.create(pluginName, config); certProviders.put(name, certificateProviderInfo); } + } + + for (ServerInfo serverInfo : servers) { + Object creds = serverInfo.implSpecificConfig(); + if (creds instanceof TlsChannelCredentials) { + Map config = ((TlsChannelCredentials)creds).getCustomCertificatesConfig(); + if (config != null) { + CertificateProviderInfo certificateProviderInfo = + CertificateProviderInfo.create("file_watcher", config); + certProviders.put("mtls_channel_creds_identity_certs", certificateProviderInfo); + } + } + } + + if (!certProviders.isEmpty()) { builder.certProviders(certProviders); } diff --git a/xds/src/main/java/io/grpc/xds/internal/TlsXdsCredentialsProvider.java b/xds/src/main/java/io/grpc/xds/internal/TlsXdsCredentialsProvider.java index f4d26a83795..797495670e5 100644 --- a/xds/src/main/java/io/grpc/xds/internal/TlsXdsCredentialsProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/TlsXdsCredentialsProvider.java @@ -18,7 +18,10 @@ import io.grpc.ChannelCredentials; import io.grpc.TlsChannelCredentials; +import io.grpc.internal.JsonUtil; import io.grpc.xds.XdsCredentialsProvider; +import java.io.File; +import java.io.IOException; import java.util.Map; /** @@ -30,7 +33,42 @@ public final class TlsXdsCredentialsProvider extends XdsCredentialsProvider { @Override protected ChannelCredentials newChannelCredentials(Map jsonConfig) { - return TlsChannelCredentials.create(); + TlsChannelCredentials.Builder builder = TlsChannelCredentials.newBuilder(); + + if (jsonConfig == null) { + return builder.build(); + } + + // use trust certificate file path from bootstrap config if provided; else use system default + String rootCertPath = JsonUtil.getString(jsonConfig, "ca_certificate_file"); + if (rootCertPath != null) { + try { + builder.trustManager(new File(rootCertPath)); + } catch (IOException e) { + return null; + } + } + + // use certificate chain and private key file paths from bootstrap config if provided. Mind that + // both JSON values must be either set (mTLS case) or both unset (TLS case) + String certChainPath = JsonUtil.getString(jsonConfig, "certificate_file"); + String privateKeyPath = JsonUtil.getString(jsonConfig, "private_key_file"); + if (certChainPath != null && privateKeyPath != null) { + try { + builder.keyManager(new File(certChainPath), new File(privateKeyPath)); + } catch (IOException e) { + return null; + } + } else if (certChainPath != null || privateKeyPath != null) { + return null; + } + + // save json config when custom certificate paths were provided in a bootstrap + if (rootCertPath != null || certChainPath != null) { + builder.customCertificatesConfig(jsonConfig); + } + + return builder.build(); } @Override diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java index 304124cc7f2..747c76e260a 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java @@ -73,10 +73,12 @@ final class FileWatcherCertificateProvider extends CertificateProvider implement this.scheduledExecutorService = checkNotNull(scheduledExecutorService, "scheduledExecutorService"); this.timeProvider = checkNotNull(timeProvider, "timeProvider"); - this.certFile = Paths.get(checkNotNull(certFile, "certFile")); - this.keyFile = Paths.get(checkNotNull(keyFile, "keyFile")); - checkArgument((trustFile != null || spiffeTrustMapFile != null), - "either trustFile or spiffeTrustMapFile must be present"); + checkArgument(((certFile != null) == (keyFile != null)), + "certFile and keyFile must be both set or both unset"); + this.certFile = certFile == null ? null : Paths.get(certFile); + this.keyFile = keyFile == null ? null : Paths.get(keyFile); + checkArgument((trustFile != null || spiffeTrustMapFile != null || keyFile != null), + "must be watching either root or identity certificates"); if (spiffeTrustMapFile != null) { this.spiffeTrustMapFile = Paths.get(spiffeTrustMapFile); this.trustFile = null; @@ -113,23 +115,26 @@ private synchronized void scheduleNextRefreshCertificate(long delayInSeconds) { void checkAndReloadCertificates() { try { try { - FileTime currentCertTime = Files.getLastModifiedTime(certFile); - FileTime currentKeyTime = Files.getLastModifiedTime(keyFile); - if (!currentCertTime.equals(lastModifiedTimeCert) - && !currentKeyTime.equals(lastModifiedTimeKey)) { - byte[] certFileContents = Files.readAllBytes(certFile); - byte[] keyFileContents = Files.readAllBytes(keyFile); - FileTime currentCertTime2 = Files.getLastModifiedTime(certFile); - FileTime currentKeyTime2 = Files.getLastModifiedTime(keyFile); - if (currentCertTime2.equals(currentCertTime) && currentKeyTime2.equals(currentKeyTime)) { - try (ByteArrayInputStream certStream = new ByteArrayInputStream(certFileContents); - ByteArrayInputStream keyStream = new ByteArrayInputStream(keyFileContents)) { - PrivateKey privateKey = CertificateUtils.getPrivateKey(keyStream); - X509Certificate[] certs = CertificateUtils.toX509Certificates(certStream); - getWatcher().updateCertificate(privateKey, Arrays.asList(certs)); + if (certFile != null && keyFile != null) { + FileTime currentCertTime = Files.getLastModifiedTime(certFile); + FileTime currentKeyTime = Files.getLastModifiedTime(keyFile); + if (!currentCertTime.equals(lastModifiedTimeCert) + && !currentKeyTime.equals(lastModifiedTimeKey)) { + byte[] certFileContents = Files.readAllBytes(certFile); + byte[] keyFileContents = Files.readAllBytes(keyFile); + FileTime currentCertTime2 = Files.getLastModifiedTime(certFile); + FileTime currentKeyTime2 = Files.getLastModifiedTime(keyFile); + if (currentCertTime2.equals(currentCertTime) + && currentKeyTime2.equals(currentKeyTime)) { + try (ByteArrayInputStream certStream = new ByteArrayInputStream(certFileContents); + ByteArrayInputStream keyStream = new ByteArrayInputStream(keyFileContents)) { + PrivateKey privateKey = CertificateUtils.getPrivateKey(keyStream); + X509Certificate[] certs = CertificateUtils.toX509Certificates(certStream); + getWatcher().updateCertificate(privateKey, Arrays.asList(certs)); + } + lastModifiedTimeCert = currentCertTime; + lastModifiedTimeKey = currentKeyTime; } - lastModifiedTimeCert = currentCertTime; - lastModifiedTimeKey = currentKeyTime; } } } catch (Throwable t) { diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java index e4871dc4c84..b191aac023b 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java @@ -17,7 +17,6 @@ package io.grpc.xds.internal.security.certprovider; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -94,30 +93,27 @@ public CertificateProvider createCertificateProvider( timeProvider); } - private static String checkForNullAndGet(Map map, String key) { - return checkNotNull(JsonUtil.getString(map, key), "'" + key + "' is required in the config"); - } - private static Config validateAndTranslateConfig(Object config) { checkArgument(config instanceof Map, "Only Map supported for config"); @SuppressWarnings("unchecked") Map map = (Map)config; Config configObj = new Config(); - configObj.certFile = checkForNullAndGet(map, CERT_FILE_KEY); - configObj.keyFile = checkForNullAndGet(map, KEY_FILE_KEY); - if (enableSpiffe) { - if (!map.containsKey(ROOT_FILE_KEY) && !map.containsKey(SPIFFE_TRUST_MAP_FILE_KEY)) { - throw new NullPointerException( - String.format("either '%s' or '%s' is required in the config", - ROOT_FILE_KEY, SPIFFE_TRUST_MAP_FILE_KEY)); - } - if (map.containsKey(SPIFFE_TRUST_MAP_FILE_KEY)) { - configObj.spiffeTrustMapFile = JsonUtil.getString(map, SPIFFE_TRUST_MAP_FILE_KEY); - } else { - configObj.rootFile = JsonUtil.getString(map, ROOT_FILE_KEY); - } + configObj.certFile = JsonUtil.getString(map, CERT_FILE_KEY); + configObj.keyFile = JsonUtil.getString(map, KEY_FILE_KEY); + if ((configObj.certFile != null) != (configObj.keyFile != null)) { + throw new NullPointerException( + String.format("'%s' and '%s' must be both set or both unset", + CERT_FILE_KEY, KEY_FILE_KEY)); + } + if (!map.containsKey(ROOT_FILE_KEY) + && !map.containsKey(CERT_FILE_KEY) + && (!enableSpiffe || !map.containsKey(SPIFFE_TRUST_MAP_FILE_KEY))) { + throw new NullPointerException("must be watching either root or identity certificates"); + } + if (enableSpiffe && map.containsKey(SPIFFE_TRUST_MAP_FILE_KEY)) { + configObj.spiffeTrustMapFile = JsonUtil.getString(map, SPIFFE_TRUST_MAP_FILE_KEY); } else { - configObj.rootFile = checkForNullAndGet(map, ROOT_FILE_KEY); + configObj.rootFile = JsonUtil.getString(map, ROOT_FILE_KEY); } String refreshIntervalString = JsonUtil.getString(map, REFRESH_INTERVAL_KEY); if (refreshIntervalString != null) { diff --git a/xds/src/test/java/io/grpc/xds/GrpcBootstrapperImplTest.java b/xds/src/test/java/io/grpc/xds/GrpcBootstrapperImplTest.java index 3f93cc6f191..ff781b4eddd 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcBootstrapperImplTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcBootstrapperImplTest.java @@ -17,6 +17,9 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; @@ -28,9 +31,11 @@ import io.grpc.TlsChannelCredentials; import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil.GrpcBuildVersion; +import io.grpc.internal.testing.TestUtils; import io.grpc.xds.client.Bootstrapper; import io.grpc.xds.client.Bootstrapper.AuthorityInfo; import io.grpc.xds.client.Bootstrapper.BootstrapInfo; +import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.client.Bootstrapper.ServerInfo; import io.grpc.xds.client.BootstrapperImpl; import io.grpc.xds.client.CommonBootstrapperTestUtils; @@ -898,6 +903,54 @@ public void badFederationConfig() { } } + @Test + public void parseTlsChannelCredentialsWithCustomCertificatesConfig() + throws XdsInitializationException, IOException { + String rootCertPath = TestUtils.loadCert(CA_PEM_FILE).getAbsolutePath(); + String certChainPath = TestUtils.loadCert(CLIENT_PEM_FILE).getAbsolutePath(); + String privateKeyPath = TestUtils.loadCert(CLIENT_KEY_FILE).getAbsolutePath(); + + String rawData = "{\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\n" + + " \"type\": \"tls\"," + + " \"config\": {\n" + + " \"ca_certificate_file\": \"" + rootCertPath + "\",\n" + + " \"certificate_file\": \"" + certChainPath + "\",\n" + + " \"private_key_file\": \"" + privateKeyPath + "\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " ]\n" + + "}"; + + bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); + BootstrapInfo info = bootstrapper.bootstrap(); + assertThat(info.servers()).hasSize(1); + ServerInfo serverInfo = Iterables.getOnlyElement(info.servers()); + assertThat(serverInfo.target()).isEqualTo(SERVER_URI); + assertThat(serverInfo.implSpecificConfig()).isInstanceOf(TlsChannelCredentials.class); + assertThat(info.node()).isEqualTo(getNodeBuilder().build()); + + assertThat(info.certProviders()).hasSize(1); + ImmutableMap certProviderInfo = info.certProviders(); + assertThat(certProviderInfo.keySet()).containsExactly("mtls_channel_creds_identity_certs"); + CertificateProviderInfo mtlsChannelCredCertProviderInfo = + certProviderInfo.get("mtls_channel_creds_identity_certs"); + assertThat(mtlsChannelCredCertProviderInfo.config().keySet()) + .containsExactly("ca_certificate_file", "certificate_file", "private_key_file"); + assertThat(mtlsChannelCredCertProviderInfo.config().get("ca_certificate_file")) + .isEqualTo(rootCertPath); + assertThat(mtlsChannelCredCertProviderInfo.config().get("certificate_file")) + .isEqualTo(certChainPath); + assertThat(mtlsChannelCredCertProviderInfo.config().get("private_key_file")) + .isEqualTo(privateKeyPath); + } + private static BootstrapperImpl.FileReader createFileReader( final String expectedPath, final String rawData) { return new BootstrapperImpl.FileReader() { diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index 23068d665bf..dd3ba9c8a5f 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -51,8 +51,10 @@ import io.grpc.Status; import io.grpc.StatusOr; import io.grpc.StatusRuntimeException; +import io.grpc.TlsServerCredentials; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.TlsTesting; import io.grpc.testing.protobuf.SimpleRequest; import io.grpc.testing.protobuf.SimpleResponse; import io.grpc.testing.protobuf.SimpleServiceGrpc; @@ -513,6 +515,36 @@ public void mtlsClientServer_changeServerContext_expectException() } } + @Test + public void mtlsClientServer_withClientAuthentication_withTlsChannelCredsFromBootstrap() + throws Exception { + final String mtlsCertProviderInstanceName = "mtls_channel_creds_identity_certs"; + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoWithMTlsChannelCredsAndBuildUpstreamTlsContext( + mtlsCertProviderInstanceName, CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); + + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoWithMTlsChannelCredsAndBuildDownstreamTlsContext( + mtlsCertProviderInstanceName, SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE); + + ServerCredentials serverCreds = TlsServerCredentials.newBuilder() + .keyManager(TlsTesting.loadCert(SERVER_1_PEM_FILE), TlsTesting.loadCert(SERVER_1_KEY_FILE)) + .trustManager(TlsTesting.loadCert(CA_PEM_FILE)) + .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE) + .build(); + + buildServer( + XdsServerBuilder.forPort(0, serverCreds) + .xdsClientPoolFactory(fakePoolFactory) + .addService(new SimpleServiceImpl()), + downstreamTlsContext); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, OVERRIDE_AUTHORITY); + assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); + } + private void performMtlsTestAndGetListenerWatcher( UpstreamTlsContext upstreamTlsContext, String certInstanceName2, String privateKey2, String cert2, String trustCa2) @@ -573,6 +605,22 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSys .build()); } + private UpstreamTlsContext setBootstrapInfoWithMTlsChannelCredsAndBuildUpstreamTlsContext( + String instanceName, String clientKeyFile, String clientPemFile, String caCertFile) { + bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfoForMTlsChannelCredentialServerInfo( + instanceName, clientKeyFile, clientPemFile, caCertFile); + return CommonTlsContextTestsUtil.buildUpstreamTlsContext(instanceName, true); + } + + private DownstreamTlsContext setBootstrapInfoWithMTlsChannelCredsAndBuildDownstreamTlsContext( + String instanceName, String serverKeyFile, String serverPemFile, String caCertFile) { + bootstrapInfoForServer = CommonBootstrapperTestUtils + .buildBootstrapInfoForMTlsChannelCredentialServerInfo( + instanceName, serverKeyFile, serverPemFile, caCertFile); + return CommonTlsContextTestsUtil.buildDownstreamTlsContext(instanceName, true, true); + } + private void buildServerWithTlsContext(DownstreamTlsContext downstreamTlsContext) throws Exception { buildServerWithTlsContext(downstreamTlsContext, InsecureServerCredentials.create()); diff --git a/xds/src/test/java/io/grpc/xds/client/CommonBootstrapperTestUtils.java b/xds/src/test/java/io/grpc/xds/client/CommonBootstrapperTestUtils.java index 485970741c1..3eb783eed3f 100644 --- a/xds/src/test/java/io/grpc/xds/client/CommonBootstrapperTestUtils.java +++ b/xds/src/test/java/io/grpc/xds/client/CommonBootstrapperTestUtils.java @@ -20,12 +20,14 @@ import com.google.common.collect.ImmutableMap; import io.grpc.ChannelCredentials; import io.grpc.InsecureChannelCredentials; +import io.grpc.TlsChannelCredentials; import io.grpc.internal.BackoffPolicy; import io.grpc.internal.FakeClock; import io.grpc.internal.JsonParser; import io.grpc.xds.client.Bootstrapper.ServerInfo; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; import io.grpc.xds.internal.security.TlsContextManagerImpl; +import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; @@ -160,6 +162,42 @@ public static Bootstrapper.BootstrapInfo buildBootstrapInfo( .build(); } + public static Bootstrapper.BootstrapInfo buildBootstrapInfoForMTlsChannelCredentialServerInfo( + String instanceName, String privateKey, String cert, String trustCa) { + try { + privateKey = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(privateKey); + cert = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(cert); + trustCa = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(trustCa); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + + HashMap config = new HashMap<>(); + config.put("certificate_file", cert); + config.put("private_key_file", privateKey); + config.put("ca_certificate_file", trustCa); + + ChannelCredentials creds; + try { + creds = TlsChannelCredentials.newBuilder() + .customCertificatesConfig(config) + .keyManager(new File(cert), new File(privateKey)) + .trustManager(new File(trustCa)) + .build(); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + + // config for tls channel credentials and for certificate provider are the same + return Bootstrapper.BootstrapInfo.builder() + .servers(ImmutableList.of(ServerInfo.create(SERVER_URI, creds))) + .node(EnvoyProtoData.Node.newBuilder().build()) + .certProviders(ImmutableMap.of( + instanceName, + Bootstrapper.CertificateProviderInfo.create("file_watcher", config))) + .build(); + } + public static boolean setEnableXdsFallback(boolean target) { boolean oldValue = BootstrapperImpl.enableXdsFallback; BootstrapperImpl.enableXdsFallback = target; diff --git a/xds/src/test/java/io/grpc/xds/internal/TlsXdsCredentialsProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/TlsXdsCredentialsProviderTest.java index 3ba26bdb281..ef4322d71e2 100644 --- a/xds/src/test/java/io/grpc/xds/internal/TlsXdsCredentialsProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/TlsXdsCredentialsProviderTest.java @@ -16,14 +16,21 @@ package io.grpc.xds.internal; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import com.google.common.collect.ImmutableMap; +import io.grpc.ChannelCredentials; import io.grpc.InternalServiceProviders; import io.grpc.TlsChannelCredentials; import io.grpc.xds.XdsCredentialsProvider; +import java.io.File; +import java.util.Map; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -32,6 +39,9 @@ public class TlsXdsCredentialsProviderTest { private TlsXdsCredentialsProvider provider = new TlsXdsCredentialsProvider(); + @Rule + public TemporaryFolder tempFolder = new TemporaryFolder(); + @Test public void provided() { for (XdsCredentialsProvider current @@ -50,8 +60,50 @@ public void isAvailable() { } @Test - public void channelCredentials() { + public void channelCredentialsWhenNullConfig() { assertSame(TlsChannelCredentials.class, provider.newChannelCredentials(null).getClass()); } + + @Test + public void channelCredentialsWhenNotExistingTrustFileConfig() { + Map jsonConfig = ImmutableMap.of( + "ca_certificate_file", "/tmp/not-exisiting-file.txt"); + assertNull(provider.newChannelCredentials(jsonConfig)); + } + + @Test + public void channelCredentialsWhenNotExistingCertificateFileConfig() { + Map jsonConfig = ImmutableMap.of( + "certificate_file", "/tmp/not-exisiting-file.txt", + "private_key_file", "/tmp/not-exisiting-file-2.txt"); + assertNull(provider.newChannelCredentials(jsonConfig)); + } + + @Test + public void channelCredentialsWhenInvalidConfig() throws Exception { + File certFile = tempFolder.newFile(new String("identity.cert")); + Map jsonConfig = ImmutableMap.of("certificate_file", certFile.toString()); + assertNull(provider.newChannelCredentials(jsonConfig)); + } + + @Test + public void channelCredentialsWhenValidConfig() throws Exception { + File trustFile = tempFolder.newFile(new String("root.cert")); + File certFile = tempFolder.newFile(new String("identity.cert")); + File keyFile = tempFolder.newFile(new String("private.key")); + + Map jsonConfig = ImmutableMap.of( + "ca_certificate_file", trustFile.toString(), + "certificate_file", certFile.toString(), + "private_key_file", keyFile.toString()); + + ChannelCredentials creds = provider.newChannelCredentials(jsonConfig); + assertSame(TlsChannelCredentials.class, creds.getClass()); + assertSame(((TlsChannelCredentials) creds).getCustomCertificatesConfig(), jsonConfig); + + trustFile.delete(); + certFile.delete(); + keyFile.delete(); + } } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java index 304a2dd5441..4b987b69a54 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java @@ -206,7 +206,9 @@ public void createProvider_missingCert_expectException() throws IOException { provider.createCertificateProvider(map, distWatcher, true); fail("exception expected"); } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo("'certificate_file' is required in the config"); + assertThat(npe) + .hasMessageThat() + .isEqualTo("'certificate_file' and 'private_key_file' must be both set or both unset"); } } @@ -220,24 +222,69 @@ public void createProvider_missingKey_expectException() throws IOException { provider.createCertificateProvider(map, distWatcher, true); fail("exception expected"); } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo("'private_key_file' is required in the config"); + assertThat(npe) + .hasMessageThat() + .isEqualTo("'certificate_file' and 'private_key_file' must be both set or both unset"); } } @Test - public void createProvider_missingRoot_expectException() throws IOException { - String expectedMessage = enableSpiffe ? "either 'ca_certificate_file' or " - + "'spiffe_trust_bundle_map_file' is required in the config" - : "'ca_certificate_file' is required in the config"; + public void createProvider_missingRootAndSpiffeConfig() throws IOException { CertificateProvider.DistributorWatcher distWatcher = new CertificateProvider.DistributorWatcher(); @SuppressWarnings("unchecked") Map map = (Map) JsonParser.parse(MISSING_ROOT_AND_SPIFFE_CONFIG); + ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); + when(scheduledExecutorServiceFactory.create()).thenReturn(mockService); + provider.createCertificateProvider(map, distWatcher, true); + verify(fileWatcherCertificateProviderFactory, times(1)) + .create( + eq(distWatcher), + eq(true), + eq("/var/run/gke-spiffe/certs/certificates.pem"), + eq("/var/run/gke-spiffe/certs/private_key.pem"), + eq(null), + eq(null), + eq(600L), + eq(mockService), + eq(timeProvider)); + } + + @Test + public void createProvider_missingCertAndKeyConfig() throws IOException { + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + @SuppressWarnings("unchecked") + Map map = (Map) JsonParser.parse(MISSING_CERT_AND_KEY_CONFIG); + ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); + when(scheduledExecutorServiceFactory.create()).thenReturn(mockService); + provider.createCertificateProvider(map, distWatcher, true); + verify(fileWatcherCertificateProviderFactory, times(1)) + .create( + eq(distWatcher), + eq(true), + eq(null), + eq(null), + eq("/var/run/gke-spiffe/certs/ca_certificates.pem"), + eq(null), + eq(600L), + eq(mockService), + eq(timeProvider)); + } + + @Test + public void createProvider_emptyConfig_expectException() throws IOException { + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + @SuppressWarnings("unchecked") + Map map = (Map) JsonParser.parse(EMPTY_CONFIG); try { provider.createCertificateProvider(map, distWatcher, true); fail("exception expected"); } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo(expectedMessage); + assertThat(npe) + .hasMessageThat() + .isEqualTo("must be watching either root or identity certificates"); } } @@ -292,6 +339,13 @@ public void createProvider_missingRoot_expectException() throws IOException { + " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key.pem\"" + " }"; + private static final String MISSING_CERT_AND_KEY_CONFIG = + "{\n" + + " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates.pem\"" + + " }"; + + private static final String EMPTY_CONFIG = "{\n}"; + private static final String ZERO_REFRESH_INTERVAL = "{\n" + " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates2.pem\"," diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java index 620ee0a7ff7..d566bee2f28 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java @@ -25,6 +25,8 @@ import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SPIFFE_TRUST_MAP_1_FILE; import static java.nio.file.StandardCopyOption.REPLACE_EXISTING; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doReturn; @@ -357,6 +359,33 @@ public void getCertificate_missingRootFile() throws IOException, InterruptedExce verifyWatcherErrorUpdates(Status.Code.UNKNOWN, NoSuchFileException.class, 1, 0, "root.pem"); } + @Test + public void illegalConstructorArguments_MissingPrivateKeyWhileCertChainPresent() + throws IllegalArgumentException { + Exception ex = assertThrows( + IllegalArgumentException.class, + () -> new FileWatcherCertificateProvider( + watcher, true, null, keyFile, rootFile, null, 600L, timeService, timeProvider)); + + String expectedMsg = "certFile and keyFile must be both set or both unset"; + String actualMsg = ex.getMessage(); + + assertEquals(expectedMsg, actualMsg); + } + + @Test + public void illegalConstructorArguments_NoFilesToWatch() throws IllegalArgumentException { + Exception ex = assertThrows( + IllegalArgumentException.class, + () -> new FileWatcherCertificateProvider( + watcher, true, null, null, null, null, 600L, timeService, timeProvider)); + + String expectedMsg = "must be watching either root or identity certificates"; + String actualMsg = ex.getMessage(); + + assertEquals(expectedMsg, actualMsg); + } + private void commonErrorTest( String certFile, String keyFile,