1818
1919import static com .google .common .truth .Truth .assertThat ;
2020import static java .nio .charset .StandardCharsets .UTF_8 ;
21+ import static org .junit .Assert .assertThrows ;
2122import static org .junit .Assert .fail ;
2223import static org .mockito .ArgumentMatchers .anyString ;
2324import static org .mockito .ArgumentMatchers .eq ;
2425import static org .mockito .ArgumentMatchers .isA ;
25- import static org .mockito .Mockito .times ;
26- import static org .mockito .Mockito .verify ;
27- import static org .mockito .Mockito .verifyNoMoreInteractions ;
28- import static org .mockito .Mockito .when ;
26+ import static org .mockito .Mockito .*;
2927
3028import com .google .api .client .googleapis .auth .oauth2 .GoogleCredential ;
3129import com .google .api .client .googleapis .json .GoogleJsonError ;
3230import com .google .api .client .googleapis .json .GoogleJsonError .ErrorInfo ;
3331import com .google .api .client .googleapis .json .GoogleJsonResponseException ;
34- import com .google .api .client .http .HttpRequest ;
35- import com .google .api .client .http .HttpResponse ;
36- import com .google .api .client .http .HttpStatusCodes ;
37- import com .google .api .client .http .HttpTransport ;
38- import com .google .api .client .http .LowLevelHttpRequest ;
32+ import com .google .api .client .http .*;
3933import com .google .api .client .json .GenericJson ;
4034import com .google .api .client .json .Json ;
4135import com .google .api .client .json .JsonFactory ;
@@ -122,8 +116,6 @@ public class CoreSocketFactoryTest {
122116 @ Mock
123117 private CredentialFactory credentialFactory ;
124118 @ Mock
125- private GoogleCredential credential ;
126- @ Mock
127119 private SQLAdmin adminApi ;
128120 @ Mock
129121 private SQLAdmin .Connect adminApiConnect ;
@@ -191,6 +183,7 @@ public void setup()
191183 new IpMapping ().setIpAddress (PUBLIC_IP ).setType ("PRIMARY" ),
192184 new IpMapping ().setIpAddress (PRIVATE_IP ).setType ("PRIVATE" )))
193185 .setServerCaCert (new SslCert ().setCert (TestKeys .SERVER_CA_CERT ))
186+ .setDatabaseVersion ("POSTGRES14" )
194187 .setRegion ("myRegion" ));
195188 when (adminApiConnectGenerateEphemeralCert .execute ())
196189 .thenReturn (generateEphemeralCertResponse );
@@ -253,10 +246,7 @@ public void create_successfulPrivateConnection()
253246 .generateEphemeralCert (
254247 eq ("myProject" ), eq ("myRegion~myInstance" ), isA (GenerateEphemeralCertRequest .class ));
255248
256- BufferedReader bufferedReader =
257- new BufferedReader (new InputStreamReader (socket .getInputStream (), UTF_8 ));
258- String line = bufferedReader .readLine ();
259- assertThat (line ).isEqualTo (SERVER_MESSAGE );
249+ assertThat (readLine (socket )).isEqualTo (SERVER_MESSAGE );
260250 }
261251
262252 @ Test
@@ -275,10 +265,7 @@ public void create_successfulConnection() throws IOException, InterruptedExcepti
275265 .generateEphemeralCert (
276266 eq ("myProject" ), eq ("myRegion~myInstance" ), isA (GenerateEphemeralCertRequest .class ));
277267
278- BufferedReader bufferedReader =
279- new BufferedReader (new InputStreamReader (socket .getInputStream (), UTF_8 ));
280- String line = bufferedReader .readLine ();
281- assertThat (line ).isEqualTo (SERVER_MESSAGE );
268+ assertThat (readLine (socket )).isEqualTo (SERVER_MESSAGE );
282269 }
283270
284271 @ Test
@@ -298,10 +285,7 @@ public void create_successfulDomainScopedConnection() throws IOException, Interr
298285 eq ("example.com:myProject" ), eq ("myRegion~myInstance" ),
299286 isA (GenerateEphemeralCertRequest .class ));
300287
301- BufferedReader bufferedReader =
302- new BufferedReader (new InputStreamReader (socket .getInputStream (), UTF_8 ));
303- String line = bufferedReader .readLine ();
304- assertThat (line ).isEqualTo (SERVER_MESSAGE );
288+ assertThat (readLine (socket )).isEqualTo (SERVER_MESSAGE );
305289 }
306290
307291 @ Test
@@ -332,10 +316,7 @@ public void create_expiredCertificateOnFirstConnection_certificateRenewed()
332316 .generateEphemeralCert (
333317 eq ("myProject" ), eq ("myRegion~myInstance" ), isA (GenerateEphemeralCertRequest .class ));
334318
335- BufferedReader bufferedReader =
336- new BufferedReader (new InputStreamReader (socket .getInputStream (), UTF_8 ));
337- String line = bufferedReader .readLine ();
338- assertThat (line ).isEqualTo (SERVER_MESSAGE );
319+ assertThat (readLine (socket )).isEqualTo (SERVER_MESSAGE );
339320 }
340321
341322 @ Test
@@ -396,6 +377,47 @@ public void create_notAuthorized() throws IOException {
396377 }
397378 }
398379
380+ @ Test
381+ public void supportsCustomCredentialFactoryWithIAM () throws InterruptedException , IOException {
382+ GoogleCredential customCredential = mock (GoogleCredential .class );
383+ when (credentialFactory .create ()).thenReturn (customCredential );
384+
385+ when (customCredential .getAccessToken ()).thenReturn ("foo" );
386+ when (customCredential .getExpirationTimeMilliseconds ()).thenReturn (new Date ().getTime ());
387+
388+ FakeSslServer sslServer = new FakeSslServer ();
389+ int port = sslServer .start ();
390+
391+ CoreSocketFactory coreSocketFactory =
392+ new CoreSocketFactory (clientKeyPair , adminApi , credentialFactory , port , defaultExecutor );
393+ Socket socket =
394+ coreSocketFactory .createSslSocket (
395+ "myProject:myRegion:myInstance" , Arrays .asList ("PRIMARY" ), true );
396+
397+ assertThat (readLine (socket )).isEqualTo (SERVER_MESSAGE );
398+ }
399+
400+ @ Test
401+ public void doesNotSupportNonGoogleCredentialWithIAM () throws InterruptedException , IOException {
402+ BasicAuthentication nonGoogleCredential = mock (BasicAuthentication .class );
403+ when (credentialFactory .create ()).thenReturn (nonGoogleCredential );
404+
405+ FakeSslServer sslServer = new FakeSslServer ();
406+ int port = sslServer .start ();
407+
408+ CoreSocketFactory coreSocketFactory =
409+ new CoreSocketFactory (clientKeyPair , adminApi , credentialFactory , port , defaultExecutor );
410+ assertThrows (RuntimeException .class , () -> {
411+ coreSocketFactory .createSslSocket (
412+ "myProject:myRegion:myInstance" , Arrays .asList ("PRIMARY" ), true );
413+ });
414+ }
415+
416+ private String readLine (Socket socket ) throws IOException {
417+ BufferedReader bufferedReader = new BufferedReader (new InputStreamReader (socket .getInputStream (), UTF_8 ));
418+ return bufferedReader .readLine ();
419+ }
420+
399421 // Creates a fake "accessNotConfigured" exception that can be used for testing.
400422 private static GoogleJsonResponseException fakeNotConfiguredException () throws IOException {
401423 return fakeGoogleJsonResponseException (
0 commit comments