Skip to content

Commit a371065

Browse files
committed
Save changes.
1 parent dd8fa02 commit a371065

File tree

3 files changed

+40
-122
lines changed

3 files changed

+40
-122
lines changed

xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,10 @@ protected Callback(Executor executor) {
5454
this.hostname = null;
5555
}
5656

57-
// Only for client SslContextProvider.
58-
protected Callback(Executor executor, String hostname) {
59-
this.executor = executor;
60-
this.hostname = hostname;
61-
}
62-
6357
@VisibleForTesting public Executor getExecutor() {
6458
return executor;
6559
}
6660

67-
protected String getHostname() {
68-
return hostname;
69-
}
70-
7161
/** Informs callee of new/updated SslContext. */
7262
@VisibleForTesting public abstract void updateSslContext(SslContext sslContext);
7363

xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ public BaseTlsContext getTlsContext() {
6060
public synchronized void updateSslContext(final SslContextProvider.Callback callback, String sni) {
6161
checkNotNull(callback, "callback");
6262
try {
63+
if (!shutdown) {
64+
if (sslContextProvider == null) {
65+
sslContextProvider = getSslContextProvider(sni);
66+
}
67+
}
68+
6369
// we want to increment the ref-count so call findOrCreate again...
6470
final SslContextProvider toRelease = getSslContextProvider(sni);
6571
toRelease.addCallback(

xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java

Lines changed: 34 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import static org.mockito.Mockito.never;
2626
import static org.mockito.Mockito.times;
2727
import static org.mockito.Mockito.verify;
28-
import static org.mockito.Mockito.when;
2928

3029
import io.grpc.xds.EnvoyServerProtoData;
3130
import io.grpc.xds.TlsContextManager;
@@ -47,41 +46,39 @@
4746
public class SslContextProviderSupplierTest {
4847
@Rule public final MockitoRule mocks = MockitoJUnit.rule();
4948

50-
private static final String ENDPOINT_HOSTNAME_FROM_ATTR = "endpoint-hostname-from-attribute";
51-
private static final String SNI_IN_UTC = "sni-in-upstream-tls-context";
49+
private static final String SNI = "sni";
5250

5351
@Mock private TlsContextManager mockTlsContextManager;
5452
private SslContextProviderSupplier supplier;
5553
private SslContextProvider mockSslContextProvider;
5654
private EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext;
5755
private SslContextProvider.Callback mockCallback;
5856

59-
private void prepareSupplier(boolean autoHostSni, String sniInUTC, String sniSentByClient) {
57+
private void prepareSupplier() {
6058
upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContext(
61-
"google_cloud_private_spiffe", true, sniInUTC, autoHostSni);
59+
"google_cloud_private_spiffe", true, SNI, false);
6260
mockSslContextProvider = mock(SslContextProvider.class);
6361
doReturn(mockSslContextProvider)
6462
.when(mockTlsContextManager)
65-
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(sniSentByClient));
63+
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI));
6664
supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager);
6765
}
6866

69-
private void callUpdateSslContext(String endpointHostname) {
67+
private void callUpdateSslContext() {
7068
mockCallback = mock(SslContextProvider.Callback.class);
71-
when(mockCallback.getHostname()).thenReturn(endpointHostname);
7269
Executor mockExecutor = mock(Executor.class);
7370
doReturn(mockExecutor).when(mockCallback).getExecutor();
74-
supplier.updateSslContext(mockCallback, null);
71+
supplier.updateSslContext(mockCallback, SNI);
7572
}
7673

7774
@Test
7875
public void get_updateSecret() {
79-
prepareSupplier(false, null, "");
80-
callUpdateSslContext(ENDPOINT_HOSTNAME_FROM_ATTR);
76+
prepareSupplier();
77+
callUpdateSslContext();
8178
verify(mockTlsContextManager, times(2))
82-
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(""));
79+
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI));
8380
verify(mockTlsContextManager, times(0))
84-
.releaseClientSslContextProvider(any(SslContextProvider.class), eq(""));
81+
.releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI));
8582
ArgumentCaptor<SslContextProvider.Callback> callbackCaptor =
8683
ArgumentCaptor.forClass(SslContextProvider.Callback.class);
8784
verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture());
@@ -91,21 +88,21 @@ public void get_updateSecret() {
9188
capturedCallback.updateSslContext(mockSslContext);
9289
verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext));
9390
verify(mockTlsContextManager, times(1))
94-
.releaseClientSslContextProvider(eq(mockSslContextProvider), eq(""));
91+
.releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI));
9592
SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class);
96-
supplier.updateSslContext(mockCallback, null);
93+
supplier.updateSslContext(mockCallback, SNI);
9794
verify(mockTlsContextManager, times(3))
98-
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(""));
95+
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI));
9996
}
10097

