Skip to content

Commit 9771ab1

Browse files
committed
Add unit tests for refreshCredentialsIfRequired.
1 parent ddfa7d8 commit 9771ab1

File tree

4 files changed

+319
-38
lines changed

4 files changed

+319
-38
lines changed

cab-token-generator/java/com/google/auth/credentialaccessboundary/ClientSideCredentialAccessBoundaryFactory.java

Lines changed: 69 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,18 @@
6060
import java.util.concurrent.ExecutionException;
6161
import javax.annotation.Nullable;
6262

63-
public final class ClientSideCredentialAccessBoundaryFactory {
63+
public class ClientSideCredentialAccessBoundaryFactory {
64+
static final Duration DEFAULT_REFRESH_MARGIN = Duration.ofMinutes(30);
65+
static final Duration DEFAULT_MINIMUM_TOKEN_LIFETIME = Duration.ofMinutes(3);
6466
private final GoogleCredentials sourceCredential;
6567
private final transient HttpTransportFactory transportFactory;
6668
private final String tokenExchangeEndpoint;
6769
private final Duration minimumTokenLifetime;
6870
private final Duration refreshMargin;
69-
private static final Duration DEFAULT_REFRESH_MARGIN = Duration.ofMinutes(30);
70-
private static final Duration DEFAULT_MINIMUM_TOKEN_LIFETIME = Duration.ofMinutes(3);
7171
private transient RefreshTask refreshTask;
7272
private final Object refreshLock = new byte[0];
7373
private volatile IntermediateCredentials intermediateCredentials = null;
74+
private final Clock clock;
7475

7576
enum RefreshType {
7677
NONE,
@@ -88,10 +89,7 @@ private ClientSideCredentialAccessBoundaryFactory(Builder builder) {
8889
builder.minimumTokenLifetime != null
8990
? builder.minimumTokenLifetime
9091
: DEFAULT_MINIMUM_TOKEN_LIFETIME;
91-
}
92-
93-
public static Builder newBuilder() {
94-
return new Builder();
92+
this.clock = builder.clock;
9593
}
9694

9795
public AccessToken generateToken(CredentialAccessBoundary accessBoundary) {
@@ -110,7 +108,8 @@ public AccessToken generateToken(CredentialAccessBoundary accessBoundary) {
110108
* @throws IOException If an error occurs during the refresh process, such as network issues,
111109
* invalid credentials, or problems with the token exchange endpoint.
112110
*/
113-
private void refreshCredentialsIfRequired() throws IOException {
111+
@VisibleForTesting
112+
void refreshCredentialsIfRequired() throws IOException {
114113
RefreshType refreshType = determineRefreshType();
115114

116115
if (refreshType == RefreshType.NONE) {
@@ -171,14 +170,14 @@ private RefreshType determineRefreshType() {
171170
return RefreshType.BLOCKING;
172171
}
173172

174-
AccessToken localAccessToken = intermediateCredentials.intermediateAccessToken;
175-
Date expirationTime = localAccessToken.getExpirationTime();
173+
AccessToken intermediateAccessToken = intermediateCredentials.intermediateAccessToken;
174+
Date expirationTime = intermediateAccessToken.getExpirationTime();
176175
if (expirationTime == null) {
177176
return RefreshType.NONE; // Token does not expire, no refresh needed.
178177
}
179178

180-
Duration remaining =
181-
Duration.ofMillis(expirationTime.getTime() - Clock.SYSTEM.currentTimeMillis());
179+
Duration remaining = Duration.ofMillis(expirationTime.getTime() - clock.currentTimeMillis());
180+
182181
if (remaining.compareTo(minimumTokenLifetime) <= 0) {
183182
// Intermediate token has expired or remaining lifetime is less than the minimum required
184183
// for CAB token generation. A blocking refresh is necessary.
@@ -208,26 +207,31 @@ private RefreshTask getOrCreateRefreshTask() {
208207
return new RefreshTask(refreshTask.task, false);
209208
}
210209

211-
// No refresh task is currently running. Create and return a new refresh task.
212210
final ListenableFutureTask<IntermediateCredentials> task =
213-
ListenableFutureTask.create(this::refreshCredentials);
214-
return new RefreshTask(task, true);
211+
ListenableFutureTask.create(this::fetchIntermediateCredentials);
212+
213+
// Store the new refresh task in the refreshTask field before returning. This ensures that
214+
// subsequent calls to this method will return the existing task while it's still in progress.
215+
refreshTask = new RefreshTask(task, true);
216+
return refreshTask;
215217
}
216218
}
217219

218220
/**
219-
* Refreshes the source credential and exchanges it for an intermediate access token using the STS
220-
* endpoint.
221+
* Fetches the credentials by refreshing the source credential and exchanging it for an
222+
* intermediate access token using the STS endpoint.
221223
*
222-
* <p>If the source credential is expired, it will be refreshed. A token exchange request is then
223-
* made to the STS endpoint.
224+
* <p>The source credential is refreshed, and a token exchange request is made to the STS endpoint
225+
* to obtain an intermediate access token and an associated access boundary session key. This
226+
* ensures the intermediate access token meets this factory's refresh margin and minimum lifetime
227+
* requirements.
224228
*
225-
* @return The refreshed {@link IntermediateCredentials} containing the intermediate access token
229+
* @return The fetched {@link IntermediateCredentials} containing the intermediate access token
226230
* and access boundary session key.
227231
* @throws IOException If an error occurs during credential refresh or token exchange.
228232
*/
229233
@VisibleForTesting
230-
IntermediateCredentials refreshCredentials() throws IOException {
234+
IntermediateCredentials fetchIntermediateCredentials() throws IOException {
231235
try {
232236
// Force a refresh on the source credentials. The intermediate token's lifetime is tied to the
233237
// source credential's expiration. The factory's refreshMargin might be different from the
@@ -306,6 +310,28 @@ private void finishRefreshTask(ListenableFuture<IntermediateCredentials> finishe
306310
}
307311
}
308312

313+
@VisibleForTesting
314+
String getAccessBoundarySessionKey() {
315+
return intermediateCredentials != null
316+
? intermediateCredentials.accessBoundarySessionKey
317+
: null;
318+
}
319+
320+
@VisibleForTesting
321+
AccessToken getIntermediateAccessToken() {
322+
return intermediateCredentials != null ? intermediateCredentials.intermediateAccessToken : null;
323+
}
324+
325+
@VisibleForTesting
326+
String getTokenExchangeEndpoint() {
327+
return tokenExchangeEndpoint;
328+
}
329+
330+
@VisibleForTesting
331+
HttpTransportFactory getTransportFactory() {
332+
return transportFactory;
333+
}
334+
309335
/**
310336
* Holds intermediate credentials obtained from the STS token exchange endpoint.
311337
*
@@ -372,13 +398,18 @@ public void run() {
372398
}
373399
}
374400

401+
public static Builder newBuilder() {
402+
return new Builder();
403+
}
404+
375405
public static class Builder {
376406
private GoogleCredentials sourceCredential;
377407
private HttpTransportFactory transportFactory;
378408
private String universeDomain;
379409
private String tokenExchangeEndpoint;
380410
private Duration minimumTokenLifetime;
381411
private Duration refreshMargin;
412+
private Clock clock = Clock.SYSTEM; // Default to system clock;
382413

383414
private Builder() {}
384415

@@ -403,14 +434,15 @@ public Builder setSourceCredential(GoogleCredentials sourceCredential) {
403434
* ensures that generated CAB tokens have a sufficient lifetime for use.
404435
*
405436
* @param minimumTokenLifetime The minimum acceptable lifetime for a generated CAB token. Must
406-
* be positive.
437+
* be greater than zero.
407438
* @return This {@code Builder} object.
408439
* @throws IllegalArgumentException if minimumTokenLifetime is negative or zero.
409440
*/
410441
@CanIgnoreReturnValue
411442
public Builder setMinimumTokenLifetime(Duration minimumTokenLifetime) {
443+
checkNotNull(minimumTokenLifetime, "Minimum token lifetime must not be null.");
412444
if (minimumTokenLifetime.isNegative() || minimumTokenLifetime.isZero()) {
413-
throw new IllegalArgumentException("Minimum token lifetime must be positive.");
445+
throw new IllegalArgumentException("Minimum token lifetime must be greater than zero.");
414446
}
415447
this.minimumTokenLifetime = minimumTokenLifetime;
416448
return this;
@@ -423,14 +455,15 @@ public Builder setMinimumTokenLifetime(Duration minimumTokenLifetime) {
423455
* time an asynchronous refresh should be initiated. If not provided, it will default to 30
424456
* minutes.
425457
*
426-
* @param refreshMargin The refresh margin. Must be positive.
458+
* @param refreshMargin The refresh margin. Must be greater than zero.
427459
* @return This {@code Builder} object.
428460
* @throws IllegalArgumentException if refreshMargin is negative or zero.
429461
*/
430462
@CanIgnoreReturnValue
431463
public Builder setRefreshMargin(Duration refreshMargin) {
464+
checkNotNull(refreshMargin, "Refresh margin must not be null.");
432465
if (refreshMargin.isNegative() || refreshMargin.isZero()) {
433-
throw new IllegalArgumentException("Refresh margin must be positive.");
466+
throw new IllegalArgumentException("Refresh margin must be greater than zero.");
434467
}
435468
this.refreshMargin = refreshMargin;
436469
return this;
@@ -460,6 +493,17 @@ public Builder setUniverseDomain(String universeDomain) {
460493
return this;
461494
}
462495

496+
/**
497+
* Set the clock for checking token expiry. Used for testing.
498+
*
499+
* @param clock the clock to use. Defaults to the system clock
500+
* @return the builder
501+
*/
502+
public Builder setClock(Clock clock) {
503+
this.clock = clock;
504+
return this;
505+
}
506+
463507
public ClientSideCredentialAccessBoundaryFactory build() {
464508
checkNotNull(sourceCredential, "Source credential must not be null.");
465509

0 commit comments

Comments
 (0)