Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.auth.mtls;

import java.io.IOException;
import java.util.concurrent.atomic.AtomicBoolean;

// TODO: Remove once the actual implementation of this class in the auth library has a config
// option that allows us to skip the most expensive code paths during tests.
public class DefaultMtlsProviderFactory {
public static final AtomicBoolean SKIP_MTLS = new AtomicBoolean(false);

public static MtlsProvider create() throws IOException {
if (SKIP_MTLS.get()) {
return null;
}
// Note: The caller should handle CertificateSourceUnavailableException gracefully, since
// it is an expected error case. All other IOExceptions are unexpected and should be surfaced
// up the call stack.
MtlsProvider mtlsProvider = new X509Provider();
if (mtlsProvider.isAvailable()) {
return mtlsProvider;
}
mtlsProvider = new SecureConnectProvider();
if (mtlsProvider.isAvailable()) {
return mtlsProvider;
}
throw new CertificateSourceUnavailableException(
"No Certificate Source is available on this device.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import static com.google.cloud.spanner.MockSpannerTestUtil.UPDATE_COUNT;
import static com.google.cloud.spanner.MockSpannerTestUtil.UPDATE_STATEMENT;

import com.google.auth.mtls.DefaultMtlsProviderFactory;
import com.google.cloud.NoCredentials;
import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult;
import io.grpc.ManagedChannelBuilder;
Expand All @@ -54,8 +55,12 @@ public abstract class AbstractAsyncTransactionTest {
Spanner spanner;
Spanner spannerWithEmptySessionPool;

private static boolean originalSkipMtls;

@BeforeClass
public static void setup() throws Exception {
originalSkipMtls = DefaultMtlsProviderFactory.SKIP_MTLS.get();
DefaultMtlsProviderFactory.SKIP_MTLS.set(true);
mockSpanner = new MockSpannerServiceImpl();
mockSpanner.setAbortProbability(0.0D);
mockSpanner.putStatementResult(
Expand Down Expand Up @@ -85,6 +90,7 @@ public static void teardown() throws Exception {
server.shutdown();
server.awaitTermination();
executor.shutdown();
DefaultMtlsProviderFactory.SKIP_MTLS.set(originalSkipMtls);
}

@Before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.cloud.spanner;

import com.google.api.gax.grpc.testing.LocalChannelProvider;
import com.google.auth.mtls.DefaultMtlsProviderFactory;
import com.google.cloud.NoCredentials;
import com.google.cloud.spanner.admin.database.v1.MockDatabaseAdminImpl;
import com.google.cloud.spanner.admin.instance.v1.MockInstanceAdminImpl;
Expand Down Expand Up @@ -48,8 +49,12 @@ abstract class AbstractMockServerTest {

protected Spanner spanner;

private static boolean originalSkipMtls;

@BeforeClass
public static void startMockServer() throws IOException {
originalSkipMtls = DefaultMtlsProviderFactory.SKIP_MTLS.get();
DefaultMtlsProviderFactory.SKIP_MTLS.set(true);
mockSpanner = new MockSpannerServiceImpl();
mockSpanner.setAbortProbability(0.0D); // We don't want any unpredictable aborted transactions.
mockInstanceAdmin = new MockInstanceAdminImpl();
Expand Down Expand Up @@ -93,6 +98,7 @@ public void getOperation(
public static void stopMockServer() throws InterruptedException {
server.shutdown();
server.awaitTermination();
DefaultMtlsProviderFactory.SKIP_MTLS.set(originalSkipMtls);
}

@Before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.cloud.spanner;

import com.google.api.gax.grpc.testing.LocalChannelProvider;
import com.google.auth.mtls.DefaultMtlsProviderFactory;
import com.google.cloud.NoCredentials;
import io.grpc.ForwardingServerCall;
import io.grpc.ManagedChannelBuilder;
Expand Down Expand Up @@ -51,8 +52,12 @@ abstract class AbstractNettyMockServerTest {

protected Spanner spanner;

private static boolean originalSkipMtls;

@BeforeClass
public static void startMockServer() throws IOException {
originalSkipMtls = DefaultMtlsProviderFactory.SKIP_MTLS.get();
DefaultMtlsProviderFactory.SKIP_MTLS.set(true);
mockSpanner = new MockSpannerServiceImpl();
mockSpanner.setAbortProbability(0.0D); // We don't want any unpredictable aborted transactions.

Expand Down Expand Up @@ -93,6 +98,7 @@ public static void stopMockServer() throws InterruptedException {
server.shutdown();
server.awaitTermination();
executor.shutdown();
DefaultMtlsProviderFactory.SKIP_MTLS.set(originalSkipMtls);
}

@Before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static com.google.common.truth.Truth.assertThat;

import com.google.api.gax.grpc.testing.LocalChannelProvider;
import com.google.auth.mtls.DefaultMtlsProviderFactory;
import com.google.cloud.NoCredentials;
import com.google.cloud.grpc.GrpcTransportOptions;
import com.google.cloud.grpc.GrpcTransportOptions.ExecutorFactory;
Expand Down Expand Up @@ -88,8 +89,12 @@ public class BackendExhaustedTest {
private Spanner spanner;
private DatabaseClientImpl client;

private static boolean originalSkipMtls;

@BeforeClass
public static void startStaticServer() throws IOException {
originalSkipMtls = DefaultMtlsProviderFactory.SKIP_MTLS.get();
DefaultMtlsProviderFactory.SKIP_MTLS.set(true);
mockSpanner = new MockSpannerServiceImpl();
mockSpanner.setAbortProbability(0.0D); // We don't want any unpredictable aborted transactions.
mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT));
Expand All @@ -115,6 +120,7 @@ public static void stopServer() throws InterruptedException {
// Force a shutdown as there are still requests stuck in the server.
server.shutdownNow();
server.awaitTermination();
DefaultMtlsProviderFactory.SKIP_MTLS.set(originalSkipMtls);
}

@Before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static org.junit.Assume.assumeFalse;

import com.google.api.gax.grpc.testing.LocalChannelProvider;
import com.google.auth.mtls.DefaultMtlsProviderFactory;
import com.google.cloud.NoCredentials;
import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime;
import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult;
Expand Down Expand Up @@ -58,8 +59,12 @@ public class BatchCreateSessionsSlowTest {
private static LocalChannelProvider channelProvider;
private Spanner spanner;

private static boolean originalSkipMtls;

@BeforeClass
public static void startStaticServer() throws IOException {
originalSkipMtls = DefaultMtlsProviderFactory.SKIP_MTLS.get();
DefaultMtlsProviderFactory.SKIP_MTLS.set(true);
mockSpanner = new MockSpannerServiceImpl();
mockSpanner.setAbortProbability(0.0D); // We don't want any unpredictable aborted transactions.
mockSpanner.putStatementResult(
Expand All @@ -82,6 +87,7 @@ public static void startStaticServer() throws IOException {
public static void stopServer() throws InterruptedException {
server.shutdown();
server.awaitTermination();
DefaultMtlsProviderFactory.SKIP_MTLS.set(originalSkipMtls);
}

@Before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.hamcrest.MatcherAssert.assertThat;

import com.google.api.gax.grpc.testing.LocalChannelProvider;
import com.google.auth.mtls.DefaultMtlsProviderFactory;
import com.google.cloud.NoCredentials;
import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime;
import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult;
Expand Down Expand Up @@ -77,8 +78,12 @@ public class BatchCreateSessionsTest {
private static Server server;
private static LocalChannelProvider channelProvider;

private static boolean originalSkipMtls;

@BeforeClass
public static void startStaticServer() throws IOException {
originalSkipMtls = DefaultMtlsProviderFactory.SKIP_MTLS.get();
DefaultMtlsProviderFactory.SKIP_MTLS.set(true);
mockSpanner = new MockSpannerServiceImpl();
mockSpanner.setAbortProbability(0.0D); // We don't want any unpredictable aborted transactions.
mockSpanner.putStatementResult(StatementResult.query(SELECT1AND2, SELECT1_RESULTSET));
Expand All @@ -97,6 +102,7 @@ public static void startStaticServer() throws IOException {
public static void stopServer() throws InterruptedException {
server.shutdown();
server.awaitTermination();
DefaultMtlsProviderFactory.SKIP_MTLS.set(originalSkipMtls);
}

@Before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.junit.Assert.assertTrue;
import static org.junit.Assume.assumeFalse;

import com.google.auth.mtls.DefaultMtlsProviderFactory;
import com.google.cloud.NoCredentials;
import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult;
import com.google.common.util.concurrent.ListeningExecutorService;
Expand Down Expand Up @@ -112,8 +113,12 @@ public static Collection<Object[]> data() {

private static Level originalLogLevel;

private static boolean originalSkipMtls;

@BeforeClass
public static void startServer() throws IOException {
originalSkipMtls = DefaultMtlsProviderFactory.SKIP_MTLS.get();
DefaultMtlsProviderFactory.SKIP_MTLS.set(true);
mockSpanner = new MockSpannerServiceImpl();
mockSpanner.setAbortProbability(0.0D); // We don't want any unpredictable aborted transactions.
mockSpanner.putStatementResult(StatementResult.query(SELECT1, SELECT1_RESULTSET));
Expand Down Expand Up @@ -166,6 +171,7 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
public static void stopServer() throws InterruptedException {
server.shutdown();
server.awaitTermination();
DefaultMtlsProviderFactory.SKIP_MTLS.set(originalSkipMtls);
}

@BeforeClass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import com.google.api.gax.rpc.ApiCallContext;
import com.google.api.gax.rpc.ServerStream;
import com.google.api.gax.rpc.StatusCode;
import com.google.auth.mtls.DefaultMtlsProviderFactory;
import com.google.cloud.ByteArray;
import com.google.cloud.NoCredentials;
import com.google.cloud.Timestamp;
Expand Down Expand Up @@ -208,8 +209,12 @@ public class DatabaseClientImplTest {
private Spanner spannerWithEmptySessionPool;
private static ExecutorService executor;

private static boolean originalSkipMtls;

@BeforeClass
public static void startStaticServer() throws IOException {
originalSkipMtls = DefaultMtlsProviderFactory.SKIP_MTLS.get();
DefaultMtlsProviderFactory.SKIP_MTLS.set(true);
mockSpanner = new MockSpannerServiceImpl();
mockSpanner.setAbortProbability(0.0D); // We don't want any unpredictable aborted transactions.
mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT));
Expand Down Expand Up @@ -262,6 +267,7 @@ public static void stopServer() throws InterruptedException {
server.shutdown();
server.awaitTermination();
executor.shutdown();
DefaultMtlsProviderFactory.SKIP_MTLS.set(originalSkipMtls);
}

@Before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import static org.junit.Assert.assertTrue;

import com.google.api.gax.grpc.testing.LocalChannelProvider;
import com.google.auth.mtls.DefaultMtlsProviderFactory;
import com.google.cloud.NoCredentials;
import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult;
import com.google.cloud.spanner.Options.RpcPriority;
Expand Down Expand Up @@ -82,8 +83,12 @@ public class DatabaseClientImplWithDefaultRWTransactionOptionsTest {
private DatabaseClient clientWithSerializableOption;
private DatabaseClient clientWithSerOptimisticOption;

private static boolean originalSkipMtls;

@BeforeClass
public static void startStaticServer() throws IOException {
originalSkipMtls = DefaultMtlsProviderFactory.SKIP_MTLS.get();
DefaultMtlsProviderFactory.SKIP_MTLS.set(true);
mockSpanner = new MockSpannerServiceImpl();
mockSpanner.setAbortProbability(0.0D); // We don't want any unpredictable aborted transactions.
mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT));
Expand Down Expand Up @@ -112,6 +117,7 @@ public static void startStaticServer() throws IOException {
public static void stopServer() throws InterruptedException {
server.shutdown();
server.awaitTermination();
DefaultMtlsProviderFactory.SKIP_MTLS.set(originalSkipMtls);
}

@Before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.google.api.core.ApiFutures;
import com.google.api.core.SettableApiFuture;
import com.google.api.gax.grpc.testing.LocalChannelProvider;
import com.google.auth.mtls.DefaultMtlsProviderFactory;
import com.google.cloud.NoCredentials;
import com.google.cloud.spanner.AsyncResultSet.CallbackResponse;
import com.google.cloud.spanner.AsyncTransactionManager.AsyncTransactionStep;
Expand Down Expand Up @@ -139,8 +140,12 @@ public class InlineBeginTransactionTest {

protected Spanner spanner;

private static boolean originalSkipMtls;

@BeforeClass
public static void startStaticServer() throws IOException {
originalSkipMtls = DefaultMtlsProviderFactory.SKIP_MTLS.get();
DefaultMtlsProviderFactory.SKIP_MTLS.set(true);
mockSpanner = new MockSpannerServiceImpl();
mockSpanner.setAbortProbability(0.0D); // We don't want any unpredictable aborted transactions.
mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT));
Expand Down Expand Up @@ -177,6 +182,7 @@ public static void startStaticServer() throws IOException {
public static void stopServer() throws InterruptedException {
server.shutdown();
server.awaitTermination();
DefaultMtlsProviderFactory.SKIP_MTLS.set(originalSkipMtls);
}

@Before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import com.google.api.gax.core.GaxProperties;
import com.google.api.gax.grpc.testing.LocalChannelProvider;
import com.google.auth.mtls.DefaultMtlsProviderFactory;
import com.google.cloud.NoCredentials;
import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult;
import com.google.common.base.Stopwatch;
Expand Down Expand Up @@ -166,6 +167,8 @@ public class OpenTelemetrySpanTest {

private int expectedReadWriteTransactionErrorWithBeginTransactionEventsCount = 11;

private static boolean originalSkipMtls;

@BeforeClass
public static void setupOpenTelemetry() {
SpannerOptions.resetActiveTracingFramework();
Expand All @@ -192,6 +195,9 @@ public static void startStaticServer() throws Exception {
modifiersField.setInt(field, field.getModifiers() & ~Modifier.FINAL);
field.set(null, failOnOverkillTraceComponent);

originalSkipMtls = DefaultMtlsProviderFactory.SKIP_MTLS.get();
DefaultMtlsProviderFactory.SKIP_MTLS.set(true);

mockSpanner = new MockSpannerServiceImpl();
mockSpanner.setAbortProbability(0.0D); // We don't want any unpredictable aborted transactions.
mockSpanner.putStatementResult(StatementResult.query(SELECT1, SELECT1_RESULTSET));
Expand All @@ -217,6 +223,7 @@ public static void stopServer() throws InterruptedException {
server.shutdown();
server.awaitTermination();
}
DefaultMtlsProviderFactory.SKIP_MTLS.set(originalSkipMtls);
}

@Before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import com.google.auth.mtls.DefaultMtlsProviderFactory;
import com.google.cloud.NoCredentials;
import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -81,8 +82,12 @@ public class PgNumericTest {
private Spanner spanner;
private DatabaseClient databaseClient;

private static boolean originalSkipMtls;

@BeforeClass
public static void beforeClass() throws Exception {
originalSkipMtls = DefaultMtlsProviderFactory.SKIP_MTLS.get();
DefaultMtlsProviderFactory.SKIP_MTLS.set(true);
mockSpanner = new MockSpannerServiceImpl();
mockSpanner.setAbortProbability(0.0D);

Expand All @@ -94,6 +99,7 @@ public static void beforeClass() throws Exception {
public static void afterClass() throws Exception {
server.shutdown();
server.awaitTermination();
DefaultMtlsProviderFactory.SKIP_MTLS.set(originalSkipMtls);
}

@Before
Expand Down
Loading
Loading