diff --git a/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientContext.java b/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientContext.java index 247e2f10f9..164b41b96a 100644 --- a/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientContext.java +++ b/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientContext.java @@ -223,7 +223,7 @@ public static ClientContext create(StubSettings settings) throws IOException { transportChannelProvider = transportChannelProvider.withEndpoint(endpoint); } transportChannelProvider = transportChannelProvider.withUseS2A(endpointContext.useS2A()); - if (transportChannelProvider.needsMtlsEndpoint()) { + if (transportChannelProvider.needsMtlsEndpoint() && endpointContext.mtlsEndpoint() != null) { transportChannelProvider = transportChannelProvider.withMtlsEndpoint(endpointContext.mtlsEndpoint()); } diff --git a/gax-java/gax/src/test/java/com/google/api/gax/rpc/ClientContextTest.java b/gax-java/gax/src/test/java/com/google/api/gax/rpc/ClientContextTest.java index 826864a49c..c7a6731d03 100644 --- a/gax-java/gax/src/test/java/com/google/api/gax/rpc/ClientContextTest.java +++ b/gax-java/gax/src/test/java/com/google/api/gax/rpc/ClientContextTest.java @@ -73,8 +73,9 @@ import org.mockito.Mockito; class ClientContextTest { - private static final String DEFAULT_ENDPOINT = "test.googleapis.com"; private static final String DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"; + private static final String DEFAULT_ENDPOINT = "https://foo.googleapis.com"; + private static final String DEFAULT_MTLS_ENDPOINT = "https://foo.mtls.googleapis.com"; private static class InterceptingExecutor extends ScheduledThreadPoolExecutor { boolean shutdownCalled = false; @@ -115,6 +116,7 @@ private static class FakeTransportProvider implements TransportChannelProvider { final Map headers; final Credentials credentials; final String endpoint; + final String mtlsEndpoint; FakeTransportProvider( FakeTransportChannel transport, @@ -122,7 +124,8 @@ private static class FakeTransportProvider implements TransportChannelProvider { boolean shouldAutoClose, Map headers, Credentials credentials, - String endpoint) { + String endpoint, + String mtlsEndpoint) { this.transport = transport; this.executor = executor; this.shouldAutoClose = shouldAutoClose; @@ -130,6 +133,7 @@ private static class FakeTransportProvider implements TransportChannelProvider { this.transport.setHeaders(headers); this.credentials = credentials; this.endpoint = endpoint; + this.mtlsEndpoint = mtlsEndpoint; } @Override @@ -155,7 +159,8 @@ public TransportChannelProvider withExecutor(Executor executor) { this.shouldAutoClose, this.headers, this.credentials, - this.endpoint); + this.endpoint, + this.mtlsEndpoint); } @Override @@ -171,7 +176,8 @@ public TransportChannelProvider withHeaders(Map headers) { this.shouldAutoClose, headers, this.credentials, - this.endpoint); + this.endpoint, + this.mtlsEndpoint); } @Override @@ -192,7 +198,8 @@ public TransportChannelProvider withEndpoint(String endpoint) { this.shouldAutoClose, this.headers, this.credentials, - endpoint); + endpoint, + this.mtlsEndpoint); } @Override @@ -230,7 +237,30 @@ public TransportChannelProvider withCredentials(Credentials credentials) { this.shouldAutoClose, this.headers, credentials, - this.endpoint); + this.endpoint, + this.mtlsEndpoint); + } + + @Override + public boolean needsMtlsEndpoint() { + return this.mtlsEndpoint == null; + } + + @Override + public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) { + // Throws an exception if this is passed with a null value. This should never + // happen as GAPICs should always have a default mtlsEndpoint value + if (mtlsEndpoint == null) { + throw new IllegalArgumentException("mtlsEndpoint is null"); + } + return new FakeTransportProvider( + this.transport, + this.executor, + this.shouldAutoClose, + this.headers, + this.credentials, + this.endpoint, + mtlsEndpoint); } } @@ -278,7 +308,8 @@ private void runTest( shouldAutoClose, needHeaders ? null : headers, null, - DEFAULT_ENDPOINT); + DEFAULT_ENDPOINT, + DEFAULT_MTLS_ENDPOINT); Credentials credentials = Mockito.mock(Credentials.class); ApiClock clock = Mockito.mock(ApiClock.class); Watchdog watchdog = @@ -352,7 +383,8 @@ void testWatchdogProvider() throws IOException { InterceptingExecutor executor = new InterceptingExecutor(1); FakeTransportChannel transportChannel = FakeTransportChannel.create(new FakeChannel()); FakeTransportProvider transportProvider = - new FakeTransportProvider(transportChannel, executor, true, null, null, DEFAULT_ENDPOINT); + new FakeTransportProvider( + transportChannel, executor, true, null, null, DEFAULT_ENDPOINT, DEFAULT_MTLS_ENDPOINT); ApiClock clock = Mockito.mock(ApiClock.class); builder.setClock(clock); @@ -391,7 +423,8 @@ void testMergeHeaders_getQuotaProjectIdFromHeadersProvider() throws IOException InterceptingExecutor executor = new InterceptingExecutor(1); FakeTransportChannel transportChannel = FakeTransportChannel.create(new FakeChannel()); FakeTransportProvider transportProvider = - new FakeTransportProvider(transportChannel, executor, true, null, null, DEFAULT_ENDPOINT); + new FakeTransportProvider( + transportChannel, executor, true, null, null, DEFAULT_ENDPOINT, DEFAULT_MTLS_ENDPOINT); HeaderProvider headerProvider = Mockito.mock(HeaderProvider.class); Mockito.when(headerProvider.getHeaders()).thenReturn(ImmutableMap.of("header_k1", "v1")); @@ -427,7 +460,8 @@ void testMergeHeaders_getQuotaProjectIdFromSettings() throws IOException { InterceptingExecutor executor = new InterceptingExecutor(1); FakeTransportChannel transportChannel = FakeTransportChannel.create(new FakeChannel()); FakeTransportProvider transportProvider = - new FakeTransportProvider(transportChannel, executor, true, null, null, DEFAULT_ENDPOINT); + new FakeTransportProvider( + transportChannel, executor, true, null, null, DEFAULT_ENDPOINT, DEFAULT_MTLS_ENDPOINT); HeaderProvider headerProvider = new HeaderProvider() { @@ -473,7 +507,8 @@ void testMergeHeaders_noQuotaProjectIdSet() throws IOException { InterceptingExecutor executor = new InterceptingExecutor(1); FakeTransportChannel transportChannel = FakeTransportChannel.create(new FakeChannel()); FakeTransportProvider transportProvider = - new FakeTransportProvider(transportChannel, executor, true, null, null, DEFAULT_ENDPOINT); + new FakeTransportProvider( + transportChannel, executor, true, null, null, DEFAULT_ENDPOINT, DEFAULT_MTLS_ENDPOINT); HeaderProvider headerProvider = Mockito.mock(HeaderProvider.class); Mockito.when(headerProvider.getHeaders()).thenReturn(ImmutableMap.of("header_k1", "v1")); @@ -504,7 +539,8 @@ void testHidingQuotaProjectId_quotaSetFromSetting() throws IOException { InterceptingExecutor executor = new InterceptingExecutor(1); FakeTransportChannel transportChannel = FakeTransportChannel.create(new FakeChannel()); FakeTransportProvider transportProvider = - new FakeTransportProvider(transportChannel, executor, true, null, null, DEFAULT_ENDPOINT); + new FakeTransportProvider( + transportChannel, executor, true, null, null, DEFAULT_ENDPOINT, DEFAULT_MTLS_ENDPOINT); Map> metaDataWithQuota = ImmutableMap.of( "k1", @@ -545,7 +581,8 @@ void testHidingQuotaProjectId_noQuotaSetFromSetting() throws IOException { InterceptingExecutor executor = new InterceptingExecutor(1); FakeTransportChannel transportChannel = FakeTransportChannel.create(new FakeChannel()); FakeTransportProvider transportProvider = - new FakeTransportProvider(transportChannel, executor, true, null, null, DEFAULT_ENDPOINT); + new FakeTransportProvider( + transportChannel, executor, true, null, null, DEFAULT_ENDPOINT, DEFAULT_MTLS_ENDPOINT); Map> metaData = ImmutableMap.of("k1", Collections.singletonList("v1")); final Credentials credentialsWithoutQuotaProjectId = Mockito.mock(GoogleCredentials.class); Mockito.when(credentialsWithoutQuotaProjectId.getRequestMetadata(null)).thenReturn(metaData); @@ -581,7 +618,8 @@ void testQuotaProjectId_worksWithNullCredentials() throws IOException { true, null, Mockito.mock(Credentials.class), - DEFAULT_ENDPOINT); + DEFAULT_ENDPOINT, + DEFAULT_MTLS_ENDPOINT); final FakeClientSettings.Builder settingsBuilder = new FakeClientSettings.Builder(); @@ -602,7 +640,8 @@ void testUserAgentInternalOnly() throws Exception { true, null, null, - DEFAULT_ENDPOINT); + DEFAULT_ENDPOINT, + DEFAULT_MTLS_ENDPOINT); ClientSettings.Builder builder = new FakeClientSettings.Builder() @@ -630,7 +669,8 @@ void testUserAgentExternalOnly() throws Exception { true, null, null, - DEFAULT_ENDPOINT); + DEFAULT_ENDPOINT, + DEFAULT_MTLS_ENDPOINT); ClientSettings.Builder builder = new FakeClientSettings.Builder() @@ -658,7 +698,8 @@ void testUserAgentConcat() throws Exception { true, null, null, - DEFAULT_ENDPOINT); + DEFAULT_ENDPOINT, + DEFAULT_MTLS_ENDPOINT); ClientSettings.Builder builder = new FakeClientSettings.Builder() @@ -743,7 +784,8 @@ private Map setupTestForCredentialTokenUsageMetricsAndGetTranspo true, null, null, - DEFAULT_ENDPOINT); + DEFAULT_ENDPOINT, + DEFAULT_MTLS_ENDPOINT); ClientSettings.Builder builder = new FakeClientSettings.Builder() @@ -759,31 +801,28 @@ private Map setupTestForCredentialTokenUsageMetricsAndGetTranspo return transportChannel.getHeaders(); } - private static String endpoint = "https://foo.googleapis.com"; - private static String mtlsEndpoint = "https://foo.mtls.googleapis.com"; - @Test void testSwitchToMtlsEndpointAllowed() throws IOException { - StubSettings settings = new FakeStubSettings.Builder().setEndpoint(endpoint).build(); + StubSettings settings = new FakeStubSettings.Builder().setEndpoint(DEFAULT_ENDPOINT).build(); assertFalse(settings.getSwitchToMtlsEndpointAllowed()); - assertEquals(endpoint, settings.getEndpoint()); + assertEquals(DEFAULT_ENDPOINT, settings.getEndpoint()); settings = new FakeStubSettings.Builder() - .setEndpoint(endpoint) + .setEndpoint(DEFAULT_ENDPOINT) .setSwitchToMtlsEndpointAllowed(true) .build(); assertTrue(settings.getSwitchToMtlsEndpointAllowed()); - assertEquals(endpoint, settings.getEndpoint()); + assertEquals(DEFAULT_ENDPOINT, settings.getEndpoint()); // Test setEndpoint sets the switchToMtlsEndpointAllowed value to false. settings = new FakeStubSettings.Builder() .setSwitchToMtlsEndpointAllowed(true) - .setEndpoint(endpoint) + .setEndpoint(DEFAULT_ENDPOINT) .build(); assertFalse(settings.getSwitchToMtlsEndpointAllowed()); - assertEquals(endpoint, settings.getEndpoint()); + assertEquals(DEFAULT_ENDPOINT, settings.getEndpoint()); } @Test @@ -795,7 +834,8 @@ void testExecutorSettings() throws Exception { true, null, null, - DEFAULT_ENDPOINT); + DEFAULT_ENDPOINT, + DEFAULT_MTLS_ENDPOINT); ClientSettings.Builder builder = new FakeClientSettings.Builder() @@ -842,7 +882,8 @@ void testExecutorSettings() throws Exception { true, null, null, - DEFAULT_ENDPOINT)); + DEFAULT_ENDPOINT, + DEFAULT_MTLS_ENDPOINT)); context = ClientContext.create(builder.build()); transportChannel = (FakeTransportChannel) context.getTransportChannel(); assertThat(transportChannel.getExecutor()).isSameInstanceAs(executorProvider.getExecutor()); @@ -864,7 +905,13 @@ private GdchCredentials getMockGdchCredentials() throws IOException { private TransportChannelProvider getFakeTransportChannelProvider() { return new FakeTransportProvider( - FakeTransportChannel.create(new FakeChannel()), null, true, null, null, DEFAULT_ENDPOINT); + FakeTransportChannel.create(new FakeChannel()), + null, + true, + null, + null, + DEFAULT_ENDPOINT, + DEFAULT_MTLS_ENDPOINT); } // EndpointContext will construct a valid endpoint if nothing is provided @@ -872,7 +919,7 @@ private TransportChannelProvider getFakeTransportChannelProvider() { void testCreateClientContext_withGdchCredentialNoAudienceNoEndpoint() throws IOException { TransportChannelProvider transportChannelProvider = new FakeTransportProvider( - FakeTransportChannel.create(new FakeChannel()), null, true, null, null, null); + FakeTransportChannel.create(new FakeChannel()), null, true, null, null, null, null); Credentials creds = getMockGdchCredentials(); CredentialsProvider provider = FixedCredentialsProvider.create(creds); @@ -899,7 +946,7 @@ void testCreateClientContext_withGdchCredentialNoAudienceEmptyEndpoint_throws() throws IOException { TransportChannelProvider transportChannelProvider = new FakeTransportProvider( - FakeTransportChannel.create(new FakeChannel()), null, true, null, null, null); + FakeTransportChannel.create(new FakeChannel()), null, true, null, null, null, null); Credentials creds = getMockGdchCredentials(); CredentialsProvider provider = FixedCredentialsProvider.create(creds); @@ -922,7 +969,7 @@ void testCreateClientContext_withGdchCredentialWithoutAudienceWithEndpoint_corre throws IOException { TransportChannelProvider transportChannelProvider = new FakeTransportProvider( - FakeTransportChannel.create(new FakeChannel()), null, true, null, null, null); + FakeTransportChannel.create(new FakeChannel()), null, true, null, null, null, null); Credentials creds = getMockGdchCredentials(); // it should correctly create a client context with gdch creds and null audience @@ -1034,7 +1081,7 @@ void testCreateClientContext_withNonGdchCredentialAndAnyAudience_throws() throws void testCreateClientContext_SetEndpointViaClientSettings() throws IOException { TransportChannelProvider transportChannelProvider = new FakeTransportProvider( - FakeTransportChannel.create(new FakeChannel()), null, true, null, null, null); + FakeTransportChannel.create(new FakeChannel()), null, true, null, null, null, null); StubSettings settings = new FakeStubSettings.Builder() .setEndpoint(DEFAULT_ENDPOINT) @@ -1060,7 +1107,8 @@ void testCreateClientContext_SetEndpointViaTransportChannelProvider() throws IOE true, null, null, - transportChannelProviderEndpoint); + transportChannelProviderEndpoint, + DEFAULT_MTLS_ENDPOINT); StubSettings settings = new FakeStubSettings.Builder() .setEndpoint(null) @@ -1088,7 +1136,8 @@ void testCreateClientContext_SetEndpointViaClientSettingsAndTransportChannelProv true, null, null, - transportChannelProviderEndpoint); + transportChannelProviderEndpoint, + DEFAULT_MTLS_ENDPOINT); StubSettings settings = new FakeStubSettings.Builder() .setEndpoint(clientSettingsEndpoint) @@ -1111,7 +1160,7 @@ void testCreateClientContext_SetEndpointViaClientSettingsAndTransportChannelProv void testCreateClientContext_doNotSetUniverseDomain() throws IOException { TransportChannelProvider transportChannelProvider = new FakeTransportProvider( - FakeTransportChannel.create(new FakeChannel()), null, true, null, null, null); + FakeTransportChannel.create(new FakeChannel()), null, true, null, null, null, null); StubSettings settings = new FakeStubSettings.Builder() .setEndpoint(null) @@ -1130,7 +1179,7 @@ void testCreateClientContext_doNotSetUniverseDomain() throws IOException { void testCreateClientContext_setUniverseDomain() throws IOException { TransportChannelProvider transportChannelProvider = new FakeTransportProvider( - FakeTransportChannel.create(new FakeChannel()), null, true, null, null, null); + FakeTransportChannel.create(new FakeChannel()), null, true, null, null, null, null); String universeDomain = "testdomain.com"; StubSettings settings = new FakeStubSettings.Builder().setEndpoint(null).setUniverseDomain(universeDomain).build(); @@ -1163,7 +1212,13 @@ void testSetApiKey_createsApiCredentials() throws IOException { FakeTransportChannel transportChannel = FakeTransportChannel.create(new FakeChannel()); FakeTransportProvider transportProvider = new FakeTransportProvider( - transportChannel, executor, true, ImmutableMap.of(), null, DEFAULT_ENDPOINT); + transportChannel, + executor, + true, + ImmutableMap.of(), + null, + DEFAULT_ENDPOINT, + DEFAULT_MTLS_ENDPOINT); builder.setTransportChannelProvider(transportProvider); HeaderProvider headerProvider = Mockito.mock(HeaderProvider.class); Mockito.when(headerProvider.getHeaders()).thenReturn(ImmutableMap.of()); @@ -1184,7 +1239,13 @@ void testSetApiKey_withDefaultCredentials_overridesCredentials() throws IOExcept FakeTransportChannel transportChannel = FakeTransportChannel.create(new FakeChannel()); FakeTransportProvider transportProvider = new FakeTransportProvider( - transportChannel, executor, true, ImmutableMap.of(), null, DEFAULT_ENDPOINT); + transportChannel, + executor, + true, + ImmutableMap.of(), + null, + DEFAULT_ENDPOINT, + DEFAULT_MTLS_ENDPOINT); builder.setTransportChannelProvider(transportProvider); HeaderProvider headerProvider = Mockito.mock(HeaderProvider.class); Mockito.when(headerProvider.getHeaders()).thenReturn(ImmutableMap.of()); @@ -1196,4 +1257,34 @@ void testSetApiKey_withDefaultCredentials_overridesCredentials() throws IOExcept FakeCallContext fakeCallContext = (FakeCallContext) context.getDefaultCallContext(); assertThat(fakeCallContext.getCredentials()).isInstanceOf(ApiKeyCredentials.class); } + + // This test case is added to cover a special case with BigTable. BigTable's EnhancedStubSettings + // wrappers do not directly inherit from the generated StubSettings. The wrappers must directly + // set the endpoint values since they are set in the generated StubSettings. This test case mimics + // the old behavior where BigTable doesn't set an mtlsEndpoint value. + @Test + void test_nullMtlsEndpointIsNotPassedToTransportChannel() throws IOException { + // Set the mtlsEndpoint in the TransportChannelProvider as null. This configures the + // ClientContext to attempt to pass the mtlsEndpoint over. + TransportChannelProvider transportChannelProvider = + new FakeTransportProvider( + FakeTransportChannel.create(new FakeChannel()), null, true, null, null, null, null); + // TransportChannelProvider would try to get the resolved mtlsEndpoint + Truth.assertThat(transportChannelProvider.needsMtlsEndpoint()).isTrue(); + + StubSettings settings = + new FakeStubSettings.Builder() + .setEndpoint(DEFAULT_ENDPOINT) + // Set this to be null so that the resolved mtls endpoint is null + // This resolved value should not be passed to the TransportChannelProvider + .setMtlsEndpoint(null) + .build(); + ClientSettings.Builder clientSettingsBuilder = new FakeClientSettings.Builder(settings); + clientSettingsBuilder.setTransportChannelProvider(transportChannelProvider); + ClientSettings clientSettings = clientSettingsBuilder.build(); + + // This call should not result in an exception being thrown as a null resolved mtlsEndpoint + // is not passed to the TransportChannelProvider + ClientContext.create(clientSettings); + } }