Skip to content

Commit eb4b793

Browse files
committed
feat: Implement refreshCredentialsIfRequired for intermediate token refresh
Implement `refreshCredentialsIfRequired`, called by `generateToken()`, to handle token refresh. It uses `refreshMargin` and `minimumTokenLifetime` to decide on synchronous or asynchronous refresh
1 parent 5d32642 commit eb4b793

File tree

1 file changed

+169
-17
lines changed

1 file changed

+169
-17
lines changed

cab-token-generator/java/com/google/auth/credentialaccessboundary/ClientSideCredentialAccessBoundaryFactory.java

Lines changed: 169 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import static com.google.auth.oauth2.OAuth2Utils.TOKEN_EXCHANGE_URL_FORMAT;
3636
import static com.google.common.base.Preconditions.checkNotNull;
3737

38+
import com.google.api.client.util.Clock;
3839
import com.google.auth.Credentials;
3940
import com.google.auth.http.HttpTransportFactory;
4041
import com.google.auth.oauth2.AccessToken;
@@ -45,41 +46,65 @@
4546
import com.google.auth.oauth2.StsTokenExchangeRequest;
4647
import com.google.auth.oauth2.StsTokenExchangeResponse;
4748
import com.google.common.base.Strings;
49+
import com.google.common.util.concurrent.SettableFuture;
4850
import com.google.errorprone.annotations.CanIgnoreReturnValue;
4951
import java.io.IOException;
52+
import java.time.Duration;
53+
import java.util.Date;
54+
import java.util.concurrent.ExecutionException;
55+
import java.util.concurrent.ExecutorService;
56+
import java.util.concurrent.Executors;
57+
import javax.annotation.Nullable;
5058

