Skip to content

Commit 37cd044

Browse files
committed
Handle Sslcontext updates for System root certs with and without Mtls.
1 parent 199cc69 commit 37cd044

File tree

3 files changed

+161
-38
lines changed

3 files changed

+161
-38
lines changed

xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020

2121
import com.google.common.annotations.VisibleForTesting;
2222
import com.google.common.base.MoreObjects;
23+
import io.grpc.netty.GrpcSslContexts;
2324
import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext;
2425
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
2526
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
2627
import io.grpc.xds.TlsContextManager;
2728
import io.netty.handler.ssl.SslContext;
2829
import java.util.Objects;
30+
import javax.net.ssl.SSLException;
2931

3032
/**
3133
* Enables Client or server side to initialize this object with the received {@link BaseTlsContext}
@@ -62,21 +64,36 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call
6264
}
6365
// we want to increment the ref-count so call findOrCreate again...
6466
final SslContextProvider toRelease = getSslContextProvider();
65-
toRelease.addCallback(
66-
new SslContextProvider.Callback(callback.getExecutor()) {
67-
68-
@Override
69-
public void updateSslContext(SslContext sslContext) {
70-
callback.updateSslContext(sslContext);
71-
releaseSslContextProvider(toRelease);
72-
}
73-
74-
@Override
75-
public void onException(Throwable throwable) {
76-
callback.onException(throwable);
77-
releaseSslContextProvider(toRelease);
78-
}
79-
});
67+
// When using system root certs on client side, SslContext updates via CertificateProvider is
68+
// only required if Mtls is also enabled, i.e. tlsContext has a cert provider instance.
69+
if (tlsContext instanceof UpstreamTlsContext
70+
&& !CommonTlsContextUtil.hasCertProviderInstance(tlsContext.getCommonTlsContext())
71+
&& CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext())) {
72+
callback.getExecutor().execute(() -> {
73+
try {
74+
callback.updateSslContext(GrpcSslContexts.forClient().build());
75+
releaseSslContextProvider(toRelease);
76+
} catch (SSLException e) {
77+
callback.onException(e);
78+
}
79+
});
80+
} else {
81+
toRelease.addCallback(
82+
new SslContextProvider.Callback(callback.getExecutor()) {
83+
84+
@Override
85+
public void updateSslContext(SslContext sslContext) {
86+
callback.updateSslContext(sslContext);
87+
releaseSslContextProvider(toRelease);
88+
}
89+
90+
@Override
91+
public void onException(Throwable throwable) {
92+
callback.onException(throwable);
93+
releaseSslContextProvider(toRelease);
94+
}
95+
});
96+
}
8097
} catch (final Throwable throwable) {
8198
callback.getExecutor().execute(new Runnable() {
8299
@Override

xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java

Lines changed: 78 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,18 @@
1717
package io.grpc.xds.internal.security;
1818

1919
import static com.google.common.truth.Truth.assertThat;
20+
import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.buildUpstreamTlsContext;
21+
import static org.mockito.ArgumentMatchers.any;
2022
import static org.mockito.ArgumentMatchers.eq;
21-
import static org.mockito.Mockito.any;
2223
import static org.mockito.Mockito.doReturn;
2324
import static org.mockito.Mockito.doThrow;
2425
import static org.mockito.Mockito.mock;
2526
import static org.mockito.Mockito.never;
27+
import static org.mockito.Mockito.reset;
2628
import static org.mockito.Mockito.times;
2729
import static org.mockito.Mockito.verify;
2830

31+
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
2932
import io.grpc.xds.EnvoyServerProtoData;
3033
import io.grpc.xds.TlsContextManager;
3134
import io.netty.handler.ssl.SslContext;
@@ -47,14 +50,17 @@ public class SslContextProviderSupplierTest {
4750
@Rule public final MockitoRule mocks = MockitoJUnit.rule();
4851

4952
@Mock private TlsContextManager mockTlsContextManager;
53+
@Mock private Executor mockExecutor;
5054
private SslContextProviderSupplier supplier;
5155
private SslContextProvider mockSslContextProvider;
5256
private EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext;
5357
private SslContextProvider.Callback mockCallback;
5458

55-
private void prepareSupplier() {
56-
upstreamTlsContext =
57-
CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true);
59+
private void prepareSupplier(boolean createUpstreamTlsContext) {
60+
if (createUpstreamTlsContext) {
61+
upstreamTlsContext =
62+
buildUpstreamTlsContext("google_cloud_private_spiffe", true);
63+
}
5864
mockSslContextProvider = mock(SslContextProvider.class);
5965
doReturn(mockSslContextProvider)
6066
.when(mockTlsContextManager)
@@ -64,14 +70,13 @@ private void prepareSupplier() {
6470

6571
private void callUpdateSslContext() {
6672
mockCallback = mock(SslContextProvider.Callback.class);
67-
Executor mockExecutor = mock(Executor.class);
6873
doReturn(mockExecutor).when(mockCallback).getExecutor();
6974
supplier.updateSslContext(mockCallback);
7075
}
7176

7277
@Test
7378
public void get_updateSecret() {
74-
prepareSupplier();
79+
prepareSupplier(true);
7580
callUpdateSslContext();
7681
verify(mockTlsContextManager, times(2))
7782
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
@@ -95,11 +100,12 @@ public void get_updateSecret() {
95100

96101
@Test
97102
public void get_onException() {
98-
prepareSupplier();
103+
prepareSupplier(true);
99104
callUpdateSslContext();
100105
ArgumentCaptor<SslContextProvider.Callback> callbackCaptor =
101106
ArgumentCaptor.forClass(SslContextProvider.Callback.class);
102-
verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture());
107+
verify(mockSslContextProvider, times(1))
108+
.addCallback(callbackCaptor.capture());
103109
SslContextProvider.Callback capturedCallback = callbackCaptor.getValue();
104110
assertThat(capturedCallback).isNotNull();
105111
Exception exception = new Exception("test");
@@ -109,9 +115,71 @@ public void get_onException() {
109115
.releaseClientSslContextProvider(eq(mockSslContextProvider));
110116
}
111117

118+
@Test
119+
public void systemRootCertsWithMtls_callbackExecutedFromProvider() {
120+
upstreamTlsContext =
121+
CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance(
122+
"gcp_id",
123+
"cert-default",
124+
null,
125+
"root-default",
126+
null,
127+
CertificateValidationContext.newBuilder()
128+
.setSystemRootCerts(
129+
CertificateValidationContext.SystemRootCerts.getDefaultInstance())
130+
.build());
131+
prepareSupplier(false);
132+
133+
callUpdateSslContext();
134+
135+
verify(mockTlsContextManager, times(2))
136+
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
137+
verify(mockTlsContextManager, times(0))
138+
.releaseClientSslContextProvider(any(SslContextProvider.class));
139+
ArgumentCaptor<SslContextProvider.Callback> callbackCaptor =
140+
ArgumentCaptor.forClass(SslContextProvider.Callback.class);
141+
verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture());
142+
SslContextProvider.Callback capturedCallback = callbackCaptor.getValue();
143+
assertThat(capturedCallback).isNotNull();
144+
SslContext mockSslContext = mock(SslContext.class);
145+
capturedCallback.updateSslContext(mockSslContext);
146+
verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext));
147+
verify(mockTlsContextManager, times(1))
148+
.releaseClientSslContextProvider(eq(mockSslContextProvider));
149+
SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class);
150+
supplier.updateSslContext(mockCallback);
151+
verify(mockTlsContextManager, times(3))
152+
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
153+
}
154+
155+
@Test
156+
public void systemRootCertsWithRegularTls_callbackExecutedFromSupplier() {
157+
upstreamTlsContext =
158+
CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance(
159+
null,
160+
null,
161+
null,
162+
"root-default",
163+
null,
164+
CertificateValidationContext.newBuilder()
165+
.setSystemRootCerts(
166+
CertificateValidationContext.SystemRootCerts.getDefaultInstance())
167+
.build());
168+
supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager);
169+
reset(mockTlsContextManager);
170+
171+
callUpdateSslContext();
172+
ArgumentCaptor<Runnable> runnableArgumentCaptor = ArgumentCaptor.forClass(Runnable.class);
173+
verify(mockExecutor).execute(runnableArgumentCaptor.capture());
174+
runnableArgumentCaptor.getValue().run();
175+
verify(mockCallback, times(1)).updateSslContext(any(SslContext.class));
176+
verify(mockTlsContextManager, times(1))
177+
.releaseClientSslContextProvider(eq(mockSslContextProvider));
178+
}
179+
112180
@Test
113181
public void testClose() {
114-
prepareSupplier();
182+
prepareSupplier(true);
115183
callUpdateSslContext();
116184
supplier.close();
117185
verify(mockTlsContextManager, times(1))
@@ -125,7 +193,7 @@ public void testClose() {
125193

126194
@Test
127195
public void testClose_nullSslContextProvider() {
128-
prepareSupplier();
196+
prepareSupplier(true);
129197
doThrow(new NullPointerException()).when(mockTlsContextManager)
130198
.releaseClientSslContextProvider(null);
131199
supplier.close();

xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -187,14 +187,19 @@ public void testProviderForClient_mtls() throws Exception {
187187
}
188188

189189
@Test
190-
public void testProviderForClient_systemRootCerts() throws Exception {
190+
/**
191+
* Note this route will not really be invoked since {@link SslContextProviderSupplier} will
192+
* shortcircuit creating the certificate provider and directly invoke the callback with the
193+
* SslContext in this case.
194+
*/
195+
public void testProviderForClient_systemRootCerts_regularTls() throws Exception {
191196
final CertificateProvider.DistributorWatcher[] watcherCaptor =
192197
new CertificateProvider.DistributorWatcher[1];
193198
TestCertificateProvider.createAndRegisterProviderProvider(
194199
certificateProviderRegistry, watcherCaptor, "testca", 0);
195200
CertProviderClientSslContextProvider provider =
196201
getSslContextProvider(
197-
"gcp_id",
202+
null,
198203
null,
199204
CommonBootstrapperTestUtils.getTestBootstrapInfo(),
200205
/* alpnProtocols= */ null,
@@ -209,36 +214,69 @@ public void testProviderForClient_systemRootCerts() throws Exception {
209214
assertThat(provider.savedTrustedRoots).isNull();
210215
assertThat(provider.getSslContext()).isNull();
211216

212-
// now generate cert update, updates SslContext
217+
assertThat(watcherCaptor[0]).isNull();
218+
}
219+
220+
@Test
221+
public void testProviderForClient_systemRootCerts_mtls() throws Exception {
222+
final CertificateProvider.DistributorWatcher[] watcherCaptor =
223+
new CertificateProvider.DistributorWatcher[1];
224+
TestCertificateProvider.createAndRegisterProviderProvider(
225+
certificateProviderRegistry, watcherCaptor, "testca", 0);
226+
CertProviderClientSslContextProvider provider =
227+
getSslContextProvider(
228+
"gcp_id",
229+
null,
230+
CommonBootstrapperTestUtils.getTestBootstrapInfo(),
231+
/* alpnProtocols= */ null,
232+
CertificateValidationContext.newBuilder()
233+
.setSystemRootCerts(
234+
CertificateValidationContext.SystemRootCerts.getDefaultInstance())
235+
.build(),
236+
true);
237+
238+
assertThat(provider.savedKey).isNull();
239+
assertThat(provider.savedCertChain).isNull();
240+
assertThat(provider.savedTrustedRoots).isNull();
241+
assertThat(provider.getSslContext()).isNull();
242+
243+
// now generate root cert update, will get ignored because of systemRootCerts config
244+
watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE)));
245+
assertThat(provider.getSslContext()).isNull();
246+
assertThat(provider.savedKey).isNull();
247+
assertThat(provider.savedCertChain).isNull();
248+
assertThat(provider.savedTrustedRoots).isNull();
249+
250+
// now generate cert update
213251
watcherCaptor[0].updateCertificate(
214-
CommonCertProviderTestUtils.getPrivateKey(CLIENT_KEY_FILE),
215-
ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE)));
252+
CommonCertProviderTestUtils.getPrivateKey(CLIENT_KEY_FILE),
253+
ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE)));
216254
assertThat(provider.savedKey).isNull();
217255
assertThat(provider.savedCertChain).isNull();
218256
assertThat(provider.getSslContext()).isNotNull();
219257

