|
28 | 28 | import static org.junit.Assert.fail; |
29 | 29 | import static org.mockito.ArgumentMatchers.any; |
30 | 30 | import static org.mockito.ArgumentMatchers.anyString; |
31 | | -import static org.mockito.Mockito.doNothing; |
32 | | -import static org.mockito.Mockito.mock; |
33 | | -import static org.mockito.Mockito.when; |
| 31 | +import static org.mockito.Mockito.*; |
34 | 32 |
|
35 | 33 | import com.google.common.util.concurrent.MoreExecutors; |
36 | 34 | import com.google.common.util.concurrent.SettableFuture; |
@@ -145,6 +143,30 @@ public void clientSecurityProtocolNegotiatorNewHandler_withTlsContextAttribute() |
145 | 143 | assertThat(newHandler).isInstanceOf(ClientSecurityHandler.class); |
146 | 144 | } |
147 | 145 |
|
| 146 | + @Test |
| 147 | + public void clientSecurityProtocolNegotiatorNewHandler_autoHostSni_hostnameIsPassedToClientSecurityHandler() { |
| 148 | + UpstreamTlsContext upstreamTlsContext = |
| 149 | + CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build(), null, true); |
| 150 | + ClientSecurityProtocolNegotiator pn = |
| 151 | + new ClientSecurityProtocolNegotiator(InternalProtocolNegotiators.plaintext()); |
| 152 | + GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); |
| 153 | + ChannelLogger logger = mock(ChannelLogger.class); |
| 154 | + doNothing().when(logger).log(any(ChannelLogLevel.class), anyString()); |
| 155 | + when(mockHandler.getNegotiationLogger()).thenReturn(logger); |
| 156 | + TlsContextManager mockTlsContextManager = mock(TlsContextManager.class); |
| 157 | + when(mockHandler.getEagAttributes()) |
| 158 | + .thenReturn( |
| 159 | + Attributes.newBuilder() |
| 160 | + .set(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, |
| 161 | + new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager)) |
| 162 | + .set(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME, FAKE_AUTHORITY) |
| 163 | + .build()); |
| 164 | + ChannelHandler newHandler = pn.newHandler(mockHandler); |
| 165 | + assertThat(newHandler).isNotNull(); |
| 166 | + assertThat(newHandler).isInstanceOf(ClientSecurityHandler.class); |
| 167 | + assertThat(((ClientSecurityHandler) newHandler).getSni()).isEqualTo(FAKE_AUTHORITY); |
| 168 | + } |
| 169 | + |
148 | 170 | @Test |
149 | 171 | public void clientSecurityHandler_addLast() |
150 | 172 | throws InterruptedException, TimeoutException, ExecutionException { |
@@ -197,59 +219,6 @@ protected void onException(Throwable throwable) { |
197 | 219 | .contains("ProtocolNegotiators.ClientTlsHandler"); |
198 | 220 | CommonCertProviderTestUtils.register0(); |
199 | 221 | } |
200 | | - |
201 | | - @Test |
202 | | - public void clientSecurityHandler_addLast() |
203 | | - throws InterruptedException, TimeoutException, ExecutionException { |
204 | | - FakeClock executor = new FakeClock(); |
205 | | - CommonCertProviderTestUtils.register(executor); |
206 | | - Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils |
207 | | - .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, |
208 | | - CA_PEM_FILE, null, null, null, null, null); |
209 | | - UpstreamTlsContext upstreamTlsContext = |
210 | | - CommonTlsContextTestsUtil |
211 | | - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, false); |
212 | | - |
213 | | - SslContextProviderSupplier sslContextProviderSupplier = |
214 | | - new SslContextProviderSupplier(upstreamTlsContext, |
215 | | - new TlsContextManagerImpl(bootstrapInfoForClient)); |
216 | | - ClientSecurityHandler clientSecurityHandler = |
217 | | - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); |
218 | | - pipeline.addLast(clientSecurityHandler); |
219 | | - channelHandlerCtx = pipeline.context(clientSecurityHandler); |
220 | | - assertNotNull(channelHandlerCtx); |
221 | | - |
222 | | - // kick off protocol negotiation. |
223 | | - pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); |
224 | | - final SettableFuture<Object> future = SettableFuture.create(); |
225 | | - sslContextProviderSupplier |
226 | | - .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { |
227 | | - @Override |
228 | | - public void updateSslContext(SslContext sslContext) { |
229 | | - future.set(sslContext); |
230 | | - } |
231 | | - |
232 | | - @Override |
233 | | - protected void onException(Throwable throwable) { |
234 | | - future.set(throwable); |
235 | | - } |
236 | | - }, null); |
237 | | - assertThat(executor.runDueTasks()).isEqualTo(1); |
238 | | - channel.runPendingTasks(); |
239 | | - Object fromFuture = future.get(2, TimeUnit.SECONDS); |
240 | | - assertThat(fromFuture).isInstanceOf(SslContext.class); |
241 | | - channel.runPendingTasks(); |
242 | | - channelHandlerCtx = pipeline.context(clientSecurityHandler); |
243 | | - assertThat(channelHandlerCtx).isNull(); |
244 | | - |
245 | | - // pipeline should have SslHandler and ClientTlsHandler |
246 | | - Iterator<Map.Entry<String, ChannelHandler>> iterator = pipeline.iterator(); |
247 | | - assertThat(iterator.next().getValue()).isInstanceOf(SslHandler.class); |
248 | | - // ProtocolNegotiators.ClientTlsHandler.class not accessible, get canonical name |
249 | | - assertThat(iterator.next().getValue().getClass().getCanonicalName()) |
250 | | - .contains("ProtocolNegotiators.ClientTlsHandler"); |
251 | | - CommonCertProviderTestUtils.register0(); |
252 | | - } |
253 | 222 |
|
254 | 223 | @Test |
255 | 224 | public void sniInClientSecurityHandler_autoHostSniIsTrue_usesEndpointHostname() { |
|
0 commit comments