5159
public final class ClientSideCredentialAccessBoundaryFactory {
5260
private final GoogleCredentials sourceCredential;
5361
private final transient HttpTransportFactory transportFactory;
5462
private final String tokenExchangeEndpoint;
5563
private String accessBoundarySessionKey;
5664
private AccessToken intermediateAccessToken;
65+
private final Duration minimumTokenLifetime;
66+
private final Duration refreshMargin;
67+
private static final Duration DEFAULT_REFRESH_MARGIN = Duration.ofMinutes(30);
68+
private static final Duration DEFAULT_MINIMUM_TOKEN_LIFETIME = Duration.ofMinutes(3);
69+
private final Object refreshLock = new Object[0]; // Lock for refresh operations
70+
@Nullable private SettableFuture<Void> currentRefreshFuture;
71+
private final ExecutorService backgroundExecutor = Executors.newSingleThreadExecutor();
5772

5873
private ClientSideCredentialAccessBoundaryFactory(Builder builder) {
5974
this.transportFactory = builder.transportFactory;
6075
this.sourceCredential = builder.sourceCredential;
6176
this.tokenExchangeEndpoint = builder.tokenExchangeEndpoint;
77+
this.refreshMargin =
78+
builder.refreshMargin != null ? builder.refreshMargin : DEFAULT_REFRESH_MARGIN;
79+
this.minimumTokenLifetime =
80+
builder.minimumTokenLifetime != null
81+
? builder.minimumTokenLifetime
82+
: DEFAULT_MINIMUM_TOKEN_LIFETIME;
6283
}
6384

6485
/**
65-
* Refreshes the source credential and exchanges it for an intermediary access token using the STS
86+
* Refreshes the source credential and exchanges it for an intermediate access token using the STS
6687
* endpoint.
6788
*
6889
* <p>If the source credential is expired, it will be refreshed. A token exchange request is then
69-
* made to the STS endpoint. The resulting intermediary access token and access boundary session
70-
* key are stored. The intermediary access token's expiration time is determined as follows:
90+
* made to the STS endpoint. The resulting intermediate access token and access boundary session
91+
* key are stored. The intermediate access token's expiration time is determined as follows:
7192
*
7293
* <ol>
7394
* <li>If the STS response includes `expires_in`, that value is used.
7495
* <li>Otherwise, if the source credential has an expiration time, that value is used.
75-
* <li>Otherwise, the intermediary token will have no expiration time.
96+
* <li>Otherwise, the intermediate token will have no expiration time.
7697
* </ol>
7798
*
7899
* @throws IOException If an error occurs during credential refresh or token exchange.
79100
*/
80101
private void refreshCredentials() throws IOException {
81102
try {
82-
this.sourceCredential.refreshIfExpired();
103+
// Force a refresh on the source credentials. The intermediate token's lifetime is tied to the
104+
// source credential's expiration. The factory's refreshMargin might be different from the
105+
// refreshMargin on source credentials. This ensures the intermediate access token
106+
// meets this factory's refresh margin and minimum lifetime requirements.
107+
sourceCredential.refresh();
83108
} catch (IOException e) {
84109
throw new IOException("Unable to refresh the provided source credential.", e);
85110
}
@@ -101,25 +126,108 @@ private void refreshCredentials() throws IOException {
101126
.build();
102127

103128
StsTokenExchangeResponse response = handler.exchangeToken();
104-
this.accessBoundarySessionKey = response.getAccessBoundarySessionKey();
105-
this.intermediateAccessToken = response.getAccessToken();
106129

107-
// The STS endpoint will only return the expiration time for the intermediary token
130+
synchronized (refreshLock) {
131+
this.accessBoundarySessionKey = response.getAccessBoundarySessionKey();
132+
this.intermediateAccessToken = getTokenFromResponse(response, sourceAccessToken);
133+
}
134+
}
135+
136+
private static AccessToken getTokenFromResponse(
137+
StsTokenExchangeResponse response, AccessToken sourceAccessToken) {
138+
AccessToken intermediateToken = response.getAccessToken();
139+
140+
// The STS endpoint will only return the expiration time for the intermediate token
108141
// if the original access token represents a service account.
109-
// The intermediary token's expiration time will always match the source credential expiration.
142+
// The intermediate token's expiration time will always match the source credential
143+
// expiration.
110144
// When no expires_in is returned, we can copy the source credential's expiration time.
111-
if (response.getAccessToken().getExpirationTime() == null) {
112-
if (sourceAccessToken.getExpirationTime() != null) {
113-
this.intermediateAccessToken =
114-
new AccessToken(
115-
response.getAccessToken().getTokenValue(), sourceAccessToken.getExpirationTime());
145+
if (intermediateToken.getExpirationTime() == null
146+
&& sourceAccessToken.getExpirationTime() != null) {
147+
return new AccessToken(
148+
intermediateToken.getTokenValue(), sourceAccessToken.getExpirationTime());
149+
}
150+
return intermediateToken; // Return original if no modification needed
151+
}
152+
153+
private void startAsynchronousRefresh() {
154+
// Obtain the lock before checking or modifying currentRefreshFuture to prevent race conditions.
155+
synchronized (refreshLock) {
156+
// Only start an asynchronous refresh if one is not already in progress.
157+
if (currentRefreshFuture == null || currentRefreshFuture.isDone()) {
158+
SettableFuture<Void> future = SettableFuture.create();
159+
currentRefreshFuture = future;
160+
backgroundExecutor.execute(
161+
() -> {
162+
try {
163+
refreshCredentials();
164+
future.set(null); // Signal successful completion.
165+
} catch (Throwable t) {
166+
future.setException(t); // Set the exception if refresh fails.
167+
} finally {
168+
currentRefreshFuture = null;
169+
}
170+
});
116171
}
117172
}
118173
}
119174

120-
private void refreshCredentialsIfRequired() {
121-
// TODO(negarb): Implement refreshCredentialsIfRequired
122-
throw new UnsupportedOperationException("refreshCredentialsIfRequired is not yet implemented.");
175+
private void blockingRefresh() throws IOException {
176+
// Obtain the lock before checking the currentRefreshFuture to prevent race conditions.
177+
synchronized (refreshLock) {
178+
if (currentRefreshFuture != null && !currentRefreshFuture.isDone()) {
179+
try {
180+
currentRefreshFuture.get(); // Wait for the asynchronous refresh to complete.
181+
} catch (InterruptedException e) {
182+
Thread.currentThread().interrupt(); // Restore the interrupt status
183+
throw new IOException("Interrupted while waiting for asynchronous refresh.", e);
184+
} catch (ExecutionException e) {
185+
Throwable cause = e.getCause(); // Unwrap the underlying cause
186+
if (cause instanceof IOException) {
187+
throw (IOException) cause;
188+
} else {
189+
throw new IOException("Asynchronous refresh failed.", cause);
190+
}
191+
}
192+
} else {
193+
// No asynchronous refresh is running, perform a synchronous refresh.
194+
refreshCredentials();
195+
}
196+
}
197+
}
198+
199+
/**
200+
* Refreshes the intermediate access token and access boundary session key if required.
201+
*
202+
* <p>This method checks the expiration time of the current intermediate access token and
203+
* initiates a refresh if necessary. The refresh process also refreshes the underlying source
204+
* credentials.
205+
*
206+
* @throws IOException If an error occurs during the refresh process, such as network issues,
207+
* invalid credentials, or problems with the token exchange endpoint.
208+
*/
209+
private void refreshCredentialsIfRequired() throws IOException {
210+
AccessToken localAccessToken = intermediateAccessToken;
211+
if (localAccessToken != null) {
212+
Date expirationTime = localAccessToken.getExpirationTime();
213+
if (expirationTime == null) {
214+
return; // Token does not expire, no refresh needed.
215+
}
216+
217+
Duration remaining =
218+
Duration.ofMillis(expirationTime.getTime() - Clock.SYSTEM.currentTimeMillis());
219+
if (remaining.compareTo(minimumTokenLifetime) <= 0) {
220+
// Intermediate token has expired or remaining lifetime is less than the minimum required
221+
// for CAB token generation. Perform a synchronous refresh immediately.
222+
blockingRefresh();
223+
} else if (remaining.compareTo(refreshMargin) <= 0) {
224+
// The token is nearing expiration, start an asynchronous refresh in the background.
225+
startAsynchronousRefresh();
226+
}
227+
} else {
228+
// No intermediate access token exists; a synchronous refresh must be performed.
229+
blockingRefresh();
230+
}
123231
}
124232

125233
public AccessToken generateToken(CredentialAccessBoundary accessBoundary) {
@@ -136,6 +244,8 @@ public static class Builder {
136244
private HttpTransportFactory transportFactory;
137245
private String universeDomain;
138246
private String tokenExchangeEndpoint;
247+
private Duration minimumTokenLifetime;
248+
private Duration refreshMargin;
139249

140250
private Builder() {}
141251

@@ -151,6 +261,48 @@ public Builder setSourceCredential(GoogleCredentials sourceCredential) {
151261
return this;
152262
}
153263

264+
/**
265+
* Sets the minimum acceptable lifetime for a generated CAB token.
266+
*
267+
* <p>This value determines the minimum remaining lifetime required on the intermediate token
268+
* before a CAB token can be generated. If the intermediate token's remaining lifetime is less
269+
* than this value, CAB token generation will be blocked and a refresh will be initiated. This
270+
* ensures that generated CAB tokens have a sufficient lifetime for use.
271+
*
272+
* @param minimumTokenLifetime The minimum acceptable lifetime for a generated CAB token. Must
273+
* be positive.
274+
* @return This {@code Builder} object.
275+
* @throws IllegalArgumentException if minimumTokenLifetime is negative or zero.
276+
*/
277+
@CanIgnoreReturnValue
278+
public Builder setMinimumTokenLifetime(Duration minimumTokenLifetime) {
279+
if (minimumTokenLifetime.isNegative() || minimumTokenLifetime.isZero()) {
280+
throw new IllegalArgumentException("Minimum token lifetime must be positive.");
281+
}
282+
this.minimumTokenLifetime = minimumTokenLifetime;
283+
return this;
284+
}
285+
286+
/**
287+
* Sets the refresh margin for the intermediate access token.
288+
*
289+
* <p>This duration specifies how far in advance of the intermediate access token's expiration
290+
* time an asynchronous refresh should be initiated. If not provided, it will default to 30
291+
* minutes.
292+
*
293+
* @param refreshMargin The refresh margin. Must be positive.
294+
* @return This {@code Builder} object.
295+
* @throws IllegalArgumentException if refreshMargin is negative or zero.
296+
*/
297+
@CanIgnoreReturnValue
298+
public Builder setRefreshMargin(Duration refreshMargin) {
299+
if (refreshMargin.isNegative() || refreshMargin.isZero()) {
300+
throw new IllegalArgumentException("Refresh margin must be positive.");
301+
}
302+
this.refreshMargin = refreshMargin;
303+
return this;
304+
}
305+
154306
/**
155307
* Sets the HTTP transport factory.
156308
*

0 commit comments

Comments
 (0)