Skip to content

Commit b64406f

Browse files
authored
Add ipTypes argument (#78) (#100)
1 parent 94331f9 commit b64406f

File tree

6 files changed

+169
-53
lines changed

6 files changed

+169
-53
lines changed

connector-j-5/src/main/java/com/google/cloud/sql/mysql/SocketFactory.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.io.File;
2323
import java.io.IOException;
2424
import java.net.Socket;
25+
import java.util.List;
2526
import java.util.Properties;
2627
import java.util.logging.Logger;
2728
import jnr.unixsocket.UnixSocketAddress;
@@ -65,7 +66,10 @@ public Socket connect(String hostname, int portNumber, Properties props) throws
6566
// Default to SSL Socket
6667
logger.info(String.format(
6768
"Connecting to Cloud SQL instance [%s] via ssl socket.", instanceName));
68-
this.socket = SslSocketFactory.getInstance().create(instanceName);
69+
List<String> ipTypes =
70+
SslSocketFactory.listIpTypes(
71+
props.getProperty("ipTypes", SslSocketFactory.DEFAULT_IP_TYPES));
72+
this.socket = SslSocketFactory.getInstance().create(instanceName, ipTypes);
6973
}
7074
return this.socket;
7175
}

connector-j-6/src/main/java/com/google/cloud/sql/mysql/SocketFactory.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.io.File;
2323
import java.io.IOException;
2424
import java.net.Socket;
25+
import java.util.List;
2526
import java.util.Properties;
2627
import java.util.logging.Logger;
2728
import jnr.unixsocket.UnixSocketAddress;
@@ -65,14 +66,16 @@ public Socket connect(String host, int portNumber, Properties props, int loginTi
6566
// Default to SSL Socket
6667
logger.info(String.format(
6768
"Connecting to Cloud SQL instance [%s] via ssl socket.", instanceName));
68-
this.socket = SslSocketFactory.getInstance().create(instanceName);
69+
List<String> ipTypes =
70+
SslSocketFactory.listIpTypes(
71+
props.getProperty("ipTypes", SslSocketFactory.DEFAULT_IP_TYPES));
72+
this.socket = SslSocketFactory.getInstance().create(instanceName, ipTypes);
6973
}
7074
return this.socket;
7175
}
7276

7377
// Cloud SQL sockets always use TLS and the socket returned by connect above is already TLS-ready. It is fine
7478
// to implement these as no-ops.
75-
7679
@Override
7780
public Socket beforeHandshake() {
7881
return socket;

connector-j-8/src/main/java/com/google/cloud/sql/mysql/SocketFactory.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.io.File;
2626
import java.io.IOException;
2727
import java.net.Socket;
28+
import java.util.List;
2829
import java.util.Properties;
2930
import java.util.logging.Logger;
3031
import jnr.unixsocket.UnixSocketAddress;
@@ -68,7 +69,10 @@ public Socket connect(String host, int portNumber, Properties props, int loginTi
6869
// Default to SSL Socket
6970
logger.info(String.format(
7071
"Connecting to Cloud SQL instance [%s] via ssl socket.", instanceName));
71-
this.socket = SslSocketFactory.getInstance().create(instanceName);
72+
List<String> ipTypes =
73+
SslSocketFactory.listIpTypes(
74+
props.getProperty("ipTypes", SslSocketFactory.DEFAULT_IP_TYPES));
75+
this.socket = SslSocketFactory.getInstance().create(instanceName, ipTypes);
7276
}
7377
return this.socket;
7478
}

core/src/main/java/com/google/cloud/sql/core/SslSocketFactory.java

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import com.google.api.services.sqladmin.SQLAdmin.Builder;
2929
import com.google.api.services.sqladmin.SQLAdminScopes;
3030
import com.google.api.services.sqladmin.model.DatabaseInstance;
31+
import com.google.api.services.sqladmin.model.IpMapping;
3132
import com.google.api.services.sqladmin.model.SslCert;
3233
import com.google.api.services.sqladmin.model.SslCertsCreateEphemeralRequest;
3334
import com.google.cloud.sql.CredentialFactory;
@@ -58,11 +59,13 @@
5859
import java.security.cert.CertificateException;
5960
import java.security.cert.CertificateFactory;
6061
import java.security.cert.X509Certificate;
62+
import java.util.ArrayList;
6163
import java.util.Base64;
6264
import java.util.Calendar;
6365
import java.util.Collections;
6466
import java.util.GregorianCalendar;
6567
import java.util.HashMap;
68+
import java.util.List;
6669
import java.util.Map;
6770
import java.util.logging.Logger;
6871