10198
@Test
10299
public void autoHostSniFalse_usesSniFromUpstreamTlsContext() {
103-
prepareSupplier(false, SNI_IN_UTC, SNI_IN_UTC);
104-
callUpdateSslContext(ENDPOINT_HOSTNAME_FROM_ATTR);
100+
prepareSupplier();
101+
callUpdateSslContext();
105102
verify(mockTlsContextManager, times(2))
106-
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI_IN_UTC));
103+
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI));
107104
verify(mockTlsContextManager, times(0))
108-
.releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI_IN_UTC));
105+
.releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI));
109106
ArgumentCaptor<SslContextProvider.Callback> callbackCaptor =
110107
ArgumentCaptor.forClass(SslContextProvider.Callback.class);
111108
verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture());
@@ -115,92 +112,17 @@ public void autoHostSniFalse_usesSniFromUpstreamTlsContext() {
115112
capturedCallback.updateSslContext(mockSslContext);
116113
verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext));
117114
verify(mockTlsContextManager, times(1))
118-
.releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI_IN_UTC));
115+
.releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI));
119116
SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class);
120-
supplier.updateSslContext(mockCallback, null);
117+
supplier.updateSslContext(mockCallback, SNI);
121118
verify(mockTlsContextManager, times(3))
122-
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI_IN_UTC));
123-
}
124-
125-
@Test
126-
public void autoHostSniTrue_usesSniFromEndpointHostname() {
127-
prepareSupplier(true, SNI_IN_UTC, ENDPOINT_HOSTNAME_FROM_ATTR);
128-
callUpdateSslContext(ENDPOINT_HOSTNAME_FROM_ATTR);
129-
verify(mockTlsContextManager, times(2))
130-
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(ENDPOINT_HOSTNAME_FROM_ATTR));
131-
verify(mockTlsContextManager, times(0))
132-
.releaseClientSslContextProvider(any(SslContextProvider.class), eq(ENDPOINT_HOSTNAME_FROM_ATTR));
133-
ArgumentCaptor<SslContextProvider.Callback> callbackCaptor =
134-
ArgumentCaptor.forClass(SslContextProvider.Callback.class);
135-
verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture());
136-
SslContextProvider.Callback capturedCallback = callbackCaptor.getValue();
137-
assertThat(capturedCallback).isNotNull();
138-
SslContext mockSslContext = mock(SslContext.class);
139-
capturedCallback.updateSslContext(mockSslContext);
140-
verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext));
141-
verify(mockTlsContextManager, times(1))
142-
.releaseClientSslContextProvider(eq(mockSslContextProvider), eq(ENDPOINT_HOSTNAME_FROM_ATTR));
143-
SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class);
144-
when(mockCallback.getHostname()).thenReturn(ENDPOINT_HOSTNAME_FROM_ATTR);
145-
supplier.updateSslContext(mockCallback, null);
146-
verify(mockTlsContextManager, times(3))
147-
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(ENDPOINT_HOSTNAME_FROM_ATTR));
148-
}
149-
150-
@Test
151-
public void autoHostSniTrue_endpointHostNameIsNull_usesSniFromUpstreamTlsContext() {
152-
prepareSupplier(true, SNI_IN_UTC, SNI_IN_UTC);
153-
callUpdateSslContext(null);
154-
verify(mockTlsContextManager, times(2))
155-
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI_IN_UTC));
156-
verify(mockTlsContextManager, times(0))
157-
.releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI_IN_UTC));
158-
ArgumentCaptor<SslContextProvider.Callback> callbackCaptor =
159-
ArgumentCaptor.forClass(SslContextProvider.Callback.class);
160-
verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture());
161-
SslContextProvider.Callback capturedCallback = callbackCaptor.getValue();
162-
assertThat(capturedCallback).isNotNull();
163-
SslContext mockSslContext = mock(SslContext.class);
164-
capturedCallback.updateSslContext(mockSslContext);
165-
verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext));
166-
verify(mockTlsContextManager, times(1))
167-
.releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI_IN_UTC));
168-
SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class);
169-
when(mockCallback.getHostname()).thenReturn(null);
170-
supplier.updateSslContext(mockCallback, null);
171-
verify(mockTlsContextManager, times(3))
172-
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI_IN_UTC));
173-
}
174-
175-
@Test
176-
public void autoHostSniTrue_endpointHostNameIsEmpty_usesSniFromUpstreamTlsContext() {
177-
prepareSupplier(true, SNI_IN_UTC, SNI_IN_UTC);
178-
callUpdateSslContext("");
179-
verify(mockTlsContextManager, times(2))
180-
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI_IN_UTC));
181-
verify(mockTlsContextManager, times(0))
182-
.releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI_IN_UTC));
183-
ArgumentCaptor<SslContextProvider.Callback> callbackCaptor =
184-
ArgumentCaptor.forClass(SslContextProvider.Callback.class);
185-
verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture());
186-
SslContextProvider.Callback capturedCallback = callbackCaptor.getValue();
187-
assertThat(capturedCallback).isNotNull();
188-
SslContext mockSslContext = mock(SslContext.class);
189-
capturedCallback.updateSslContext(mockSslContext);
190-
verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext));
191-
verify(mockTlsContextManager, times(1))
192-
.releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI_IN_UTC));
193-
SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class);
194-
when(mockCallback.getHostname()).thenReturn("");
195-
supplier.updateSslContext(mockCallback, null);
196-
verify(mockTlsContextManager, times(3))
197-
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI_IN_UTC));
119+
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI));
198120
}
199121

