Skip to content

Commit cf024eb

Browse files
fix: support for different types of credentials (#1097)
Co-authored-by: Shubha Rajan <[email protected]>
1 parent 42b773b commit cf024eb

File tree

3 files changed

+93
-32
lines changed

3 files changed

+93
-32
lines changed

core/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,12 @@
200200
<version>1.13.0</version>
201201
</dependency>
202202

203+
<dependency>
204+
<groupId>com.google.oauth-client</groupId>
205+
<artifactId>google-oauth-client</artifactId>
206+
<version>1.34.1</version>
207+
</dependency>
208+
203209
<!-- com.google.cloud.sql.nativeimage.CloudSqlFeature needs the GraalVM
204210
dependencies for compilation. The provided-scope dependencies do not add
205211
additional dependencies to library users. -->

core/src/main/java/com/google/cloud/sql/core/CloudSqlInstance.java

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,16 @@
1818

1919
import static com.google.common.base.Preconditions.checkArgument;
2020

21+
import com.google.api.client.auth.oauth2.Credential;
2122
import com.google.api.client.googleapis.json.GoogleJsonResponseException;
23+
import com.google.api.client.http.HttpRequestInitializer;
2224
import com.google.api.services.sqladmin.SQLAdmin;
2325
import com.google.api.services.sqladmin.model.ConnectSettings;
2426
import com.google.api.services.sqladmin.model.GenerateEphemeralCertRequest;
2527
import com.google.api.services.sqladmin.model.GenerateEphemeralCertResponse;
2628
import com.google.api.services.sqladmin.model.IpMapping;
2729
import com.google.auth.http.HttpCredentialsAdapter;
30+
import com.google.auth.oauth2.AccessToken;
2831
import com.google.auth.oauth2.GoogleCredentials;
2932
import com.google.auth.oauth2.OAuth2Credentials;
3033
import com.google.cloud.sql.CredentialFactory;
@@ -154,9 +157,9 @@ class CloudSqlInstance {
154157
this.keyPair = keyPair;
155158

156159
if (enableIamAuth) {
157-
HttpCredentialsAdapter credentialsAdapter = (HttpCredentialsAdapter) tokenSourceFactory
158-
.create();
159-
this.credentials = Optional.of((OAuth2Credentials) credentialsAdapter.getCredentials());
160+
HttpRequestInitializer source = tokenSourceFactory.create();
161+
162+
this.credentials = Optional.of(parseCredentials(source));
160163
this.credentials.get().refresh();
161164
} else {
162165
this.credentials = Optional.empty();
@@ -169,6 +172,36 @@ class CloudSqlInstance {
169172
}
170173
}
171174

175+
private OAuth2Credentials parseCredentials(HttpRequestInitializer source) {
176+
if (source instanceof HttpCredentialsAdapter) {
177+
HttpCredentialsAdapter adapter = (HttpCredentialsAdapter) source;
178+
return (OAuth2Credentials) adapter.getCredentials();
179+
}
180+
181+
if (source instanceof Credential) {
182+
Credential credential = (Credential) source;
183+
AccessToken accessToken = new AccessToken(
184+
credential.getAccessToken(),
185+
new Date(credential.getExpirationTimeMilliseconds())
186+
);
187+
GoogleCredentials googleCredentials = new GoogleCredentials(accessToken) {
188+
189+
@Override
190+
public AccessToken refreshAccessToken() throws IOException {
191+
credential.refreshToken();
192+
193+
return new AccessToken(
194+
credential.getAccessToken(),
195+
new Date(credential.getExpirationTimeMilliseconds()));
196+
}
197+
};
198+
199+
return googleCredentials;
200+
}
201+
202+
throw new RuntimeException("Not supporting credentials of type " + source.getClass().getName());
203+
}
204+
172205
/**
173206
* Generates public key certificate for which the instance has the matching private key.
174207
*
@@ -326,7 +359,7 @@ boolean forceRefresh() {
326359
* would expire.
327360
*/
328361
private ListenableFuture<InstanceData> performRefresh() {
329-
// To avoid unreasonable SQL Admin API usage, use a rate limit to throttle our usage.
362+
// To avoid unreasonable SQL Admin API usage, use a rate limit to throttle our usage.
330363
forcedRenewRateLimiter.acquire(1);
331364
// Use the Cloud SQL Admin API to return the Metadata and Certificate
332365
ListenableFuture<Metadata> metadataFuture = executor.submit(this::fetchMetadata);
@@ -492,7 +525,7 @@ private Metadata fetchMetadata() {
492525
+ "instance.",
493526
connectionName));
494527
}
495-
528+
496529
checkDatabaseCompatibility(instanceMetadata, enableIamAuth, connectionName);
497530

498531

core/src/test/java/com/google/cloud/sql/core/CoreSocketFactoryTest.java

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,18 @@
1818

1919
import static com.google.common.truth.Truth.assertThat;
2020
import static java.nio.charset.StandardCharsets.UTF_8;
21+
import static org.junit.Assert.assertThrows;
2122
import static org.junit.Assert.fail;
2223
import static org.mockito.ArgumentMatchers.anyString;
2324
import static org.mockito.ArgumentMatchers.eq;
2425
import static org.mockito.ArgumentMatchers.isA;
25-
import static org.mockito.Mockito.times;
26-
import static org.mockito.Mockito.verify;
27-
import static org.mockito.Mockito.verifyNoMoreInteractions;
28-
import static org.mockito.Mockito.when;
26+
import static org.mockito.Mockito.*;
2927

3028
import com.google.api.client.googleapis.auth.oauth2.GoogleCredential;
3129
import com.google.api.client.googleapis.json.GoogleJsonError;
3230
import com.google.api.client.googleapis.json.GoogleJsonError.ErrorInfo;
3331
import com.google.api.client.googleapis.json.GoogleJsonResponseException;
34-
import com.google.api.client.http.HttpRequest;
35-
import com.google.api.client.http.HttpResponse;
36-
import com.google.api.client.http.HttpStatusCodes;
37-
import com.google.api.client.http.HttpTransport;
38-
import com.google.api.client.http.LowLevelHttpRequest;
32+
import com.google.api.client.http.*;
3933
import com.google.api.client.json.GenericJson;
4034
import com.google.api.client.json.Json;
4135
import com.google.api.client.json.JsonFactory;
@@ -122,8 +116,6 @@ public class CoreSocketFactoryTest {
122116
@Mock
123117
private CredentialFactory credentialFactory;
124118
@Mock
125-
private GoogleCredential credential;
126-
@Mock
127119
private SQLAdmin adminApi;
128120
@Mock
129121
private SQLAdmin.Connect adminApiConnect;
@@ -191,6 +183,7 @@ public void setup()
191183
new IpMapping().setIpAddress(PUBLIC_IP).setType("PRIMARY"),
192184
new IpMapping().setIpAddress(PRIVATE_IP).setType("PRIVATE")))
193185
.setServerCaCert(new SslCert().setCert(TestKeys.SERVER_CA_CERT))
186+
.setDatabaseVersion("POSTGRES14")
194187
.setRegion("myRegion"));
195188
when(adminApiConnectGenerateEphemeralCert.execute())
196189
.thenReturn(generateEphemeralCertResponse);
@@ -253,10 +246,7 @@ public void create_successfulPrivateConnection()
253246
.generateEphemeralCert(
254247
eq("myProject"), eq("myRegion~myInstance"), isA(GenerateEphemeralCertRequest.class));
255248

