Skip to content

Commit 071085c

Browse files
fix: only keep track of sockets opened using domain name (#2130)
The list `sockets` on the `MonitoredCache` should only keep track of sockets opened for connections using a domain name. It should not keep track of all sockets for regular connections. This is what it is currently doing: https://github.com/GoogleCloudPlatform/cloud-sql-jdbc-socket-factory/blob/a5329ac62ce877783c447f4192b62e72f39b63d4/core/src/main/java/com/google/cloud/sql/core/Connector.java#L141 Sockets are only removed from the list if the connections were made using a domain name: https://github.com/GoogleCloudPlatform/cloud-sql-jdbc-socket-factory/blob/a5329ac62ce877783c447f4192b62e72f39b63d4/core/src/main/java/com/google/cloud/sql/core/MonitoredCache.java#L92-L96 Thus for sockets not using a domain name and using the default instance connection name pattern, the sockets list will continue to grow and grow without being cleaned up. This PR adds the proper check to gate adding to the `sockets` list for only connections using a domain name. Also, this switches MonitoredCache to use weak references for the list of sockets. Fixes #2129
1 parent a5329ac commit 071085c

File tree

2 files changed

+228
-4
lines changed

2 files changed

+228
-4
lines changed

core/src/main/java/com/google/cloud/sql/core/MonitoredCache.java

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,16 @@
1616

1717
package com.google.cloud.sql.core;
1818

19+
import com.google.common.annotations.VisibleForTesting;
1920
import com.google.common.base.Strings;
2021
import java.io.IOException;
2122
import java.net.Socket;
22-
import java.util.ArrayList;
2323
import java.util.Collections;
2424
import java.util.Iterator;
25-
import java.util.List;
25+
import java.util.Set;
2626
import java.util.Timer;
2727
import java.util.TimerTask;
28+
import java.util.WeakHashMap;
2829
import java.util.function.Function;
2930
import javax.net.ssl.SSLSocket;
3031
import org.slf4j.Logger;
@@ -38,7 +39,11 @@
3839
class MonitoredCache implements ConnectionInfoCache {
3940
private static final Logger logger = LoggerFactory.getLogger(Connector.class);
4041
private final ConnectionInfoCache cache;
41-
private final List<Socket> sockets = Collections.synchronizedList(new ArrayList<>());
42+
// Use weak references to hold the open sockets. If a socket is no longer in
43+
// use by the application, the garabage collector will automatically remove
44+
// it from this set.
45+
private final Set<Socket> sockets =
46+
Collections.synchronizedSet(Collections.newSetFromMap(new WeakHashMap<>()));
4247
private final Function<ConnectionConfig, CloudSqlInstanceName> resolve;
4348
private final TimerTask task;
4449

@@ -49,6 +54,8 @@ class MonitoredCache implements ConnectionInfoCache {
4954
this.cache = cache;
5055
this.resolve = resolve;
5156

57+
// If this was configured with a domain name, start the domain name check
58+
// and socket cleanup periodic task.
5259
if (!Strings.isNullOrEmpty(cache.getConfig().getDomainName())) {
5360
long failoverPeriod = cache.getConfig().getConnectorConfig().getFailoverPeriod().toMillis();
5461
this.task =
@@ -64,6 +71,11 @@ public void run() {
6471
}
6572
}
6673

74+
@VisibleForTesting
75+
int getOpenSocketCount() {
76+
return sockets.size();
77+
}
78+
6779
private void checkDomainName() {
6880
// Resolve the domain name again. If it changed, close the sockets
6981
try {
@@ -149,6 +161,10 @@ public synchronized boolean isClosed() {
149161
}
150162

151163
synchronized void addSocket(SSLSocket socket) {
152-
sockets.add(socket);
164+
// Only add the socket if this was configured using a domain name,
165+
// and therefore the background socket cleanup task is running.
166+
if (!Strings.isNullOrEmpty(cache.getConfig().getDomainName())) {
167+
sockets.add(socket);
168+
}
153169
}
154170
}
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.cloud.sql.core;
18+
19+
import com.google.cloud.sql.ConnectorConfig;
20+
import java.io.IOException;
21+
import java.time.Duration;
22+
import java.util.Timer;
23+
import javax.net.ssl.HandshakeCompletedListener;
24+
import javax.net.ssl.SSLSession;
25+
import javax.net.ssl.SSLSocket;
26+
import org.junit.AfterClass;
27+
import org.junit.Assert;
28+
import org.junit.Test;
29+
30+
public class MonitoredCacheTest {
31+
private static final Timer timer = new Timer(true);
32+
33+
@AfterClass
34+
public static void afterClass() {
35+
timer.cancel();
36+
}
37+
38+
@Test
39+
public void testMonitoredCacheHoldsSocketsWithDomainName() {
40+
CloudSqlInstanceName name = new CloudSqlInstanceName("proj:reg:inst", "db.example.com");
41+
ConnectionConfig config =
42+
new ConnectionConfig.Builder()
43+
.withCloudSqlInstance("proj:reg:inst")
44+
.withDomainName("db.example.com")
45+
.build();
46+
MockCache mockCache = new MockCache(config);
47+
48+
MonitoredCache cache = new MonitoredCache(mockCache, timer, connectionConfig -> name);
49+
MockSslSocket socket = new MockSslSocket();
50+
cache.addSocket(socket);
51+
Assert.assertEquals("1 socket in cache", 1, cache.getOpenSocketCount());
52+
cache.close();
53+
Assert.assertTrue("socket closed", socket.closed);
54+
}
55+
56+
@Test
57+
public void testMonitoredCachePurgesClosedSockets() throws InterruptedException {
58+
CloudSqlInstanceName name = new CloudSqlInstanceName("proj:reg:inst", "db.example.com");
59+
// Purge sockets every 10ms.
60+
ConnectionConfig config =
61+
new ConnectionConfig.Builder()
62+
.withCloudSqlInstance("proj:reg:inst")
63+
.withDomainName("db.example.com")
64+
.withConnectorConfig(
65+
new ConnectorConfig.Builder().withFailoverPeriod(Duration.ofMillis(10)).build())
66+
.build();
67+
MockCache mockCache = new MockCache(config);
68+
69+
MonitoredCache cache = new MonitoredCache(mockCache, timer, connectionConfig -> name);
70+
MockSslSocket socket = new MockSslSocket();
71+
cache.addSocket(socket);
72+
Assert.assertEquals("1 socket in cache", 1, cache.getOpenSocketCount());
73+
socket.close();
74+
Thread.sleep(20);
75+
Assert.assertEquals("0 socket in cache", 0, cache.getOpenSocketCount());
76+
}
77+
78+
@Test
79+
public void testMonitoredCacheWithoutDomainNameIgnoresSockets() {
80+
CloudSqlInstanceName name = new CloudSqlInstanceName("proj:reg:inst");
81+
ConnectionConfig config =
82+
new ConnectionConfig.Builder().withCloudSqlInstance("proj:reg:inst").build();
83+
MockCache mockCache = new MockCache(config);
84+
85+
MonitoredCache cache = new MonitoredCache(mockCache, timer, connectionConfig -> name);
86+
MockSslSocket socket = new MockSslSocket();
87+
cache.addSocket(socket);
88+
Assert.assertEquals("0 socket in cache", 0, cache.getOpenSocketCount());
89+
}
90+
91+
private static class MockSslSocket extends SSLSocket {
92+
boolean closed;
93+
94+
@Override
95+
public synchronized boolean isClosed() {
96+
return closed;
97+
}
98+
99+
@Override
100+
public synchronized void close() {
101+
this.closed = true;
102+
}
103+
104+
@Override
105+
public String[] getSupportedCipherSuites() {
106+
return new String[0];
107+
}
108+
109+
@Override
110+
public String[] getEnabledCipherSuites() {
111+
return new String[0];
112+
}
113+
114+
@Override
115+
public void setEnabledCipherSuites(String[] suites) {}
116+
117+
@Override
118+
public String[] getSupportedProtocols() {
119+
return new String[0];
120+
}
121+
122+
@Override
123+
public String[] getEnabledProtocols() {
124+
return new String[0];
125+
}
126+
127+
@Override
128+
public void setEnabledProtocols(String[] protocols) {}
129+
130+
@Override
131+
public SSLSession getSession() {
132+
return null;
133+
}
134+
135+
@Override
136+
public void addHandshakeCompletedListener(HandshakeCompletedListener listener) {}
137+
138+
@Override
139+
public void removeHandshakeCompletedListener(HandshakeCompletedListener listener) {}
140+
141+
@Override
142+
public void startHandshake() throws IOException {}
143+
144+
@Override
145+
public void setUseClientMode(boolean mode) {}
146+
147+
@Override
148+
public boolean getUseClientMode() {
149+
return false;
150+
}
151+
152+
@Override
153+
public void setNeedClientAuth(boolean need) {}
154+
155+
@Override
156+
public boolean getNeedClientAuth() {
157+
return false;
158+
}
159+
160+
@Override
161+
public void setWantClientAuth(boolean want) {}
162+
163+
@Override
164+
public boolean getWantClientAuth() {
165+
return false;
166+
}
167+
168+
@Override
169+
public void setEnableSessionCreation(boolean flag) {}
170+
171+
@Override
172+
public boolean getEnableSessionCreation() {
173+
return false;
174+
}
175+
}
176+
177+
private static class MockCache implements ConnectionInfoCache {
178+
private final ConnectionConfig config;
179+
180+
MockCache(ConnectionConfig config) {
181+
this.config = config;
182+
}
183+
184+
@Override
185+
public ConnectionMetadata getConnectionMetadata(long timeoutMs) {
186+
return null;
187+
}
188+
189+
@Override
190+
public void forceRefresh() {}
191+
192+
@Override
193+
public void refreshIfExpired() {}
194+
195+
@Override
196+
public void close() {}
197+
198+
@Override
199+
public boolean isClosed() {
200+
return false;
201+
}
202+
203+
@Override
204+
public ConnectionConfig getConfig() {
205+
return config;
206+
}
207+
}
208+
}

0 commit comments

Comments
 (0)