Skip to content

Commit 7f48afa

Browse files
committed
Remove special-casing for System root certs in SslContextProviderSupplier and handle it in
1 parent 139805e commit 7f48afa

File tree

5 files changed

+49
-124
lines changed

5 files changed

+49
-124
lines changed

xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,12 @@
2020

2121
import com.google.common.annotations.VisibleForTesting;
2222
import com.google.common.base.MoreObjects;
23-
import io.grpc.netty.GrpcSslContexts;
2423
import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext;
2524
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
2625
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
2726
import io.grpc.xds.TlsContextManager;
2827
import io.netty.handler.ssl.SslContext;
2928
import java.util.Objects;
30-
import javax.net.ssl.SSLException;
3129

3230
/**
3331
* Enables Client or server side to initialize this object with the received {@link BaseTlsContext}
@@ -64,36 +62,21 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call
6462
}
6563
// we want to increment the ref-count so call findOrCreate again...
6664
final SslContextProvider toRelease = getSslContextProvider();
67-
// When using system root certs on client side, SslContext updates via CertificateProvider is
68-
// only required if Mtls is also enabled, i.e. tlsContext has a cert provider instance.
69-
if (tlsContext instanceof UpstreamTlsContext
70-
&& !CommonTlsContextUtil.hasCertProviderInstance(tlsContext.getCommonTlsContext())
71-
&& CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext())) {
72-
callback.getExecutor().execute(() -> {
73-
try {
74-
callback.updateSslContext(GrpcSslContexts.forClient().build());
75-
releaseSslContextProvider(toRelease);
76-
} catch (SSLException e) {
77-
callback.onException(e);
78-
}
79-
});
80-
} else {
81-
toRelease.addCallback(
82-
new SslContextProvider.Callback(callback.getExecutor()) {
83-
84-
@Override
85-
public void updateSslContext(SslContext sslContext) {
86-
callback.updateSslContext(sslContext);
87-
releaseSslContextProvider(toRelease);
88-
}
89-
90-
@Override
91-
public void onException(Throwable throwable) {
92-
callback.onException(throwable);
93-
releaseSslContextProvider(toRelease);
94-
}
95-
});
96-
}
65+
toRelease.addCallback(
66+
new SslContextProvider.Callback(callback.getExecutor()) {
67+
68+
@Override
69+
public void updateSslContext(SslContext sslContext) {
70+
callback.updateSslContext(sslContext);
71+
releaseSslContextProvider(toRelease);
72+
}
73+
74+
@Override
75+
public void onException(Throwable throwable) {
76+
callback.onException(throwable);
77+
releaseSslContextProvider(toRelease);
78+
}
79+
});
9780
} catch (final Throwable throwable) {
9881
callback.getExecutor().execute(new Runnable() {
9982
@Override

xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import java.security.cert.X509Certificate;
2929
import java.util.Map;
3030
import javax.annotation.Nullable;
31+
import javax.net.ssl.SSLException;
3132

3233
/** A client SslContext provider using CertificateProviderInstance to fetch secrets. */
3334
final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider {
@@ -48,15 +49,24 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP
4849
staticCertValidationContext,
4950
upstreamTlsContext,
5051
certificateProviderStore);
52+
// Null rootCertInstance implies hasSystemRootCerts because of the check in
53+
// CertProviderClientSslContextProviderFactory.
54+
if (rootCertInstance == null && !isMtls()) {
55+
try {
56+
// Instantiate sslContext so that addCallback will immediately update the callback with
57+
// the SslContext.
58+
sslContext = getSslContextBuilder(staticCertificateValidationContext).build();
59+
} catch (SSLException | CertStoreException e) {
60+
throw new RuntimeException(e);
61+
}
62+
}
5163
}
5264

