4646import com .google .auth .ApiKeyCredentials ;
4747import com .google .auth .Credentials ;
4848import com .google .auth .oauth2 .ComputeEngineCredentials ;
49+ import com .google .auth .oauth2 .S2A ;
4950import com .google .common .annotations .VisibleForTesting ;
5051import com .google .common .base .Preconditions ;
5152import com .google .common .collect .ImmutableList ;
5455import io .grpc .CallCredentials ;
5556import io .grpc .ChannelCredentials ;
5657import io .grpc .Grpc ;
58+ import io .grpc .InsecureChannelCredentials ;
5759import io .grpc .ManagedChannel ;
5860import io .grpc .ManagedChannelBuilder ;
5961import io .grpc .TlsChannelCredentials ;
6062import io .grpc .alts .GoogleDefaultChannelCredentials ;
6163import io .grpc .auth .MoreCallCredentials ;
64+ import io .grpc .s2a .S2AChannelCredentials ;
6265import java .io .File ;
6366import java .io .IOException ;
6467import java .nio .charset .StandardCharsets ;
@@ -99,6 +102,12 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
99102 @ VisibleForTesting
100103 static final String DIRECT_PATH_ENV_ENABLE_XDS = "GOOGLE_CLOUD_ENABLE_DIRECT_PATH_XDS" ;
101104
105+ private static final String S2A_ENV_ENABLE_USE_S2A = "EXPERIMENTAL_GOOGLE_API_USE_S2A" ;
106+ private static final String MTLS_MDS_ROOT = "/run/google-mds-mtls/root.crt" ;
107+ // The mTLS MDS credentials are formatted as the concatenation of a PEM-encoded certificate chain
108+ // followed by a PEM-encoded private key.
109+ private static final String MTLS_MDS_CERT_CHAIN_AND_KEY = "/run/google-mds-mtls/client.key" ;
110+
102111 static final long DIRECT_PATH_KEEP_ALIVE_TIME_SECONDS = 3600 ;
103112 static final long DIRECT_PATH_KEEP_ALIVE_TIMEOUT_SECONDS = 20 ;
104113 static final String GCE_PRODUCTION_NAME_PRIOR_2016 = "Google" ;
@@ -108,6 +117,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
108117 private final Executor executor ;
109118 private final HeaderProvider headerProvider ;
110119 private final String endpoint ;
120+ private final String mtlsEndpoint ;
111121 // TODO: remove. envProvider currently provides DirectPath environment variable, and is only used
112122 // during initial rollout for DirectPath. This provider will be removed once the DirectPath
113123 // environment is not used.
@@ -136,6 +146,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
136146 this .executor = builder .executor ;
137147 this .headerProvider = builder .headerProvider ;
138148 this .endpoint = builder .endpoint ;
149+ this .mtlsEndpoint = builder .mtlsEndpoint ;
139150 this .mtlsProvider = builder .mtlsProvider ;
140151 this .envProvider = builder .envProvider ;
141152 this .interceptorProvider = builder .interceptorProvider ;
@@ -211,6 +222,10 @@ public boolean needsEndpoint() {
211222 return endpoint == null ;
212223 }
213224
225+ public boolean needsMtlsEndpoint () {
226+ return mtlsEndpoint == null ;
227+ }
228+
214229 /**
215230 * Specify the endpoint the channel should connect to.
216231 *
@@ -225,6 +240,20 @@ public TransportChannelProvider withEndpoint(String endpoint) {
225240 return toBuilder ().setEndpoint (endpoint ).build ();
226241 }
227242
243+ /**
244+ * Specify the MTLS endpoint.
245+ *
246+ * <p>The value of {@code mtlsEndpoint} must be of the form {@code host:port}.
247+ *
248+ * @param mtlsEndpoint
249+ * @return A new {@link InstantiatingGrpcChannelProvider} with the specified MTLS endpoint
250+ * configured
251+ */
252+ public TransportChannelProvider withMtlsEndpoint (String mtlsEndpoint ) {
253+ validateEndpoint (mtlsEndpoint );
254+ return toBuilder ().setMtlsEndpoint (mtlsEndpoint ).build ();
255+ }
256+
228257 /** @deprecated Please modify pool settings via {@link #toBuilder()} */
229258 @ Deprecated
230259 @ Override
@@ -410,6 +439,83 @@ ChannelCredentials createMtlsChannelCredentials() throws IOException, GeneralSec
410439 return null ;
411440 }
412441
442+ @ VisibleForTesting
443+ boolean isGoogleS2AEnabled () {
444+ String S2AEnv = envProvider .getenv (S2A_ENV_ENABLE_USE_S2A );
445+ boolean isS2AEnv = Boolean .parseBoolean (S2AEnv );
446+ if (isS2AEnv ) {
447+ return true ;
448+ }
449+ return false ;
450+ }
451+
452+ @ VisibleForTesting
453+ boolean shouldUseS2A () {
454+ // If EXPERIMENTAL_GOOGLE_API_USE_S2A is not set to true, skip S2A.
455+ if (!isGoogleS2AEnabled ()) {
456+ return false ;
457+ }
458+
459+ // If {@link mtlsEndpoint} is not set, skip S2A. S2A is also skipped when there is endpoint
460+ // override. Endpoint override is respected when the {@link endpoint} is resolved via AIP#4114,
461+ // see EndpointContext.java
462+ if (endpoint != mtlsEndpoint ) {
463+ return false ;
464+ }
465+
466+ // mTLS via S2A is not supported in any universe other than googleapis.com.
467+ if (!endpoint .contains (Credentials .GOOGLE_DEFAULT_UNIVERSE )) {
468+ return false ;
469+ }
470+
471+ return true ;
472+ }
473+
474+ @ VisibleForTesting
475+ ChannelCredentials createMtlsToS2AChannelCredentials () throws IOException {
476+ if (!isOnComputeEngine ()) {
477+ // Currently, MTLS to MDS is only available on GCE. See:
478+ // https://cloud.google.com/compute/docs/metadata/overview#https-mds
479+ return null ;
480+ }
481+ File privateKeyFile = new File (MTLS_MDS_CERT_CHAIN_AND_KEY );
482+ File certChainFile = new File (MTLS_MDS_CERT_CHAIN_AND_KEY );
483+ File trustBundleFile = new File (MTLS_MDS_ROOT );
484+ if (!privateKeyFile .isFile () || !certChainFile .isFile () || !trustBundleFile .isFile ()) {
485+ return null ;
486+ }
487+ return TlsChannelCredentials .newBuilder ()
488+ .keyManager (privateKeyFile , certChainFile )
489+ .trustManager (trustBundleFile )
490+ .build ();
491+ }
492+
493+ @ VisibleForTesting
494+ ChannelCredentials createS2ASecuredChannelCredentials () {
495+ S2A s2aUtils = S2A .newBuilder ().build ();
496+ String plaintextAddress = s2aUtils .getPlaintextS2AAddress ();
497+ String mtlsAddress = s2aUtils .getMtlsS2AAddress ();
498+ if (!mtlsAddress .isEmpty ()) {
499+ try {
500+ // Try to connect to S2A using mTLS.
501+ ChannelCredentials mtlsToS2AChannelCredentials = createMtlsToS2AChannelCredentials ();
502+ if (mtlsToS2AChannelCredentials != null ) {
503+ return S2AChannelCredentials .newBuilder (mtlsAddress , mtlsToS2AChannelCredentials ).build ();
504+ }
505+ } catch (IOException ignore ) {
506+ // Fallback to plaintext connection to S2A.
507+ }
508+ }
509+
510+ if (!plaintextAddress .isEmpty ()) {
511+ // Fallback to plaintext connection to S2A.
512+ return S2AChannelCredentials .newBuilder (plaintextAddress , InsecureChannelCredentials .create ())
513+ .build ();
514+ }
515+
516+ return null ;
517+ }
518+
413519 private ManagedChannel createSingleChannel () throws IOException {
414520 GrpcHeaderInterceptor headerInterceptor =
415521 new GrpcHeaderInterceptor (headersWithDuplicatesRemoved );
@@ -447,16 +553,30 @@ private ManagedChannel createSingleChannel() throws IOException {
447553 builder .keepAliveTime (DIRECT_PATH_KEEP_ALIVE_TIME_SECONDS , TimeUnit .SECONDS );
448554 builder .keepAliveTimeout (DIRECT_PATH_KEEP_ALIVE_TIMEOUT_SECONDS , TimeUnit .SECONDS );
449555 } else {
556+ // Try and create credentials via DCA. See https://google.aip.dev/auth/4114.
450557 ChannelCredentials channelCredentials ;
451558 try {
452559 channelCredentials = createMtlsChannelCredentials ();
453560 } catch (GeneralSecurityException e ) {
454561 throw new IOException (e );
455562 }
456563 if (channelCredentials != null ) {
564+ // Create the channel using channel credentials created via DCA.
457565 builder = Grpc .newChannelBuilder (endpoint , channelCredentials );
458566 } else {
459- builder = ManagedChannelBuilder .forAddress (serviceAddress , port );
567+ // Could not create channel credentials via DCA. In accordance with
568+ // https://google.aip.dev/auth/4115, if credentials not available through
569+ // DCA, try mTLS with credentials held by the S2A (Secure Session Agent).
570+ if (shouldUseS2A ()) {
571+ channelCredentials = createS2ASecuredChannelCredentials ();
572+ }
573+ if (channelCredentials != null ) {
574+ // Create the channel using S2A-secured channel credentials.
575+ builder = Grpc .newChannelBuilder (mtlsEndpoint , channelCredentials );
576+ } else {
577+ // Use default if we cannot initialize channel credentials via DCA or S2A.
578+ builder = ManagedChannelBuilder .forAddress (serviceAddress , port );
579+ }
460580 }
461581 }
462582 // google-c2p resolver requires service config lookup
@@ -547,6 +667,11 @@ public String getEndpoint() {
547667 return endpoint ;
548668 }
549669
670+ /** The mTLS endpoint. */
671+ public String getMtlsEndpoint () {
672+ return mtlsEndpoint ;
673+ }
674+
550675 /** This method is obsolete. Use {@link #getKeepAliveTimeDuration()} instead. */
551676 @ ObsoleteApi ("Use getKeepAliveTimeDuration() instead" )
552677 public org .threeten .bp .Duration getKeepAliveTime () {
@@ -604,6 +729,7 @@ public static final class Builder {
604729 private Executor executor ;
605730 private HeaderProvider headerProvider ;
606731 private String endpoint ;
732+ private String mtlsEndpoint ;
607733 private EnvironmentProvider envProvider ;
608734 private MtlsProvider mtlsProvider = new MtlsProvider ();
609735 @ Nullable private GrpcInterceptorProvider interceptorProvider ;
@@ -632,6 +758,7 @@ private Builder(InstantiatingGrpcChannelProvider provider) {
632758 this .executor = provider .executor ;
633759 this .headerProvider = provider .headerProvider ;
634760 this .endpoint = provider .endpoint ;
761+ this .mtlsEndpoint = provider .mtlsEndpoint ;
635762 this .envProvider = provider .envProvider ;
636763 this .interceptorProvider = provider .interceptorProvider ;
637764 this .maxInboundMessageSize = provider .maxInboundMessageSize ;
@@ -700,6 +827,12 @@ public Builder setEndpoint(String endpoint) {
700827 return this ;
701828 }
702829
830+ public Builder setMtlsEndpoint (String mtlsEndpoint ) {
831+ validateEndpoint (mtlsEndpoint );
832+ this .mtlsEndpoint = mtlsEndpoint ;
833+ return this ;
834+ }
835+
703836 @ VisibleForTesting
704837 Builder setMtlsProvider (MtlsProvider mtlsProvider ) {
705838 this .mtlsProvider = mtlsProvider ;
@@ -722,6 +855,10 @@ public String getEndpoint() {
722855 return endpoint ;
723856 }
724857
858+ public String getMtlsEndpoint () {
859+ return mtlsEndpoint ;
860+ }
861+
725862 /** The maximum message size allowed to be received on the channel. */
726863 public Builder setMaxInboundMessageSize (Integer max ) {
727864 this .maxInboundMessageSize = max ;
0 commit comments