@@ -78,6 +81,7 @@
7881
public class SslSocketFactory {
7982
private static final Logger logger = Logger.getLogger(SslSocketFactory.class.getName());
8083

84+
public static final String DEFAULT_IP_TYPES = "PUBLIC,PRIVATE";
8185
public static final String USER_TOKEN_PROPERTY_NAME = "_CLOUD_SQL_USER_TOKEN";
8286

8387
static final String ADMIN_API_NOT_ENABLED_REASON = "accessNotConfigured";
@@ -151,9 +155,9 @@ public static synchronized SslSocketFactory getInstance() {
151155
}
152156

153157
// TODO(berezv): separate creating socket and performing connection to make it easier to test
154-
public Socket create(String instanceName) throws IOException {
158+
public Socket create(String instanceName, List<String> ipTypes) throws IOException {
155159
try {
156-
return createAndConfigureSocket(instanceName, CertificateCaching.USE_CACHE);
160+
return createAndConfigureSocket(instanceName, ipTypes, CertificateCaching.USE_CACHE);
157161
} catch (SSLHandshakeException e) {
158162
logger.warning(
159163
String.format(
@@ -170,7 +174,7 @@ public Socket create(String instanceName) throws IOException {
170174
instanceName));
171175
forcedRenewRateLimiter.acquire();
172176
}
173-
return createAndConfigureSocket(instanceName, CertificateCaching.BYPASS_CACHE);
177+
return createAndConfigureSocket(instanceName, ipTypes, CertificateCaching.BYPASS_CACHE);
174178
}
175179
}
176180

@@ -183,9 +187,13 @@ private static void logTestPropertyWarning(String property) {
183187
}
184188

185189
private SSLSocket createAndConfigureSocket(
186-
String instanceName, CertificateCaching certificateCaching) throws IOException {
190+
String instanceName,
191+
List<String> ipTypes,
192+
CertificateCaching certificateCaching)
193+
throws IOException {
187194
InstanceSslInfo instanceSslInfo = getInstanceSslInfo(instanceName, certificateCaching);
188-
String ipAddress = instanceSslInfo.getInstanceIpAddress();
195+
String ipAddress = getPreferredIp(instanceName, ipTypes, instanceSslInfo);
196+
189197
logger.info(
190198
String.format(
191199
"Connecting to Cloud SQL instance [%s] on IP [%s].", instanceName, ipAddress));
@@ -203,6 +211,43 @@ private SSLSocket createAndConfigureSocket(
203211
return sslSocket;
204212
}
205213

214+
/**
215+
* Converts the string property of IP types to a list by splitting by commas, and upper-casing.
216+
*/
217+
public static List<String> listIpTypes(String cloudSqlIpTypes) {
218+
String[] rawTypes = cloudSqlIpTypes.split(",");
219+
ArrayList<String> result = new ArrayList<>(rawTypes.length);
220+
for (int i = 0; i < rawTypes.length; i++) {
221+
if (rawTypes[i].trim().equalsIgnoreCase("PUBLIC")) {
222+
result.add(i, "PRIMARY");
223+
} else {
224+
result.add(i, rawTypes[i].trim().toUpperCase());
225+
}
226+
}
227+
return result;
228+
}
229+
230+
@Nullable
231+
private String getPreferredIp(
232+
String instanceName, List<String> ipTypes, InstanceSslInfo instanceSslInfo) {
233+
String ipAddress = null;
234+
for (String ipType : ipTypes) {
235+
ipAddress = instanceSslInfo.getInstanceIpAddress(ipType);
236+
if (ipAddress != null) {
237+
break;
238+
}
239+
}
240+
241+
if (ipAddress == null) {
242+
throw new RuntimeException(
243+
String.format(
244+
"Cloud SQL instance [%s] does not have any IP addresses matching preference: [ %s ]",
245+
instanceName, String.join(", ", ipTypes)));
246+
}
247+
248+
return ipAddress;
249+
}
250+
206251
// TODO(berezv): synchronize per instance, instead of globally
207252
@VisibleForTesting
208253
synchronized InstanceSslInfo getInstanceSslInfo(
@@ -287,11 +332,9 @@ private InstanceSslInfo fetchInstanceSslInfo(
287332
DatabaseInstance instance =
288333
obtainInstanceMetadata(adminApi, instanceConnectionString, projectId, instanceName);
289334
if (instance.getIpAddresses().isEmpty()) {
290-
throw
291-
new RuntimeException(
292-
String.format(
293-
"Cloud SQL instance [%s] does not have any external IP addresses",
294-
instanceConnectionString));
335+
throw new RuntimeException(
336+
String.format(
337+
"Cloud SQL instance [%s] does not have any IP addresses", instanceConnectionString));
295338
}
296339
if (!instance.getRegion().equals(region)) {
297340
throw
@@ -323,12 +366,13 @@ private InstanceSslInfo fetchInstanceSslInfo(
323366

324367
SSLContext sslContext = createSslContext(ephemeralCertificate, instanceCaCertificate);
325368

326-
return
327-
new InstanceSslInfo(
328-
instance.getIpAddresses().get(0).getIpAddress(),
329-
ephemeralCertificate,
330-
sslContext.getSocketFactory());
369+
InstanceSslInfo info = new InstanceSslInfo(ephemeralCertificate, sslContext.getSocketFactory());
370+
371+
for (IpMapping ip : instance.getIpAddresses()) {
372+
info.putInstanceIpAddress(ip.getType(), ip.getIpAddress());
373+
}
331374

375+
return info;
332376
}
333377

334378
private SSLContext createSslContext(
@@ -613,21 +657,23 @@ public long getLastFailureMillis() {
613657
}
614658

615659
private static class InstanceSslInfo {
616-
private final String instanceIpAddress;
660+
private final HashMap<String, String> instanceIpAddresses = new HashMap<>();
617661
private final X509Certificate ephemeralCertificate;
618662
private final SSLSocketFactory sslSocketFactory;
619663

620664
InstanceSslInfo(
621-
String instanceIpAddress,
622665
X509Certificate ephemeralCertificate,
623666
SSLSocketFactory sslSocketFactory) {
624-
this.instanceIpAddress = instanceIpAddress;
625667
this.ephemeralCertificate = ephemeralCertificate;
626668
this.sslSocketFactory = sslSocketFactory;
627669
}
628670

629-
public String getInstanceIpAddress() {
630-
return instanceIpAddress;
671+
public void putInstanceIpAddress(String type, String ipAddress) {
672+
instanceIpAddresses.put(type == null ? "" : type.toUpperCase(), ipAddress);
673+
}
674+
675+
public String getInstanceIpAddress(String type) {
676+
return instanceIpAddresses.get(type.toUpperCase());
631677
}
632678

633679
public X509Certificate getEphemeralCertificate() {

0 commit comments

Comments
 (0)