5365
@Override
5466
protected final SslContextBuilder getSslContextBuilder(
5567
CertificateValidationContext certificateValidationContextdationContext)
5668
throws CertStoreException {
5769
SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
58-
// Null rootCertInstance implies hasSystemRootCerts because of the check in
59-
// CertProviderClientSslContextProviderFactory.
6070
if (rootCertInstance != null) {
6171
if (savedSpiffeTrustMap != null) {
6272
sslContextBuilder = sslContextBuilder.trustManager(

xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,12 @@ private void updateSslContextWhenReady() {
158158
updateSslContext();
159159
clearKeysAndCerts();
160160
}
161-
} else if (isNormalTlsAndClientSide()) {
161+
} else if (isRegularTlsAndClientSide()) {
162162
if (savedTrustedRoots != null || savedSpiffeTrustMap != null) {
163163
updateSslContext();
164164
clearKeysAndCerts();
165165
}
166-
} else if (isNormalTlsAndServerSide()) {
166+
} else if (isRegularTlsAndServerSide()) {
167167
if (savedKey != null) {
168168
updateSslContext();
169169
clearKeysAndCerts();
@@ -182,14 +182,14 @@ protected final boolean isMtls() {
182182
return certInstance != null && (rootCertInstance != null || isUsingSystemRootCerts);
183183
}
184184

185-
protected final boolean isNormalTlsAndClientSide() {
185+
protected final boolean isRegularTlsAndClientSide() {
186186
// We don't do (rootCertInstance != null || isUsingSystemRootCerts) here because of how this
187187
// method is used. With the rootCertInstance being null when using system root certs, there
188188
// is nothing to update in the SslContext
189189
return rootCertInstance != null && certInstance == null;
190190
}
191191

192-
protected final boolean isNormalTlsAndServerSide() {
192+
protected final boolean isRegularTlsAndServerSide() {
193193
return certInstance != null && rootCertInstance == null;
194194
}
195195

xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java

Lines changed: 13 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,15 @@
1717
package io.grpc.xds.internal.security;
1818

1919
import static com.google.common.truth.Truth.assertThat;
20-
import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.buildUpstreamTlsContext;
21-
import static org.mockito.ArgumentMatchers.any;
2220
import static org.mockito.ArgumentMatchers.eq;
21+
import static org.mockito.Mockito.any;
2322
import static org.mockito.Mockito.doReturn;
2423
import static org.mockito.Mockito.doThrow;
2524
import static org.mockito.Mockito.mock;
2625
import static org.mockito.Mockito.never;
27-
import static org.mockito.Mockito.reset;
2826
import static org.mockito.Mockito.times;
2927
import static org.mockito.Mockito.verify;
3028

31-
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
3229
import io.grpc.xds.EnvoyServerProtoData;
3330
import io.grpc.xds.TlsContextManager;
3431
import io.netty.handler.ssl.SslContext;
@@ -50,33 +47,31 @@ public class SslContextProviderSupplierTest {
5047
@Rule public final MockitoRule mocks = MockitoJUnit.rule();
5148

5249
@Mock private TlsContextManager mockTlsContextManager;
53-
@Mock private Executor mockExecutor;
5450
private SslContextProviderSupplier supplier;
5551
private SslContextProvider mockSslContextProvider;
5652
private EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext;
5753
private SslContextProvider.Callback mockCallback;
5854

59-
private void prepareSupplier(boolean createUpstreamTlsContext) {
60-
if (createUpstreamTlsContext) {
61-
upstreamTlsContext =
62-
buildUpstreamTlsContext("google_cloud_private_spiffe", true);
63-
}
55+
private void prepareSupplier() {
56+
upstreamTlsContext =
57+
CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true);
6458
mockSslContextProvider = mock(SslContextProvider.class);
6559
doReturn(mockSslContextProvider)
66-
.when(mockTlsContextManager)
67-
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
60+
.when(mockTlsContextManager)
61+
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
6862
supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager);
6963
}
7064

7165
private void callUpdateSslContext() {
7266
mockCallback = mock(SslContextProvider.Callback.class);
67+
Executor mockExecutor = mock(Executor.class);
7368
doReturn(mockExecutor).when(mockCallback).getExecutor();
7469
supplier.updateSslContext(mockCallback);
7570
}
7671

7772
@Test
7873
public void get_updateSecret() {
79-
prepareSupplier(true);
74+
prepareSupplier();
8075
callUpdateSslContext();
8176
verify(mockTlsContextManager, times(2))
8277
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
@@ -100,12 +95,11 @@ public void get_updateSecret() {
10095

10196
@Test
10297
public void get_onException() {
103-
prepareSupplier(true);
98+
prepareSupplier();
10499
callUpdateSslContext();
105100
ArgumentCaptor<SslContextProvider.Callback> callbackCaptor =
106101
ArgumentCaptor.forClass(SslContextProvider.Callback.class);
107-
verify(mockSslContextProvider, times(1))
108-
.addCallback(callbackCaptor.capture());
102+
verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture());
109103
SslContextProvider.Callback capturedCallback = callbackCaptor.getValue();
110104
assertThat(capturedCallback).isNotNull();
111105
Exception exception = new Exception("test");
@@ -115,71 +109,9 @@ public void get_onException() {
115109
.releaseClientSslContextProvider(eq(mockSslContextProvider));
116110
}
117111

118-
@Test
119-
public void systemRootCertsWithMtls_callbackExecutedFromProvider() {
120-
upstreamTlsContext =
121-
CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance(
122-
"gcp_id",
123-
"cert-default",
124-
null,
125-
"root-default",
126-
null,
127-
CertificateValidationContext.newBuilder()
128-
.setSystemRootCerts(
129-
CertificateValidationContext.SystemRootCerts.getDefaultInstance())
130-
.build());
131-
prepareSupplier(false);
132-
133-
callUpdateSslContext();
134-
135-
verify(mockTlsContextManager, times(2))
136-
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
137-
verify(mockTlsContextManager, times(0))
138-
.releaseClientSslContextProvider(any(SslContextProvider.class));
139-
ArgumentCaptor<SslContextProvider.Callback> callbackCaptor =
140-
ArgumentCaptor.forClass(SslContextProvider.Callback.class);
141-
verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture());
142-
SslContextProvider.Callback capturedCallback = callbackCaptor.getValue();
143-
assertThat(capturedCallback).isNotNull();
144-
SslContext mockSslContext = mock(SslContext.class);
145-
capturedCallback.updateSslContext(mockSslContext);
146-
verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext));
147-
verify(mockTlsContextManager, times(1))
148-
.releaseClientSslContextProvider(eq(mockSslContextProvider));
149-
SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class);
150-
supplier.updateSslContext(mockCallback);
151-
verify(mockTlsContextManager, times(3))
152-
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
153-
}
154-
155-
@Test
156-
public void systemRootCertsWithRegularTls_callbackExecutedFromSupplier() {
157-
upstreamTlsContext =
158-
CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance(
159-
null,
160-
null,
161-
null,
162-
"root-default",
163-
null,
164-
CertificateValidationContext.newBuilder()
165-
.setSystemRootCerts(
166-
CertificateValidationContext.SystemRootCerts.getDefaultInstance())
167-
.build());
168-
supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager);
169-
reset(mockTlsContextManager);
170-
171-
callUpdateSslContext();
172-
ArgumentCaptor<Runnable> runnableArgumentCaptor = ArgumentCaptor.forClass(Runnable.class);
173-
verify(mockExecutor).execute(runnableArgumentCaptor.capture());
174-
runnableArgumentCaptor.getValue().run();
175-
verify(mockCallback, times(1)).updateSslContext(any(SslContext.class));
176-
verify(mockTlsContextManager, times(1))
177-
.releaseClientSslContextProvider(eq(mockSslContextProvider));
178-
}
179-
180112
@Test
181113
public void testClose() {
182-
prepareSupplier(true);
114+
prepareSupplier();
183115
callUpdateSslContext();
184116
supplier.close();
185117
verify(mockTlsContextManager, times(1))
@@ -193,7 +125,7 @@ public void testClose() {
193125

194126
@Test
195127
public void testClose_nullSslContextProvider() {
196-
prepareSupplier(true);
128+
prepareSupplier();
197129
doThrow(new NullPointerException()).when(mockTlsContextManager)
198130
.releaseClientSslContextProvider(null);
199131
supplier.close();
@@ -203,4 +135,4 @@ public void testClose_nullSslContextProvider() {
203135
verify(mockTlsContextManager, times(1))
204136
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
205137
}
206-
}
138+
}

xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,7 @@ public void testProviderForClient_mtls() throws Exception {
187187
}
188188

189189
@Test
190-
// Note: This code flow will not really be invoked since {@link SslContextProviderSupplier} will
191-
// shortcircuit creating the certificate provider and directly invoke the callback with the
192-
// SslContext in this case.
193-
public void testProviderForClient_systemRootCerts_regularTls() throws Exception {
190+
public void testProviderForClient_systemRootCerts_regularTls() {
194191
final CertificateProvider.DistributorWatcher[] watcherCaptor =
195192
new CertificateProvider.DistributorWatcher[1];
196193
TestCertificateProvider.createAndRegisterProviderProvider(
@@ -210,7 +207,10 @@ public void testProviderForClient_systemRootCerts_regularTls() throws Exception
210207
assertThat(provider.savedKey).isNull();
211208
assertThat(provider.savedCertChain).isNull();
212209
assertThat(provider.savedTrustedRoots).isNull();
213-
assertThat(provider.getSslContext()).isNull();
210+
assertThat(provider.getSslContext()).isNotNull();
211+
TestCallback testCallback =
212+
CommonTlsContextTestsUtil.getValueThruCallback(provider);
213+
assertThat(testCallback.updatedSslContext).isEqualTo(provider.getSslContext());
214214

215215
assertThat(watcherCaptor[0]).isNull();
216216
}

0 commit comments

Comments
 (0)