220258
TestCallback testCallback =
221-
CommonTlsContextTestsUtil.getValueThruCallback(provider);
259+
CommonTlsContextTestsUtil.getValueThruCallback(provider);
222260

223261
doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
224262
TestCallback testCallback1 =
225-
CommonTlsContextTestsUtil.getValueThruCallback(provider);
263+
CommonTlsContextTestsUtil.getValueThruCallback(provider);
226264
assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext);
227265

228-
// just do root cert update: trusted roots is not updated (because of system root certs config)
229-
// and sslContext should still be the same
266+
// just do root cert update: sslContext should still be the same, will get ignored because of
267+
// systemRootCerts config
230268
watcherCaptor[0].updateTrustedRoots(
231-
ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE)));
269+
ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE)));
232270
assertThat(provider.savedKey).isNull();
233271
assertThat(provider.savedCertChain).isNull();
234272
assertThat(provider.savedTrustedRoots).isNull();
235273
testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider);
236274
assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext);
237275

238-
// now update id cert: sslContext should be updated i.e.different from the previous one
276+
// now update id cert: sslContext should be updated i.e. different from the previous one
239277
watcherCaptor[0].updateCertificate(
240-
CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE),
241-
ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE)));
278+
CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE),
279+
ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE)));
242280
assertThat(provider.savedKey).isNull();
243281
assertThat(provider.savedCertChain).isNull();
244282
assertThat(provider.savedTrustedRoots).isNull();

0 commit comments

Comments
 (0)