Skip to content

Commit 47f6319

Browse files
committed
Improve concurrency handling during credential refresh.
Introduced a refresh task to manage concurrent refresh requests, preventing redundant attempts and potential race conditions. This aligns the refresh mechanism with the pattern used in OAuth2Credentials and ensures more robust credential management.
1 parent 8451d05 commit 47f6319

File tree

1 file changed

+218
-97
lines changed

1 file changed

+218
-97
lines changed

cab-token-generator/java/com/google/auth/credentialaccessboundary/ClientSideCredentialAccessBoundaryFactory.java

Lines changed: 218 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -47,29 +47,36 @@
4747
import com.google.auth.oauth2.StsTokenExchangeResponse;
4848
import com.google.common.annotations.VisibleForTesting;
4949
import com.google.common.base.Strings;
50-
import com.google.common.util.concurrent.SettableFuture;
50+
import com.google.common.util.concurrent.AbstractFuture;
51+
import com.google.common.util.concurrent.FutureCallback;
52+
import com.google.common.util.concurrent.Futures;
53+
import com.google.common.util.concurrent.ListenableFuture;
54+
import com.google.common.util.concurrent.ListenableFutureTask;
55+
import com.google.common.util.concurrent.MoreExecutors;
5156
import com.google.errorprone.annotations.CanIgnoreReturnValue;
5257
import java.io.IOException;
5358
import java.time.Duration;
5459
import java.util.Date;
5560
import java.util.concurrent.ExecutionException;
56-
import java.util.concurrent.ExecutorService;
57-
import java.util.concurrent.Executors;
5861
import javax.annotation.Nullable;
5962

