Skip to content

Commit acb57cb

Browse files
authored
fix: downscope credentials used for IAM AuthN login (#999)
1 parent 0e532a1 commit acb57cb

File tree

6 files changed

+155
-25
lines changed

6 files changed

+155
-25
lines changed

core/src/main/java/com/google/cloud/sql/CredentialFactory.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,16 @@
1818

1919
import com.google.api.client.auth.oauth2.Credential;
2020
import com.google.api.client.http.HttpRequestInitializer;
21+
import java.io.IOException;
2122

22-
/** Factory for creating {@link Credential}s for interaction with Cloud SQL Admin API. */
23+
/**
24+
* Factory for creating {@link Credential}s for interaction with Cloud SQL Admin API.
25+
*/
2326
public interface CredentialFactory {
24-
/** Name of system property that can specify an alternative credential factory. */
27+
28+
/**
29+
* Name of system property that can specify an alternative credential factory.
30+
*/
2531
String CREDENTIAL_FACTORY_PROPERTY = "cloudSql.socketFactory.credentialFactory";
2632

2733
HttpRequestInitializer create();

core/src/main/java/com/google/cloud/sql/core/CloudSqlInstance.java

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import com.google.api.services.sqladmin.model.GenerateEphemeralCertResponse;
2626
import com.google.api.services.sqladmin.model.IpMapping;
2727
import com.google.auth.http.HttpCredentialsAdapter;
28+
import com.google.auth.oauth2.GoogleCredentials;
2829
import com.google.auth.oauth2.OAuth2Credentials;
2930
import com.google.cloud.sql.CredentialFactory;
3031
import com.google.common.base.CharMatcher;
@@ -81,6 +82,7 @@
8182
*/
8283
class CloudSqlInstance {
8384

85+
private static final String SQL_LOGIN_SCOPE = "https://www.googleapis.com/auth/sqlservice.login";
8486
private static final Logger logger = Logger.getLogger(CloudSqlInstance.class.getName());
8587

8688
// Unique identifier for each Cloud SQL instance in the format "PROJECT:REGION:INSTANCE"
@@ -131,7 +133,7 @@ class CloudSqlInstance {
131133
boolean enableIamAuth,
132134
CredentialFactory tokenSourceFactory,
133135
ListeningScheduledExecutorService executor,
134-
ListenableFuture<KeyPair> keyPair) {
136+
ListenableFuture<KeyPair> keyPair) throws IOException {
135137

136138
Matcher matcher = CONNECTION_NAME.matcher(connectionName);
137139
checkArgument(
@@ -155,6 +157,7 @@ class CloudSqlInstance {
155157
HttpCredentialsAdapter credentialsAdapter = (HttpCredentialsAdapter) tokenSourceFactory
156158
.create();
157159
this.credentials = Optional.of((OAuth2Credentials) credentialsAdapter.getCredentials());
160+
this.credentials.get().refresh();
158161
} else {
159162
this.credentials = Optional.empty();
160163
}
@@ -277,10 +280,10 @@ SSLSocket createSslSocket() throws IOException {
277280
* preferredTypes.
278281
*
279282
* @param preferredTypes Preferred instance IP types to use. Valid IP types include "Public" and
280-
* "Private".
283+
* "Private".
281284
* @return returns a string representing the IP address for the instance
282285
* @throws IllegalArgumentException If the instance has no IP addresses matching the provided
283-
* preferences.
286+
* preferences.
284287
*/
285288
String getPreferredIp(List<String> preferredTypes) {
286289
Map<String, String> ipAddrs = getInstanceData().getIpAddrs();
@@ -525,8 +528,9 @@ private Certificate fetchEphemeralCertificate(KeyPair keyPair) {
525528

526529
if (enableIamAuth) {
527530
try {
528-
credentials.get().refresh();
529-
String token = credentials.get().getAccessToken().getTokenValue();
531+
GoogleCredentials downscoped = getDownscopedCredentials(credentials.get());
532+
downscoped.refresh();
533+
String token = downscoped.getAccessToken().getTokenValue();
530534
// TODO: remove this once issue with OAuth2 Tokens is resolved.
531535
// See: https://github.com/GoogleCloudPlatform/cloud-sql-jdbc-socket-factory/issues/565
532536
request.setAccessToken(CharMatcher.is('.').trimTrailingFrom(token));
@@ -563,6 +567,19 @@ private Certificate fetchEphemeralCertificate(KeyPair keyPair) {
563567
return ephemeralCertificate;
564568
}
565569

570+
static GoogleCredentials getDownscopedCredentials(OAuth2Credentials credentials) {
571+
GoogleCredentials downscoped;
572+
try {
573+
GoogleCredentials oldCredentials = (GoogleCredentials) credentials;
574+
downscoped = oldCredentials.createScoped(SQL_LOGIN_SCOPE);
575+
} catch (ClassCastException ex) {
576+
throw new RuntimeException(
577+
"Failed to downscope credentials for IAM Authentication:",
578+
ex);
579+
}
580+
return downscoped;
581+
}
582+
566583
private Date getTokenExpirationTime() {
567584
return credentials.get().getAccessToken().getExpirationTime();
568585
}
@@ -590,7 +607,7 @@ private long secondsUntilRefresh() {
590607
*
591608
* @param ex exception thrown by the Admin API request
592609
* @param fallbackDesc generic description used as a fallback if no additional information can be
593-
* provided to the user
610+
* provided to the user
594611
*/
595612
private RuntimeException addExceptionContext(IOException ex, String fallbackDesc) {
596613
// Verify we are able to extract a reason from an exception, or fallback to a generic desc

core/src/main/java/com/google/cloud/sql/core/CoreSocketFactory.java

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ public final class CoreSocketFactory {
7070
* Property used to set the application name for the underlying SQLAdmin client.
7171
*
7272
* @deprecated Use {@link #setApplicationName(String)} to set the application name
73-
* programmatically.
73+
* programmatically.
7474
*/
7575

7676
@Deprecated
@@ -117,7 +117,7 @@ public final class CoreSocketFactory {
117117
/**
118118
* Returns the {@link CoreSocketFactory} singleton.
119119
*/
120-
public static synchronized CoreSocketFactory getInstance() {
120+
public static synchronized CoreSocketFactory getInstance() throws IOException {
121121
if (coreSocketFactory == null) {
122122
logger.info("First Cloud SQL connection, generating RSA key pair.");
123123

@@ -155,14 +155,27 @@ public static synchronized CoreSocketFactory getInstance() {
155155
private CloudSqlInstance getCloudSqlInstance(String instanceName, boolean enableIamAuth) {
156156
return instances.computeIfAbsent(
157157
instanceName,
158-
k -> new CloudSqlInstance(k, adminApi, enableIamAuth, credentialFactory, executor,
159-
localKeyPair));
158+
k -> {
159+
try {
160+
return new CloudSqlInstance(k, adminApi, enableIamAuth, credentialFactory, executor,
161+
localKeyPair);
162+
} catch (IOException e) {
163+
throw new RuntimeException(e);
164+
}
165+
});
160166
}
161167

162168
private CloudSqlInstance getCloudSqlInstance(String instanceName) {
163169
return instances.computeIfAbsent(
164170
instanceName,
165-
k -> new CloudSqlInstance(k, adminApi, false, credentialFactory, executor, localKeyPair));
171+
k -> {
172+
try {
173+
return new CloudSqlInstance(k, adminApi, false, credentialFactory, executor,
174+
localKeyPair);
175+
} catch (IOException e) {
176+
throw new RuntimeException(e);
177+
}
178+
});
166179
}
167180

168181
static int getDefaultServerProxyPort() {
@@ -215,7 +228,7 @@ public static Socket connect(Properties props) throws IOException {
215228
*
216229
* <p>Depending on the given properties, it may return either a SSL Socket or a Unix Socket.
217230
*
218-
* @param props Properties used to configure the connection.
231+
* @param props Properties used to configure the connection.
219232
* @param unixPathSuffix suffix to add the the Unix socket path. Unused if null.
220233
* @return the newly created Socket.
221234
* @throws IOException if error occurs during socket creation.
@@ -255,18 +268,19 @@ public static Socket connect(Properties props, String unixPathSuffix) throws IOE
255268
/**
256269
* Returns data that can be used to establish Cloud SQL SSL connection.
257270
*/
258-
public static SslData getSslData(String csqlInstanceName, boolean enableIamAuth) {
271+
public static SslData getSslData(String csqlInstanceName, boolean enableIamAuth)
272+
throws IOException {
259273
return getInstance().getCloudSqlInstance(csqlInstanceName, enableIamAuth).getSslData();
260274
}
261275

262-
public static SslData getSslData(String csqlInstanceName) {
276+
public static SslData getSslData(String csqlInstanceName) throws IOException {
263277
return getSslData(csqlInstanceName, false);
264278
}
265279

266280
/**
267281
* Returns preferred ip address that can be used to establish Cloud SQL connection.
268282
*/
269-
public static String getHostIp(String csqlInstanceName) {
283+
public static String getHostIp(String csqlInstanceName) throws IOException {
270284
return getInstance().getHostIp(csqlInstanceName, listIpTypes(DEFAULT_IP_TYPES));
271285
}
272286

@@ -280,7 +294,7 @@ private String getHostIp(String instanceName, List<String> ipTypes) {
280294
* Creates a secure socket representing a connection to a Cloud SQL instance.
281295
*
282296
* @param instanceName Name of the Cloud SQL instance.
283-
* @param ipTypes Preferred type of IP to use ("PRIVATE", "PUBLIC")
297+
* @param ipTypes Preferred type of IP to use ("PRIVATE", "PUBLIC")
284298
* @return the newly created Socket.
285299
* @throws IOException if error occurs during socket creation.
286300
*/
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Copyright 2022 Google LLC. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.google.cloud.sql.core;
17+
18+
import static com.google.common.truth.Truth.assertThat;
19+
import static org.mockito.Mockito.times;
20+
import static org.mockito.Mockito.verify;
21+
import static org.mockito.Mockito.when;
22+
23+
import com.google.auth.oauth2.GoogleCredentials;
24+
import com.google.auth.oauth2.OAuth2Credentials;
25+
import java.io.IOException;
26+
import org.junit.Before;
27+
import org.junit.Test;
28+
import org.junit.runner.RunWith;
29+
import org.junit.runners.JUnit4;
30+
import org.mockito.Mock;
31+
import org.mockito.MockitoAnnotations;
32+
33+
@RunWith(JUnit4.class)
34+
public class CloudSqlInstanceTest {
35+
36+
@Mock
37+
private GoogleCredentials googleCredentials;
38+
39+
@Mock
40+
private GoogleCredentials scopedCredentials;
41+
42+
@Mock
43+
private OAuth2Credentials oAuth2Credentials;
44+
45+
@Before
46+
public void setup() throws IOException {
47+
MockitoAnnotations.openMocks(this);
48+
when(googleCredentials.createScoped(
49+
"https://www.googleapis.com/auth/sqlservice.login")).thenReturn(scopedCredentials);
50+
}
51+
52+
@Test
53+
public void downscopesGoogleCredentials() {
54+
GoogleCredentials downscoped = CloudSqlInstance.getDownscopedCredentials(googleCredentials);
55+
assertThat(downscoped).isEqualTo(scopedCredentials);
56+
verify(googleCredentials, times(1)).createScoped(
57+
"https://www.googleapis.com/auth/sqlservice.login");
58+
}
59+
60+
61+
@Test
62+
public void throwsErrorForWrongCredentialType() {
63+
try {
64+
CloudSqlInstance.getDownscopedCredentials(oAuth2Credentials);
65+
} catch (RuntimeException ex) {
66+
assertThat(ex)
67+
.hasMessageThat()
68+
.contains("Failed to downscope credentials for IAM Authentication");
69+
}
70+
}
71+
72+
73+
}

r2dbc/core/src/main/java/com/google/cloud/sql/core/CloudSqlConnectionFactory.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@
2424
import io.r2dbc.spi.ConnectionFactoryMetadata;
2525
import io.r2dbc.spi.ConnectionFactoryOptions;
2626
import io.r2dbc.spi.ConnectionFactoryOptions.Builder;
27+
import java.io.IOException;
2728
import java.util.function.Function;
2829
import org.reactivestreams.Publisher;
2930

3031
/**
31-
* * {@link ConnectionFactory} for accessing Cloud SQL instances via R2DBC protocol.
32+
* * {@link ConnectionFactory} for accessing Cloud SQL instances via R2DBC protocol.
3233
*/
3334
public class CloudSqlConnectionFactory implements ConnectionFactory {
3435

@@ -50,15 +51,23 @@ public CloudSqlConnectionFactory(
5051

5152
@Override
5253
public Publisher<? extends Connection> create() {
53-
return getConnectionFactory().create();
54+
try {
55+
return getConnectionFactory().create();
56+
} catch (IOException e) {
57+
throw new RuntimeException(e);
58+
}
5459
}
5560

5661
@Override
5762
public ConnectionFactoryMetadata getMetadata() {
58-
return getConnectionFactory().getMetadata();
63+
try {
64+
return getConnectionFactory().getMetadata();
65+
} catch (IOException e) {
66+
throw new RuntimeException(e);
67+
}
5968
}
6069

61-
private ConnectionFactory getConnectionFactory() {
70+
private ConnectionFactory getConnectionFactory() throws IOException {
6271
String hostIp = CoreSocketFactory.getHostIp(csqlHostName);
6372
builder.option(HOST, hostIp).option(PORT, CoreSocketFactory.getDefaultServerProxyPort());
6473
return connectionFactoryFactory.apply(builder.build());

r2dbc/core/src/main/java/com/google/cloud/sql/core/GcpConnectionFactoryProvider.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import io.r2dbc.spi.ConnectionFactoryOptions;
2727
import io.r2dbc.spi.ConnectionFactoryProvider;
2828
import io.r2dbc.spi.Option;
29+
import java.io.IOException;
2930
import java.util.function.Function;
3031
import reactor.core.publisher.Mono;
3132
import reactor.core.scheduler.Schedulers;
@@ -45,7 +46,13 @@ private static Function<SslContextBuilder, SslContextBuilder> createSslCustomize
4546
sslContextBuilder -> {
4647
// Execute in a default scheduler to prevent it from blocking event loop
4748
SslData sslData = Mono
48-
.fromSupplier(() -> CoreSocketFactory.getSslData(connectionName, enableIamAuth))
49+
.fromSupplier(() -> {
50+
try {
51+
return CoreSocketFactory.getSslData(connectionName, enableIamAuth);
52+
} catch (IOException e) {
53+
throw new RuntimeException(e);
54+
}
55+
})
4956
.subscribeOn(Schedulers.boundedElastic())
5057
.share()
5158
.block();
@@ -93,11 +100,15 @@ public ConnectionFactory create(ConnectionFactoryOptions connectionFactoryOption
93100
"Cannot create ConnectionFactory: unsupported protocol (" + protocol + ")");
94101
}
95102

96-
return createFactory(connectionFactoryOptions);
103+
try {
104+
return createFactory(connectionFactoryOptions);
105+
} catch (IOException e) {
106+
throw new RuntimeException(e);
107+
}
97108
}
98109

99110
private ConnectionFactory createFactory(
100-
ConnectionFactoryOptions connectionFactoryOptions) {
111+
ConnectionFactoryOptions connectionFactoryOptions) throws IOException {
101112
String connectionName = (String) connectionFactoryOptions.getRequiredValue(HOST);
102113
String socket = (String) connectionFactoryOptions.getValue(UNIX_SOCKET);
103114

0 commit comments

Comments
 (0)