256-
BufferedReader bufferedReader =
257-
new BufferedReader(new InputStreamReader(socket.getInputStream(), UTF_8));
258-
String line = bufferedReader.readLine();
259-
assertThat(line).isEqualTo(SERVER_MESSAGE);
249+
assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE);
260250
}
261251

262252
@Test
@@ -275,10 +265,7 @@ public void create_successfulConnection() throws IOException, InterruptedExcepti
275265
.generateEphemeralCert(
276266
eq("myProject"), eq("myRegion~myInstance"), isA(GenerateEphemeralCertRequest.class));
277267

278-
BufferedReader bufferedReader =
279-
new BufferedReader(new InputStreamReader(socket.getInputStream(), UTF_8));
280-
String line = bufferedReader.readLine();
281-
assertThat(line).isEqualTo(SERVER_MESSAGE);
268+
assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE);
282269
}
283270

284271
@Test
@@ -298,10 +285,7 @@ public void create_successfulDomainScopedConnection() throws IOException, Interr
298285
eq("example.com:myProject"), eq("myRegion~myInstance"),
299286
isA(GenerateEphemeralCertRequest.class));
300287

301-
BufferedReader bufferedReader =
302-
new BufferedReader(new InputStreamReader(socket.getInputStream(), UTF_8));
303-
String line = bufferedReader.readLine();
304-
assertThat(line).isEqualTo(SERVER_MESSAGE);
288+
assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE);
305289
}
306290

