Skip to content

Commit 286051d

Browse files
authored
fix: Lazy refresh should refresh tokens 4 minutes before expiration. (#2063)
Refresh tokens and certificates 4 minutes before they expire to avoid creating race condition that would allow the connector to create an ephemeral certificate with an expired auth token. Now, IAM auth tokens are now refreshed 4 minutes before they token expire. Also, the Lazy Refresh Strategy will refresh the client certificate 4 minutes before the expiration of the certificate and the IAM auth token. This should mitigate some of the strange certificate expiration errors commonly found in Cloud Run, see: #2059
1 parent 0fbea44 commit 286051d

File tree

6 files changed

+104
-23
lines changed

6 files changed

+104
-23
lines changed

core/src/main/java/com/google/cloud/sql/core/DefaultAccessTokenSupplier.java

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package com.google.cloud.sql.core;
1818

19+
import static com.google.cloud.sql.core.RefreshCalculator.DEFAULT_REFRESH_BUFFER;
20+
1921
import com.google.auth.oauth2.AccessToken;
2022
import com.google.auth.oauth2.GoogleCredentials;
2123
import com.google.cloud.sql.AuthType;
@@ -76,12 +78,12 @@ public Optional<AccessToken> get() throws IOException {
7678
() -> {
7779
final GoogleCredentials credentials = credentialFactory.getCredentials();
7880
try {
79-
credentials.refreshIfExpired();
81+
refreshIfRequired(credentials);
8082
} catch (IllegalStateException e) {
8183
throw new IllegalStateException("Error refreshing credentials " + credentials, e);
8284
}
83-
if (credentials.getAccessToken() == null
84-
|| "".equals(credentials.getAccessToken().getTokenValue())) {
85+
86+
if (isAccessTokenEmpty(credentials)) {
8587

8688
String errorMessage = "Access Token has length of zero";
8789
logger.debug(errorMessage);
@@ -97,19 +99,17 @@ public Optional<AccessToken> get() throws IOException {
9799
// For some implementations of GoogleCredentials, particularly
98100
// ImpersonatedCredentials, down-scoped credentials are not
99101
// initialized with a token and need to be explicitly refreshed.
100-
if (downscoped.getAccessToken() == null
101-
|| "".equals(downscoped.getAccessToken().getTokenValue())) {
102+
if (isAccessTokenEmpty(downscoped)) {
102103
try {
103-
downscoped.refreshIfExpired();
104+
downscoped.refresh();
104105
} catch (Exception e) {
105106
throw new IllegalStateException(
106107
"Error refreshing downscoped credentials " + credentials, e);
107108
}
108109

109110
// After attempting to refresh once, if the downscoped credentials do not have
110111
// an access token after attempting to refresh, then throw an IllegalStateException
111-
if (downscoped.getAccessToken() == null
112-
|| "".equals(downscoped.getAccessToken().getTokenValue())) {
112+
if (isAccessTokenEmpty(downscoped)) {
113113
String errorMessage = "Downscoped access token has length of zero";
114114
logger.debug(errorMessage);
115115

@@ -135,6 +135,28 @@ public Optional<AccessToken> get() throws IOException {
135135
}
136136
}
137137

138+
private static boolean isAccessTokenEmpty(GoogleCredentials credentials) {
139+
return credentials.getAccessToken() == null
140+
|| "".equals(credentials.getAccessToken().getTokenValue());
141+
}
142+
143+
private void refreshIfRequired(GoogleCredentials credentials) throws IOException {
144+
// if the token does not exist, or if the token expires in less than 4 minutes, refresh it.
145+
if (credentials.getAccessToken() == null) {
146+
logger.debug("Current IAM AuthN Token is not set. Refreshing the token.");
147+
credentials.refresh();
148+
} else if (credentials.getAccessToken().getExpirationTime() != null
149+
&& credentials
150+
.getAccessToken()
151+
.getExpirationTime()
152+
.toInstant()
153+
.minus(DEFAULT_REFRESH_BUFFER)
154+
.isBefore(Instant.now())) {
155+
logger.debug("Current IAM AuthN Token expires in less than 4 minutes. Refreshing the token.");
156+
credentials.refresh();
157+
}
158+
}
159+
138160
private void validateAccessTokenExpiration(AccessToken accessToken) {
139161
Date expirationTimeDate = accessToken.getExpirationTime();
140162

core/src/main/java/com/google/cloud/sql/core/DefaultConnectionInfoRepository.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import java.security.cert.X509Certificate;
4747
import java.time.Instant;
4848
import java.util.Arrays;
49+
import java.util.Base64;
4950
import java.util.HashMap;
5051
import java.util.List;
5152
import java.util.Map;
@@ -196,7 +197,15 @@ private static ConnectionInfo createConnectionInfo(
196197
.orElse(x509Certificate.getNotAfter().toInstant());
197198
}
198199

199-
logger.debug(String.format("[%s] INSTANCE DATA DONE", instanceName));
200+
logger.debug(
201+
"[{}] INSTANCE DATA DONE - Ephemeral cert id: {} cert expiration: {} token expiration: {}",
202+
instanceName,
203+
Base64.getEncoder().encodeToString(((X509Certificate) ephemeralCertificate).getSignature()),
204+
token
205+
.map(tok -> tok.getExpirationTime())
206+
.filter(time -> time != null)
207+
.map(time -> time.toInstant().toString())
208+
.orElse("(none)"));
200209

201210
return new ConnectionInfo(metadata, sslContext, expiration);
202211
}

core/src/main/java/com/google/cloud/sql/core/LazyRefreshConnectionInfoCache.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package com.google.cloud.sql.core;
1818

19+
import static com.google.cloud.sql.core.RefreshCalculator.DEFAULT_REFRESH_BUFFER;
20+
1921
import com.google.cloud.sql.CredentialFactory;
2022
import java.security.KeyPair;
2123

@@ -54,7 +56,8 @@ public LazyRefreshConnectionInfoCache(
5456
config.getCloudSqlInstance(),
5557
() ->
5658
connectionInfoRepository.getConnectionInfoSync(
57-
instanceName, accessTokenSupplier, config.getAuthType(), keyPair));
59+
instanceName, accessTokenSupplier, config.getAuthType(), keyPair),
60+
DEFAULT_REFRESH_BUFFER);
5861
}
5962

6063
@Override

core/src/main/java/com/google/cloud/sql/core/LazyRefreshStrategy.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package com.google.cloud.sql.core;
1818

1919
import com.google.errorprone.annotations.concurrent.GuardedBy;
20+
import java.time.Duration;
2021
import java.time.Instant;
2122
import java.util.function.Supplier;
2223
import org.slf4j.Logger;
@@ -28,6 +29,7 @@ public class LazyRefreshStrategy implements RefreshStrategy {
2829

2930
private final String name;
3031
private final Supplier<ConnectionInfo> refreshOperation;
32+
private final Duration refreshBuffer;
3133

3234
private final Object connectionInfoGuard = new Object();
3335

@@ -38,9 +40,11 @@ public class LazyRefreshStrategy implements RefreshStrategy {
3840
private boolean closed;
3941

4042
/** Creates a new LazyRefreshStrategy instance. */
41-
public LazyRefreshStrategy(String name, Supplier<ConnectionInfo> refreshOperation) {
43+
public LazyRefreshStrategy(
44+
String name, Supplier<ConnectionInfo> refreshOperation, Duration refreshDuration) {
4245
this.name = name;
4346
this.refreshOperation = refreshOperation;
47+
this.refreshBuffer = refreshDuration;
4448
}
4549

4650
@Override
@@ -59,7 +63,7 @@ public ConnectionInfo getConnectionInfo(long timeoutMs) {
5963
name));
6064
fetchConnectionInfo();
6165
}
62-
if (Instant.now().isAfter(connectionInfo.getExpiration())) {
66+
if (Instant.now().isAfter(connectionInfo.getExpiration().minus(refreshBuffer))) {
6367
logger.debug(
6468
String.format(
6569
"[%s] Lazy Refresh Operation: Client certificate has expired. Starting next "

core/src/main/java/com/google/cloud/sql/core/RefreshCalculator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class RefreshCalculator {
2828
// defaultRefreshBuffer is the minimum amount of time for which a
2929
// certificate must be valid to ensure the next refresh attempt has adequate
3030
// time to complete.
31-
private static final Duration DEFAULT_REFRESH_BUFFER = Duration.ofMinutes(4);
31+
static final Duration DEFAULT_REFRESH_BUFFER = Duration.ofMinutes(4);
3232

3333
long calculateSecondsUntilNextRefresh(Instant now, Instant expiration) {
3434
Duration timeUntilExp = Duration.between(now, expiration);

core/src/test/java/com/google/cloud/sql/core/LazyRefreshStrategyTest.java

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import static com.google.common.truth.Truth.assertThat;
2020
import static org.junit.Assert.assertThrows;
2121

22+
import java.time.Duration;
2223
import java.time.Instant;
2324
import java.time.temporal.ChronoUnit;
2425
import java.util.concurrent.atomic.AtomicInteger;
@@ -32,7 +33,9 @@ public void testCloudSqlInstanceDataRetrievedSuccessfully() {
3233
final ExampleData data = new ExampleData(Instant.now().plus(1, ChronoUnit.HOURS));
3334
LazyRefreshStrategy r =
3435
new LazyRefreshStrategy(
35-
"LazyRefresherTest.testCloudSqlInstanceDataRetrievedSuccessfully", () -> data);
36+
"LazyRefresherTest.testCloudSqlInstanceDataRetrievedSuccessfully",
37+
() -> data,
38+
Duration.ZERO);
3639
ConnectionInfo gotInfo = r.getConnectionInfo(TEST_TIMEOUT_MS);
3740
assertThat(gotInfo).isSameInstanceAs(data);
3841
}
@@ -44,7 +47,8 @@ public void testInstanceFailsOnConnectionError() {
4447
"LazyRefresherTest.testInstanceFailsOnConnectionError",
4548
() -> {
4649
throw new RuntimeException("always fails");
47-
});
50+
},
51+
Duration.ZERO);
4852

4953
RuntimeException ex =
5054
assertThrows(RuntimeException.class, () -> r.getConnectionInfo(TEST_TIMEOUT_MS));
@@ -62,7 +66,8 @@ public void testCloudSqlInstanceForcesRefresh() throws Exception {
6266
() -> {
6367
refreshCount.incrementAndGet();
6468
return data;
65-
});
69+
},
70+
Duration.ZERO);
6671

6772
r.getConnectionInfo(TEST_TIMEOUT_MS);
6873
assertThat(refreshCount.get()).isEqualTo(1);
@@ -96,24 +101,59 @@ public void testCloudSqlRefreshesExpiredData() throws Exception {
96101
return initialData;
97102
}
98103
return data;
99-
});
104+
},
105+
Duration.ZERO);
100106

101107
// Get the first data that is about to expire
102108
ConnectionInfo d = r.getConnectionInfo(TEST_TIMEOUT_MS);
103109
assertThat(refreshCount.get()).isEqualTo(1);
104110
assertThat(d).isSameInstanceAs(initialData);
105111

106-
waitForExpiration(initialData);
112+
waitForExpiration(initialData.getExpiration());
107113

108114
assertThat(r.getConnectionInfo(TEST_TIMEOUT_MS)).isSameInstanceAs(data);
109115
assertThat(refreshCount.get()).isEqualTo(2);
110116
}
111117

112-
private static void waitForExpiration(ExampleData initialData) throws InterruptedException {
118+
@Test
119+
public void testCloudSqlRefreshesExpiredDataWithRefreshBuffer() throws Exception {
120+
Duration buffer = Duration.ofSeconds(5);
121+
ExampleData initialData = new ExampleData(Instant.now().plus(8, ChronoUnit.SECONDS));
122+
ExampleData data = new ExampleData(Instant.now().plus(1, ChronoUnit.HOURS));
123+
124+
AtomicInteger refreshCount = new AtomicInteger();
125+
126+
LazyRefreshStrategy r =
127+
new LazyRefreshStrategy(
128+
"LazyRefresherTest.testCloudSqlRefreshesExpiredData",
129+
() -> {
130+
int c = refreshCount.getAndIncrement();
131+
if (c == 0) {
132+
return initialData;
133+
}
134+
return data;
135+
},
136+
buffer);
137+
138+
// Get the first data that is about to expire, but not yet expired
139+
ConnectionInfo d = r.getConnectionInfo(TEST_TIMEOUT_MS);
140+
assertThat(refreshCount.get()).isEqualTo(1);
141+
assertThat(d).isSameInstanceAs(initialData);
142+
143+
// Wait 5 seconds. Now the initialData expires in 3 seconds, less than the buffer.
144+
waitForExpiration(Instant.now().plus(buffer));
145+
146+
// Assert that it gets data instead of initialData.
147+
assertThat(r.getConnectionInfo(TEST_TIMEOUT_MS)).isSameInstanceAs(data);
148+
assertThat(refreshCount.get()).isEqualTo(2);
149+
}
150+
151+
private static void waitForExpiration(Instant expiration) throws InterruptedException {
113152
// Wait for the instance to expire
114-
while (!Instant.now().isAfter(initialData.getExpiration())) {
153+
while (!Instant.now().isAfter(expiration)) {
115154
Thread.sleep(10);
116155
}
156+
117157
// Sleep a few more ms to make sure that Instant.now() really is after expiration.
118158
// Fixes a date math race condition only present in Java 8.
119159
Thread.sleep(10);
@@ -135,15 +175,16 @@ public void testThatConcurrentRequestsDontCauseDuplicateRefreshAttempts() throws
135175
return initialData;
136176
}
137177
return data;
138-
});
178+
},
179+
Duration.ZERO);
139180

140181
// Get the first data that is about to expire
141182
ConnectionInfo d = r.getConnectionInfo(TEST_TIMEOUT_MS);
142183
assertThat(refreshCount.get()).isEqualTo(1);
143184
assertThat(d).isSameInstanceAs(initialData);
144185

145186
// Wait for the instance to expire
146-
waitForExpiration(initialData);
187+
waitForExpiration(initialData.getExpiration());
147188

148189
// Start multiple threads and request connection info
149190
Thread t1 = new Thread(() -> r.getConnectionInfo(TEST_TIMEOUT_MS));
@@ -166,7 +207,9 @@ public void testClosedCloudSqlInstanceDataThrowsException() {
166207
ExampleData data = new ExampleData(Instant.now().plus(1, ChronoUnit.HOURS));
167208
LazyRefreshStrategy r =
168209
new LazyRefreshStrategy(
169-
"RefresherTest.testClosedCloudSqlInstanceDataThrowsException", () -> data);
210+
"RefresherTest.testClosedCloudSqlInstanceDataThrowsException",
211+
() -> data,
212+
Duration.ZERO);
170213
r.close();
171214

172215
assertThrows(IllegalStateException.class, () -> r.getConnectionInfo(TEST_TIMEOUT_MS));

0 commit comments

Comments
 (0)