Skip to content

[xDS] A65 mTLS credentials in bootstrap #12255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions api/src/main/java/io/grpc/TlsChannelCredentials.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.net.ssl.KeyManager;
import javax.net.ssl.TrustManager;
Expand All @@ -49,6 +50,7 @@ public static ChannelCredentials create() {
private final List<KeyManager> keyManagers;
private final byte[] rootCertificates;
private final List<TrustManager> trustManagers;
private final Map<String, ?> customCertificatesConfig;

TlsChannelCredentials(Builder builder) {
fakeFeature = builder.fakeFeature;
Expand All @@ -58,6 +60,7 @@ public static ChannelCredentials create() {
keyManagers = builder.keyManagers;
rootCertificates = builder.rootCertificates;
trustManagers = builder.trustManagers;
customCertificatesConfig = builder.customCertificatesConfig;
}

/**
Expand Down Expand Up @@ -121,6 +124,21 @@ public List<TrustManager> getTrustManagers() {
return trustManagers;
}

/**
* Returns custom certificates config. It contains following entries:
*
* <ul>
* <li>{@code "ca_certificate_file"} key containing the path to the root certificate file</li>
* <li>{@code "certificate_file"} key containing the path to the identity certificate file</li>
* <li>{@code "private_key_file"} key containing the path to the private key</li>
* <li>{@code "refresh_interval"} key specifying the frequency of updates to the above
* files</li>
* </ul>
*/
public Map<String, ?> getCustomCertificatesConfig() {
return customCertificatesConfig;
}

/**
* Returns an empty set if this credential can be adequately understood via
* the features listed, otherwise returns a hint of features that are lacking
Expand Down Expand Up @@ -228,6 +246,7 @@ public static final class Builder {
private List<KeyManager> keyManagers;
private byte[] rootCertificates;
private List<TrustManager> trustManagers;
private Map<String, ?> customCertificatesConfig;

private Builder() {}

Expand Down Expand Up @@ -355,6 +374,11 @@ public Builder trustManager(TrustManager... trustManagers) {
return this;
}

public Builder customCertificatesConfig(Map<String, ?> customCertificatesConfig) {
this.customCertificatesConfig = customCertificatesConfig;
return this;
}

private void clearTrustManagers() {
this.rootCertificates = null;
this.trustManagers = null;
Expand Down
18 changes: 17 additions & 1 deletion xds/src/main/java/io/grpc/xds/client/BootstrapperImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.google.common.collect.ImmutableMap;
import io.grpc.Internal;
import io.grpc.InternalLogId;
import io.grpc.TlsChannelCredentials;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.GrpcUtil.GrpcBuildVersion;
import io.grpc.internal.JsonParser;
Expand Down Expand Up @@ -162,9 +163,9 @@ protected BootstrapInfo.Builder bootstrapBuilder(Map<String, ?> rawData)
builder.node(nodeBuilder.build());

Map<String, ?> certProvidersBlob = JsonUtil.getObject(rawData, "certificate_providers");
Map<String, CertificateProviderInfo> certProviders = new HashMap<>();
if (certProvidersBlob != null) {
logger.log(XdsLogLevel.INFO, "Configured with {0} cert providers", certProvidersBlob.size());
Map<String, CertificateProviderInfo> certProviders = new HashMap<>(certProvidersBlob.size());
for (String name : certProvidersBlob.keySet()) {
Map<String, ?> valueMap = JsonUtil.getObject(certProvidersBlob, name);
String pluginName =
Expand All @@ -175,6 +176,21 @@ protected BootstrapInfo.Builder bootstrapBuilder(Map<String, ?> rawData)
CertificateProviderInfo.create(pluginName, config);
certProviders.put(name, certificateProviderInfo);
}
}

for (ServerInfo serverInfo : servers) {
Object creds = serverInfo.implSpecificConfig();
if (creds instanceof TlsChannelCredentials) {
Map<String, ?> config = ((TlsChannelCredentials)creds).getCustomCertificatesConfig();
if (config != null) {
CertificateProviderInfo certificateProviderInfo =
CertificateProviderInfo.create("file_watcher", config);
certProviders.put("mtls_channel_creds_identity_certs", certificateProviderInfo);
}
}
}

if (!certProviders.isEmpty()) {
builder.certProviders(certProviders);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

import io.grpc.ChannelCredentials;
import io.grpc.TlsChannelCredentials;
import io.grpc.internal.JsonUtil;
import io.grpc.xds.XdsCredentialsProvider;
import java.io.File;
import java.io.IOException;
import java.util.Map;

/**
Expand All @@ -30,7 +33,42 @@ public final class TlsXdsCredentialsProvider extends XdsCredentialsProvider {

@Override
protected ChannelCredentials newChannelCredentials(Map<String, ?> jsonConfig) {
return TlsChannelCredentials.create();
TlsChannelCredentials.Builder builder = TlsChannelCredentials.newBuilder();

if (jsonConfig == null) {
return builder.build();
}

// use trust certificate file path from bootstrap config if provided; else use system default
String rootCertPath = JsonUtil.getString(jsonConfig, "ca_certificate_file");
if (rootCertPath != null) {
try {
builder.trustManager(new File(rootCertPath));
} catch (IOException e) {
return null;
}
}

// use certificate chain and private key file paths from bootstrap config if provided. Mind that
// both JSON values must be either set (mTLS case) or both unset (TLS case)
String certChainPath = JsonUtil.getString(jsonConfig, "certificate_file");
String privateKeyPath = JsonUtil.getString(jsonConfig, "private_key_file");
if (certChainPath != null && privateKeyPath != null) {
try {
builder.keyManager(new File(certChainPath), new File(privateKeyPath));
} catch (IOException e) {
return null;
}
} else if (certChainPath != null || privateKeyPath != null) {
return null;
}

// save json config when custom certificate paths were provided in a bootstrap
if (rootCertPath != null || certChainPath != null) {
builder.customCertificatesConfig(jsonConfig);
}

return builder.build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ final class FileWatcherCertificateProvider extends CertificateProvider implement
this.scheduledExecutorService =
checkNotNull(scheduledExecutorService, "scheduledExecutorService");
this.timeProvider = checkNotNull(timeProvider, "timeProvider");
this.certFile = Paths.get(checkNotNull(certFile, "certFile"));
this.keyFile = Paths.get(checkNotNull(keyFile, "keyFile"));
checkArgument((trustFile != null || spiffeTrustMapFile != null),
"either trustFile or spiffeTrustMapFile must be present");
checkArgument(((certFile != null) == (keyFile != null)),
"certFile and keyFile must be both set or both unset");
this.certFile = certFile == null ? null : Paths.get(certFile);
this.keyFile = keyFile == null ? null : Paths.get(keyFile);
checkArgument((trustFile != null || spiffeTrustMapFile != null || keyFile != null),
"must be watching either root or identity certificates");
if (spiffeTrustMapFile != null) {
this.spiffeTrustMapFile = Paths.get(spiffeTrustMapFile);
this.trustFile = null;
Expand Down Expand Up @@ -113,23 +115,26 @@ private synchronized void scheduleNextRefreshCertificate(long delayInSeconds) {
void checkAndReloadCertificates() {
try {
try {
FileTime currentCertTime = Files.getLastModifiedTime(certFile);
FileTime currentKeyTime = Files.getLastModifiedTime(keyFile);
if (!currentCertTime.equals(lastModifiedTimeCert)
&& !currentKeyTime.equals(lastModifiedTimeKey)) {
byte[] certFileContents = Files.readAllBytes(certFile);
byte[] keyFileContents = Files.readAllBytes(keyFile);
FileTime currentCertTime2 = Files.getLastModifiedTime(certFile);
FileTime currentKeyTime2 = Files.getLastModifiedTime(keyFile);
if (currentCertTime2.equals(currentCertTime) && currentKeyTime2.equals(currentKeyTime)) {
try (ByteArrayInputStream certStream = new ByteArrayInputStream(certFileContents);
ByteArrayInputStream keyStream = new ByteArrayInputStream(keyFileContents)) {
PrivateKey privateKey = CertificateUtils.getPrivateKey(keyStream);
X509Certificate[] certs = CertificateUtils.toX509Certificates(certStream);
getWatcher().updateCertificate(privateKey, Arrays.asList(certs));
if (certFile != null && keyFile != null) {
FileTime currentCertTime = Files.getLastModifiedTime(certFile);
FileTime currentKeyTime = Files.getLastModifiedTime(keyFile);
if (!currentCertTime.equals(lastModifiedTimeCert)
&& !currentKeyTime.equals(lastModifiedTimeKey)) {
byte[] certFileContents = Files.readAllBytes(certFile);
byte[] keyFileContents = Files.readAllBytes(keyFile);
FileTime currentCertTime2 = Files.getLastModifiedTime(certFile);
FileTime currentKeyTime2 = Files.getLastModifiedTime(keyFile);
if (currentCertTime2.equals(currentCertTime)
&& currentKeyTime2.equals(currentKeyTime)) {
try (ByteArrayInputStream certStream = new ByteArrayInputStream(certFileContents);
ByteArrayInputStream keyStream = new ByteArrayInputStream(keyFileContents)) {
PrivateKey privateKey = CertificateUtils.getPrivateKey(keyStream);
X509Certificate[] certs = CertificateUtils.toX509Certificates(certStream);
getWatcher().updateCertificate(privateKey, Arrays.asList(certs));
}
lastModifiedTimeCert = currentCertTime;
lastModifiedTimeKey = currentKeyTime;
}
lastModifiedTimeCert = currentCertTime;
lastModifiedTimeKey = currentKeyTime;
}
}
} catch (Throwable t) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package io.grpc.xds.internal.security.certprovider;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
Expand Down Expand Up @@ -94,30 +93,27 @@ public CertificateProvider createCertificateProvider(
timeProvider);
}

private static String checkForNullAndGet(Map<String, ?> map, String key) {
return checkNotNull(JsonUtil.getString(map, key), "'" + key + "' is required in the config");
}

private static Config validateAndTranslateConfig(Object config) {
checkArgument(config instanceof Map, "Only Map supported for config");
@SuppressWarnings("unchecked") Map<String, ?> map = (Map<String, ?>)config;

Config configObj = new Config();
configObj.certFile = checkForNullAndGet(map, CERT_FILE_KEY);
configObj.keyFile = checkForNullAndGet(map, KEY_FILE_KEY);
if (enableSpiffe) {
if (!map.containsKey(ROOT_FILE_KEY) && !map.containsKey(SPIFFE_TRUST_MAP_FILE_KEY)) {
throw new NullPointerException(
String.format("either '%s' or '%s' is required in the config",
ROOT_FILE_KEY, SPIFFE_TRUST_MAP_FILE_KEY));
}
if (map.containsKey(SPIFFE_TRUST_MAP_FILE_KEY)) {
configObj.spiffeTrustMapFile = JsonUtil.getString(map, SPIFFE_TRUST_MAP_FILE_KEY);
} else {
configObj.rootFile = JsonUtil.getString(map, ROOT_FILE_KEY);
}
configObj.certFile = JsonUtil.getString(map, CERT_FILE_KEY);
configObj.keyFile = JsonUtil.getString(map, KEY_FILE_KEY);
if ((configObj.certFile != null) != (configObj.keyFile != null)) {
throw new NullPointerException(
String.format("'%s' and '%s' must be both set or both unset",
CERT_FILE_KEY, KEY_FILE_KEY));
}
if (!map.containsKey(ROOT_FILE_KEY)
&& !map.containsKey(CERT_FILE_KEY)
&& (!enableSpiffe || !map.containsKey(SPIFFE_TRUST_MAP_FILE_KEY))) {
throw new NullPointerException("must be watching either root or identity certificates");
}
if (enableSpiffe && map.containsKey(SPIFFE_TRUST_MAP_FILE_KEY)) {
configObj.spiffeTrustMapFile = JsonUtil.getString(map, SPIFFE_TRUST_MAP_FILE_KEY);
} else {
configObj.rootFile = checkForNullAndGet(map, ROOT_FILE_KEY);
configObj.rootFile = JsonUtil.getString(map, ROOT_FILE_KEY);
}
String refreshIntervalString = JsonUtil.getString(map, REFRESH_INTERVAL_KEY);
if (refreshIntervalString != null) {
Expand Down
53 changes: 53 additions & 0 deletions xds/src/test/java/io/grpc/xds/GrpcBootstrapperImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
package io.grpc.xds;

import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
Expand All @@ -28,9 +31,11 @@
import io.grpc.TlsChannelCredentials;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.GrpcUtil.GrpcBuildVersion;
import io.grpc.internal.testing.TestUtils;
import io.grpc.xds.client.Bootstrapper;
import io.grpc.xds.client.Bootstrapper.AuthorityInfo;
import io.grpc.xds.client.Bootstrapper.BootstrapInfo;
import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo;
import io.grpc.xds.client.Bootstrapper.ServerInfo;
import io.grpc.xds.client.BootstrapperImpl;
import io.grpc.xds.client.CommonBootstrapperTestUtils;
Expand Down Expand Up @@ -898,6 +903,54 @@ public void badFederationConfig() {
}
}

@Test
public void parseTlsChannelCredentialsWithCustomCertificatesConfig()
throws XdsInitializationException, IOException {
String rootCertPath = TestUtils.loadCert(CA_PEM_FILE).getAbsolutePath();
String certChainPath = TestUtils.loadCert(CLIENT_PEM_FILE).getAbsolutePath();
String privateKeyPath = TestUtils.loadCert(CLIENT_KEY_FILE).getAbsolutePath();

String rawData = "{\n"
+ " \"xds_servers\": [\n"
+ " {\n"
+ " \"server_uri\": \"" + SERVER_URI + "\",\n"
+ " \"channel_creds\": [\n"
+ " {\n"
+ " \"type\": \"tls\","
+ " \"config\": {\n"
+ " \"ca_certificate_file\": \"" + rootCertPath + "\",\n"
+ " \"certificate_file\": \"" + certChainPath + "\",\n"
+ " \"private_key_file\": \"" + privateKeyPath + "\"\n"
+ " }\n"
+ " }\n"
+ " ]\n"
+ " }\n"
+ " ]\n"
+ "}";

bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData));
BootstrapInfo info = bootstrapper.bootstrap();
assertThat(info.servers()).hasSize(1);
ServerInfo serverInfo = Iterables.getOnlyElement(info.servers());
assertThat(serverInfo.target()).isEqualTo(SERVER_URI);
assertThat(serverInfo.implSpecificConfig()).isInstanceOf(TlsChannelCredentials.class);
assertThat(info.node()).isEqualTo(getNodeBuilder().build());

assertThat(info.certProviders()).hasSize(1);
ImmutableMap<String, CertificateProviderInfo> certProviderInfo = info.certProviders();
assertThat(certProviderInfo.keySet()).containsExactly("mtls_channel_creds_identity_certs");
CertificateProviderInfo mtlsChannelCredCertProviderInfo =
certProviderInfo.get("mtls_channel_creds_identity_certs");
assertThat(mtlsChannelCredCertProviderInfo.config().keySet())
.containsExactly("ca_certificate_file", "certificate_file", "private_key_file");
assertThat(mtlsChannelCredCertProviderInfo.config().get("ca_certificate_file"))
.isEqualTo(rootCertPath);
assertThat(mtlsChannelCredCertProviderInfo.config().get("certificate_file"))
.isEqualTo(certChainPath);
assertThat(mtlsChannelCredCertProviderInfo.config().get("private_key_file"))
.isEqualTo(privateKeyPath);
}

private static BootstrapperImpl.FileReader createFileReader(
final String expectedPath, final String rawData) {
return new BootstrapperImpl.FileReader() {
Expand Down
Loading