200122
@Test
201123
public void get_onException() {
202-
prepareSupplier(false, null, "");
203-
callUpdateSslContext(ENDPOINT_HOSTNAME_FROM_ATTR);
124+
prepareSupplier();
125+
callUpdateSslContext();
204126
ArgumentCaptor<SslContextProvider.Callback> callbackCaptor =
205127
ArgumentCaptor.forClass(SslContextProvider.Callback.class);
206128
verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture());
@@ -210,33 +132,33 @@ public void get_onException() {
210132
capturedCallback.onException(exception);
211133
verify(mockCallback, times(1)).onException(eq(exception));
212134
verify(mockTlsContextManager, times(1))
213-
.releaseClientSslContextProvider(eq(mockSslContextProvider), eq(""));
135+
.releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI));
214136
}
215137

216138
@Test
217139
public void testClose() {
218-
prepareSupplier(false, null, "");
219-
callUpdateSslContext(ENDPOINT_HOSTNAME_FROM_ATTR);
140+
prepareSupplier();
141+
callUpdateSslContext();
220142
supplier.close();
221143
verify(mockTlsContextManager, times(1))
222-
.releaseClientSslContextProvider(eq(mockSslContextProvider), eq(""));
223-
supplier.updateSslContext(mockCallback, null);
144+
.releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI));
145+
supplier.updateSslContext(mockCallback, SNI);
224146
verify(mockTlsContextManager, times(3))
225-
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(""));
147+
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI));
226148
verify(mockTlsContextManager, times(1))
227-
.releaseClientSslContextProvider(any(SslContextProvider.class), eq(""));
149+
.releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI));
228150
}
229151

230152
@Test
231153
public void testClose_nullSslContextProvider() {
232-
prepareSupplier(false, null, "");
154+
prepareSupplier();
233155
doThrow(new NullPointerException()).when(mockTlsContextManager)
234-
.releaseClientSslContextProvider(null, "");
156+
.releaseClientSslContextProvider(null, SNI);
235157
supplier.close();
236158
verify(mockTlsContextManager, never())
237-
.releaseClientSslContextProvider(eq(mockSslContextProvider), eq(""));
238-
callUpdateSslContext(ENDPOINT_HOSTNAME_FROM_ATTR);
159+
.releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI));
160+
callUpdateSslContext();
239161
verify(mockTlsContextManager, times(1))
240-
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(""));
162+
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI));
241163
}
242164
}

0 commit comments

Comments
 (0)