6063
public final class ClientSideCredentialAccessBoundaryFactory {
6164
private final GoogleCredentials sourceCredential;
6265
private final transient HttpTransportFactory transportFactory;
6366
private final String tokenExchangeEndpoint;
64-
private String accessBoundarySessionKey;
65-
private AccessToken intermediateAccessToken;
6667
private final Duration minimumTokenLifetime;
6768
private final Duration refreshMargin;
6869
private static final Duration DEFAULT_REFRESH_MARGIN = Duration.ofMinutes(30);
6970
private static final Duration DEFAULT_MINIMUM_TOKEN_LIFETIME = Duration.ofMinutes(3);
70-
private final Object refreshLock = new Object[0]; // Lock for refresh operations
71-
@Nullable private SettableFuture<Void> currentRefreshFuture;
72-
private final ExecutorService backgroundExecutor = Executors.newSingleThreadExecutor();
71+
private transient RefreshTask refreshTask;
72+
private final Object refreshLock = new byte[0];
73+
private volatile IntermediateCredentials intermediateCredentials = null;
74+
75+
enum RefreshType {
76+
NONE,
77+
ASYNC,
78+
BLOCKING
79+
}
7380

7481
private ClientSideCredentialAccessBoundaryFactory(Builder builder) {
7582
this.transportFactory = builder.transportFactory;
@@ -83,24 +90,140 @@ private ClientSideCredentialAccessBoundaryFactory(Builder builder) {
8390
: DEFAULT_MINIMUM_TOKEN_LIFETIME;
8491
}
8592

93+
public static Builder newBuilder() {
94+
return new Builder();
95+
}
96+
97+
public AccessToken generateToken(CredentialAccessBoundary accessBoundary) {
98+
// TODO(negarb/jiahuah): Implement generateToken
99+
// Note: This method will call refreshCredentialsIfRequired().
100+
throw new UnsupportedOperationException("generateToken is not yet implemented.");
101+
}
102+
103+
/**
104+
* Refreshes the intermediate access token and access boundary session key if required.
105+
*
106+
* <p>This method checks the expiration time of the current intermediate access token and
107+
* initiates a refresh if necessary. The refresh process also refreshes the underlying source
108+
* credentials.
109+
*
110+
* @throws IOException If an error occurs during the refresh process, such as network issues,
111+
* invalid credentials, or problems with the token exchange endpoint.
112+
*/
113+
private void refreshCredentialsIfRequired() throws IOException {
114+
RefreshType refreshType = determineRefreshType();
115+
116+
if (refreshType == RefreshType.NONE) {
117+
return; // No refresh needed, token is still valid.
118+
}
119+
120+
// If a refresh is required, create or retrieve the refresh task.
121+
RefreshTask refreshTask = getOrCreateRefreshTask();
122+
123+
// Handle the refresh based on the determined refresh type.
124+
switch (refreshType) {
125+
case BLOCKING:
126+
if (refreshTask.isNew) {
127+
// Execute the new refresh task synchronously on a direct executor.
128+
// This blocks until the refresh is complete.
129+
MoreExecutors.directExecutor().execute(refreshTask.task);
130+
} else {
131+
// A refresh is already in progress, wait for it to complete.
132+
try {
133+
refreshTask.task.get();
134+
} catch (InterruptedException e) {
135+
// Restore the interrupted status and throw an exception.
136+
Thread.currentThread().interrupt();
137+
throw new IOException(
138+
"Interrupted while asynchronously refreshing the intermediate credentials", e);
139+
} catch (ExecutionException e) {
140+
// Unwrap the underlying cause of the execution exception.
141+
Throwable cause = e.getCause();
142+
if (cause instanceof IOException) {
143+
throw (IOException) cause;
144+
} else if (cause instanceof RuntimeException) {
145+
throw (RuntimeException) cause;
146+
} else {
147+
// Wrap other exceptions in an IOException.
148+
throw new IOException("Unexpected error refreshing intermediate credentials", cause);
149+
}
150+
}
151+
}
152+
break;
153+
case ASYNC:
154+
if (refreshTask.isNew) {
155+
// Start a new background thread for the refresh task.
156+
// This allows the current thread to continue without blocking.
157+
new Thread(refreshTask.task).start();
158+
} // (No else needed - if not new, another thread is handling the refresh)
159+
break;
160+
}
161+
}
162+
163+
private RefreshType determineRefreshType() {
164+
if (intermediateCredentials == null
165+
|| intermediateCredentials.intermediateAccessToken == null) {
166+
// A blocking refresh is needed if the intermediate access token doesn't exist.
167+
return RefreshType.BLOCKING;
168+
}
169+
170+
AccessToken localAccessToken = intermediateCredentials.intermediateAccessToken;
171+
Date expirationTime = localAccessToken.getExpirationTime();
172+
if (expirationTime == null) {
173+
return RefreshType.NONE; // Token does not expire, no refresh needed.
174+
}
175+
176+
Duration remaining =
177+
Duration.ofMillis(expirationTime.getTime() - Clock.SYSTEM.currentTimeMillis());
178+
if (remaining.compareTo(minimumTokenLifetime) <= 0) {
179+
// Intermediate token has expired or remaining lifetime is less than the minimum required
180+
// for CAB token generation. A blocking refresh is necessary.
181+
return RefreshType.BLOCKING;
182+
} else if (remaining.compareTo(refreshMargin) <= 0) {
183+
// The token is nearing expiration, an async refresh is needed.
184+
return RefreshType.ASYNC;
185+
}
186+
// Token is still fresh, no refresh needed.
187+
return RefreshType.NONE;
188+
}
189+
190+
/**
191+
* Atomically creates a single flight refresh task.
192+
*
193+
* <p>Only a single refresh task can be scheduled at a time. If there is an existing task, it will
194+
* be returned for subsequent invocations. However, if a new task is created, it is the
195+
* responsibility of the caller to execute it. The task will clear the single flight slot upon
196+
* completion.
197+
*/
198+
private RefreshTask getOrCreateRefreshTask() {
199+
synchronized (refreshLock) {
200+
if (refreshTask != null) {
201+
// An existing refresh task is already in progress. Return a NEW RefreshTask instance with
202+
// the existing task, but set isNew to false. This indicates to the caller that a new
203+
// refresh task was NOT created.
204+
return new RefreshTask(refreshTask.task, false);
205+
}
206+
207+
// No refresh task is currently running. Create and return a new refresh task.
208+
final ListenableFutureTask<IntermediateCredentials> task =
209+
ListenableFutureTask.create(this::refreshCredentials);
210+
return new RefreshTask(task, true);
211+
}
212+
}
213+
86214
/**
87215
* Refreshes the source credential and exchanges it for an intermediate access token using the STS
88216
* endpoint.
89217
*
90218
* <p>If the source credential is expired, it will be refreshed. A token exchange request is then
91-
* made to the STS endpoint. The resulting intermediate access token and access boundary session
92-
* key are stored. The intermediate access token's expiration time is determined as follows:
93-
*
94-
* <ol>
95-
* <li>If the STS response includes `expires_in`, that value is used.
96-
* <li>Otherwise, if the source credential has an expiration time, that value is used.
97-
* <li>Otherwise, the intermediate token will have no expiration time.
98-
* </ol>
219+
* made to the STS endpoint.
99220
*
221+
* @return The refreshed {@link IntermediateCredentials} containing the intermediate access token
222+
* and access boundary session key.
100223
* @throws IOException If an error occurs during credential refresh or token exchange.
101224
*/
102225
@VisibleForTesting
103-
void refreshCredentials() throws IOException {
226+
IntermediateCredentials refreshCredentials() throws IOException {
104227
try {
105228
// Force a refresh on the source credentials. The intermediate token's lifetime is tied to the
106229
// source credential's expiration. The factory's refreshMargin might be different from the
@@ -128,13 +251,18 @@ void refreshCredentials() throws IOException {
128251
.build();
129252

130253
StsTokenExchangeResponse response = handler.exchangeToken();
131-
132-
synchronized (refreshLock) {
133-
this.accessBoundarySessionKey = response.getAccessBoundarySessionKey();
134-
this.intermediateAccessToken = getTokenFromResponse(response, sourceAccessToken);
135-
}
254+
return new IntermediateCredentials(
255+
getTokenFromResponse(response, sourceAccessToken), response.getAccessBoundarySessionKey());
136256
}
137257

258+
/**
259+
* Extracts the access token from the STS exchange response and sets the appropriate expiration
260+
* time.
261+
*
262+
* @param response The STS token exchange response.
263+
* @param sourceAccessToken The original access token used for the exchange.
264+
* @return The intermediate access token.
265+
*/
138266
private static AccessToken getTokenFromResponse(
139267
StsTokenExchangeResponse response, AccessToken sourceAccessToken) {
140268
AccessToken intermediateToken = response.getAccessToken();
@@ -152,93 +280,83 @@ private static AccessToken getTokenFromResponse(
152280
return intermediateToken; // Return original if no modification needed
153281
}
154282

155-
private void startAsynchronousRefresh() {
156-
// Obtain the lock before checking or modifying currentRefreshFuture to prevent race conditions.
157-
synchronized (refreshLock) {
158-
// Only start an asynchronous refresh if one is not already in progress.
159-
if (currentRefreshFuture == null || currentRefreshFuture.isDone()) {
160-
SettableFuture<Void> future = SettableFuture.create();
161-
currentRefreshFuture = future;
162-
backgroundExecutor.execute(
163-
() -> {
164-
try {
165-
refreshCredentials();
166-
future.set(null); // Signal successful completion.
167-
} catch (Throwable t) {
168-
future.setException(t); // Set the exception if refresh fails.
169-
} finally {
170-
currentRefreshFuture = null;
171-
}
172-
});
173-
}
174-
}
175-
}
176-
177-
private void blockingRefresh() throws IOException {
178-
// Obtain the lock before checking the currentRefreshFuture to prevent race conditions.
283+
/**
284+
* Completes the refresh task by storing the results and clearing the single flight slot.
285+
*
286+
* <p>This method is called when a refresh task finishes. It stores the refreshed credentials if
287+
* successful. The single-flight "slot" is cleared, allowing subsequent refresh attempts. Any
288+
* exceptions during the refresh are caught and suppressed to prevent indefinite blocking of
289+
* subsequent refresh attempts.
290+
*/
291+
private void finishRefreshTask(ListenableFuture<IntermediateCredentials> finishedTask) {
179292
synchronized (refreshLock) {
180-
if (currentRefreshFuture != null && !currentRefreshFuture.isDone()) {
181-
try {
182-
currentRefreshFuture.get(); // Wait for the asynchronous refresh to complete.
183-
} catch (InterruptedException e) {
184-
Thread.currentThread().interrupt(); // Restore the interrupt status
185-
throw new IOException("Interrupted while waiting for asynchronous refresh.", e);
186-
} catch (ExecutionException e) {
187-
Throwable cause = e.getCause(); // Unwrap the underlying cause
188-
if (cause instanceof IOException) {
189-
throw (IOException) cause;
190-
} else {
191-
throw new IOException("Asynchronous refresh failed.", cause);
192-
}
293+
try {
294+
this.intermediateCredentials = Futures.getDone(finishedTask);
295+
} catch (Exception e) {
296+
// noop
297+
} finally {
298+
if (this.refreshTask != null && this.refreshTask.task == finishedTask) {
299+
this.refreshTask = null;
193300
}
194-
} else {
195-
// No asynchronous refresh is running, perform a synchronous refresh.
196-
refreshCredentials();
197301
}
198302
}
199303
}
200304

201305
/**
202-
* Refreshes the intermediate access token and access boundary session key if required.
306+
* Holds intermediate credentials obtained from the STS token exchange endpoint.
203307
*
204-
* <p>This method checks the expiration time of the current intermediate access token and
205-
* initiates a refresh if necessary. The refresh process also refreshes the underlying source
206-
* credentials.
207-
*
208-
* @throws IOException If an error occurs during the refresh process, such as network issues,
209-
* invalid credentials, or problems with the token exchange endpoint.
308+
* <p>These credentials include an intermediate access token and an access boundary session key.
210309
*/
211-
private void refreshCredentialsIfRequired() throws IOException {
212-
AccessToken localAccessToken = intermediateAccessToken;
213-
if (localAccessToken != null) {
214-
Date expirationTime = localAccessToken.getExpirationTime();
215-
if (expirationTime == null) {
216-
return; // Token does not expire, no refresh needed.
217-
}
310+
private static class IntermediateCredentials {
311+
private final AccessToken intermediateAccessToken;
312+
private final String accessBoundarySessionKey;
218313

219-
Duration remaining =
220-
Duration.ofMillis(expirationTime.getTime() - Clock.SYSTEM.currentTimeMillis());
221-
if (remaining.compareTo(minimumTokenLifetime) <= 0) {
222-
// Intermediate token has expired or remaining lifetime is less than the minimum required
223-
// for CAB token generation. Perform a synchronous refresh immediately.
224-
blockingRefresh();
225-
} else if (remaining.compareTo(refreshMargin) <= 0) {
226-
// The token is nearing expiration, start an asynchronous refresh in the background.
227-
startAsynchronousRefresh();
228-
}
229-
} else {
230-
// No intermediate access token exists; a synchronous refresh must be performed.
231-
blockingRefresh();
314+
IntermediateCredentials(AccessToken accessToken, String accessBoundarySessionKey) {
315+
this.intermediateAccessToken = accessToken;
316+
this.accessBoundarySessionKey = accessBoundarySessionKey;
232317
}
233318
}
234319

235-
public AccessToken generateToken(CredentialAccessBoundary accessBoundary) {
236-
// TODO(negarb/jiahuah): Implement generateToken
237-
throw new UnsupportedOperationException("generateToken is not yet implemented.");
238-
}
320+
/**
321+
* Represents a task for refreshing intermediate credentials, ensuring that only one refresh
322+
* operation is in progress at a time.
323+
*
324+
* <p>The {@code isNew} flag indicates whether this is a newly initiated refresh operation or an
325+
* existing one already in progress. This distinction is used to prevent redundant refreshes.
326+
*/
327+
class RefreshTask extends AbstractFuture<IntermediateCredentials> implements Runnable {
328+
private final ListenableFutureTask<IntermediateCredentials> task;
329+
final boolean isNew;
330+
331+
RefreshTask(ListenableFutureTask<IntermediateCredentials> task, boolean isNew) {
332+
this.task = task;
333+
this.isNew = isNew;
334+
335+
// Add listener to update factory's credentials when the task completes.
336+
task.addListener(() -> finishRefreshTask(task), MoreExecutors.directExecutor());
337+
338+
// Add callback to set the result or exception based on the outcome.
339+
Futures.addCallback(
340+
task,
341+
new FutureCallback<IntermediateCredentials>() {
342+
@Override
343+
public void onSuccess(IntermediateCredentials result) {
344+
RefreshTask.this.set(result);
345+
}
346+
347+
@Override
348+
public void onFailure(@Nullable Throwable t) {
349+
RefreshTask.this.setException(
350+
t != null ? t : new IOException("Refresh failed with null Throwable."));
351+
}
352+
},
353+
MoreExecutors.directExecutor());
354+
}
239355

240-
public static Builder newBuilder() {
241-
return new Builder();
356+
@Override
357+
public void run() {
358+
task.run();
359+
}
242360
}
243361

244362
public static class Builder {
@@ -363,12 +481,15 @@ public ClientSideCredentialAccessBoundaryFactory build() {
363481

364482
@VisibleForTesting
365483
String getAccessBoundarySessionKey() {
366-
return accessBoundarySessionKey;
484+
485+
return intermediateCredentials != null
486+
? intermediateCredentials.accessBoundarySessionKey
487+
: null;
367488
}
368489

369490
@VisibleForTesting
370491
AccessToken getIntermediateAccessToken() {
371-
return intermediateAccessToken;
492+
return intermediateCredentials != null ? intermediateCredentials.intermediateAccessToken : null;
372493
}
373494

374495
@VisibleForTesting

0 commit comments

Comments
 (0)