2828import com .google .api .services .sqladmin .SQLAdmin .Builder ;
2929import com .google .api .services .sqladmin .SQLAdminScopes ;
3030import com .google .api .services .sqladmin .model .DatabaseInstance ;
31+ import com .google .api .services .sqladmin .model .IpMapping ;
3132import com .google .api .services .sqladmin .model .SslCert ;
3233import com .google .api .services .sqladmin .model .SslCertsCreateEphemeralRequest ;
3334import com .google .cloud .sql .CredentialFactory ;
5859import java .security .cert .CertificateException ;
5960import java .security .cert .CertificateFactory ;
6061import java .security .cert .X509Certificate ;
62+ import java .util .ArrayList ;
6163import java .util .Base64 ;
6264import java .util .Calendar ;
6365import java .util .Collections ;
6466import java .util .GregorianCalendar ;
6567import java .util .HashMap ;
68+ import java .util .List ;
6669import java .util .Map ;
6770import java .util .logging .Logger ;
6871
7881public class SslSocketFactory {
7982 private static final Logger logger = Logger .getLogger (SslSocketFactory .class .getName ());
8083
84+ public static final String DEFAULT_IP_TYPES = "PUBLIC,PRIVATE" ;
8185 public static final String USER_TOKEN_PROPERTY_NAME = "_CLOUD_SQL_USER_TOKEN" ;
8286
8387 static final String ADMIN_API_NOT_ENABLED_REASON = "accessNotConfigured" ;
@@ -151,9 +155,9 @@ public static synchronized SslSocketFactory getInstance() {
151155 }
152156
153157 // TODO(berezv): separate creating socket and performing connection to make it easier to test
154- public Socket create (String instanceName ) throws IOException {
158+ public Socket create (String instanceName , List < String > ipTypes ) throws IOException {
155159 try {
156- return createAndConfigureSocket (instanceName , CertificateCaching .USE_CACHE );
160+ return createAndConfigureSocket (instanceName , ipTypes , CertificateCaching .USE_CACHE );
157161 } catch (SSLHandshakeException e ) {
158162 logger .warning (
159163 String .format (
@@ -170,7 +174,7 @@ public Socket create(String instanceName) throws IOException {
170174 instanceName ));
171175 forcedRenewRateLimiter .acquire ();
172176 }
173- return createAndConfigureSocket (instanceName , CertificateCaching .BYPASS_CACHE );
177+ return createAndConfigureSocket (instanceName , ipTypes , CertificateCaching .BYPASS_CACHE );
174178 }
175179 }
176180
@@ -183,9 +187,13 @@ private static void logTestPropertyWarning(String property) {
183187 }
184188
185189 private SSLSocket createAndConfigureSocket (
186- String instanceName , CertificateCaching certificateCaching ) throws IOException {
190+ String instanceName ,
191+ List <String > ipTypes ,
192+ CertificateCaching certificateCaching )
193+ throws IOException {
187194 InstanceSslInfo instanceSslInfo = getInstanceSslInfo (instanceName , certificateCaching );
188- String ipAddress = instanceSslInfo .getInstanceIpAddress ();
195+ String ipAddress = getPreferredIp (instanceName , ipTypes , instanceSslInfo );
196+
189197 logger .info (
190198 String .format (
191199 "Connecting to Cloud SQL instance [%s] on IP [%s]." , instanceName , ipAddress ));
@@ -203,6 +211,43 @@ private SSLSocket createAndConfigureSocket(
203211 return sslSocket ;
204212 }
205213
214+ /**
215+ * Converts the string property of IP types to a list by splitting by commas, and upper-casing.
216+ */
217+ public static List <String > listIpTypes (String cloudSqlIpTypes ) {
218+ String [] rawTypes = cloudSqlIpTypes .split ("," );
219+ ArrayList <String > result = new ArrayList <>(rawTypes .length );
220+ for (int i = 0 ; i < rawTypes .length ; i ++) {
221+ if (rawTypes [i ].trim ().equalsIgnoreCase ("PUBLIC" )) {
222+ result .add (i , "PRIMARY" );
223+ } else {
224+ result .add (i , rawTypes [i ].trim ().toUpperCase ());
225+ }
226+ }
227+ return result ;
228+ }
229+
230+ @ Nullable
231+ private String getPreferredIp (
232+ String instanceName , List <String > ipTypes , InstanceSslInfo instanceSslInfo ) {
233+ String ipAddress = null ;
234+ for (String ipType : ipTypes ) {
235+ ipAddress = instanceSslInfo .getInstanceIpAddress (ipType );
236+ if (ipAddress != null ) {
237+ break ;
238+ }
239+ }
240+
241+ if (ipAddress == null ) {
242+ throw new RuntimeException (
243+ String .format (
244+ "Cloud SQL instance [%s] does not have any IP addresses matching preference: [ %s ]" ,
245+ instanceName , String .join (", " , ipTypes )));
246+ }
247+
248+ return ipAddress ;
249+ }
250+
206251 // TODO(berezv): synchronize per instance, instead of globally
207252 @ VisibleForTesting
208253 synchronized InstanceSslInfo getInstanceSslInfo (
@@ -287,11 +332,9 @@ private InstanceSslInfo fetchInstanceSslInfo(
287332 DatabaseInstance instance =
288333 obtainInstanceMetadata (adminApi , instanceConnectionString , projectId , instanceName );
289334 if (instance .getIpAddresses ().isEmpty ()) {
290- throw
291- new RuntimeException (
292- String .format (
293- "Cloud SQL instance [%s] does not have any external IP addresses" ,
294- instanceConnectionString ));
335+ throw new RuntimeException (
336+ String .format (
337+ "Cloud SQL instance [%s] does not have any IP addresses" , instanceConnectionString ));
295338 }
296339 if (!instance .getRegion ().equals (region )) {
297340 throw
@@ -323,12 +366,13 @@ private InstanceSslInfo fetchInstanceSslInfo(
323366
324367 SSLContext sslContext = createSslContext (ephemeralCertificate , instanceCaCertificate );
325368
326- return
327- new InstanceSslInfo (
328- instance .getIpAddresses (). get ( 0 ). getIpAddress (),
329- ephemeralCertificate ,
330- sslContext . getSocketFactory ());
369+ InstanceSslInfo info = new InstanceSslInfo ( ephemeralCertificate , sslContext . getSocketFactory ());
370+
371+ for ( IpMapping ip : instance .getIpAddresses ()) {
372+ info . putInstanceIpAddress ( ip . getType (), ip . getIpAddress ());
373+ }
331374
375+ return info ;
332376 }
333377
334378 private SSLContext createSslContext (
@@ -613,21 +657,23 @@ public long getLastFailureMillis() {
613657 }
614658
615659 private static class InstanceSslInfo {
616- private final String instanceIpAddress ;
660+ private final HashMap < String , String > instanceIpAddresses = new HashMap <>() ;
617661 private final X509Certificate ephemeralCertificate ;
618662 private final SSLSocketFactory sslSocketFactory ;
619663
620664 InstanceSslInfo (
621- String instanceIpAddress ,
622665 X509Certificate ephemeralCertificate ,
623666 SSLSocketFactory sslSocketFactory ) {
624- this .instanceIpAddress = instanceIpAddress ;
625667 this .ephemeralCertificate = ephemeralCertificate ;
626668 this .sslSocketFactory = sslSocketFactory ;
627669 }
628670
629- public String getInstanceIpAddress () {
630- return instanceIpAddress ;
671+ public void putInstanceIpAddress (String type , String ipAddress ) {
672+ instanceIpAddresses .put (type == null ? "" : type .toUpperCase (), ipAddress );
673+ }
674+
675+ public String getInstanceIpAddress (String type ) {
676+ return instanceIpAddresses .get (type .toUpperCase ());
631677 }
632678
633679 public X509Certificate getEphemeralCertificate () {
0 commit comments