307291
@Test
@@ -332,10 +316,7 @@ public void create_expiredCertificateOnFirstConnection_certificateRenewed()
332316
.generateEphemeralCert(
333317
eq("myProject"), eq("myRegion~myInstance"), isA(GenerateEphemeralCertRequest.class));
334318

335-
BufferedReader bufferedReader =
336-
new BufferedReader(new InputStreamReader(socket.getInputStream(), UTF_8));
337-
String line = bufferedReader.readLine();
338-
assertThat(line).isEqualTo(SERVER_MESSAGE);
319+
assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE);
339320
}
340321

341322
@Test
@@ -396,6 +377,47 @@ public void create_notAuthorized() throws IOException {
396377
}
397378
}
398379

380+
@Test
381+
public void supportsCustomCredentialFactoryWithIAM() throws InterruptedException, IOException {
382+
GoogleCredential customCredential = mock(GoogleCredential.class);
383+
when(credentialFactory.create()).thenReturn(customCredential);
384+
385+
when(customCredential.getAccessToken()).thenReturn("foo");
386+
when(customCredential.getExpirationTimeMilliseconds()).thenReturn(new Date().getTime());
387+
388+
FakeSslServer sslServer = new FakeSslServer();
389+
int port = sslServer.start();
390+
391+
CoreSocketFactory coreSocketFactory =
392+
new CoreSocketFactory(clientKeyPair, adminApi, credentialFactory, port, defaultExecutor);
393+
Socket socket =
394+
coreSocketFactory.createSslSocket(
395+
"myProject:myRegion:myInstance", Arrays.asList("PRIMARY"), true);
396+
397+
assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE);
398+
}
399+
400+
@Test
401+
public void doesNotSupportNonGoogleCredentialWithIAM() throws InterruptedException, IOException {
402+
BasicAuthentication nonGoogleCredential = mock(BasicAuthentication.class);
403+
when(credentialFactory.create()).thenReturn(nonGoogleCredential);
404+
405+
FakeSslServer sslServer = new FakeSslServer();
406+
int port = sslServer.start();
407+
408+
CoreSocketFactory coreSocketFactory =
409+
new CoreSocketFactory(clientKeyPair, adminApi, credentialFactory, port, defaultExecutor);
410+
assertThrows(RuntimeException.class, () -> {
411+
coreSocketFactory.createSslSocket(
412+
"myProject:myRegion:myInstance", Arrays.asList("PRIMARY"), true);
413+
});
414+
}
415+
416+
private String readLine(Socket socket) throws IOException {
417+
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(socket.getInputStream(), UTF_8));
418+
return bufferedReader.readLine();
419+
}
420+
399421
// Creates a fake "accessNotConfigured" exception that can be used for testing.
400422
private static GoogleJsonResponseException fakeNotConfiguredException() throws IOException {
401423
return fakeGoogleJsonResponseException(

0 commit comments

Comments
 (0)