Skip to content
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package com.google.auth.credentialaccessboundary;

import static com.google.auth.oauth2.OAuth2Credentials.getFromServiceLoader;
import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Preconditions.checkNotNull;

import com.google.auth.Credentials;
import com.google.auth.http.HttpTransportFactory;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.auth.oauth2.OAuth2Utils;
import com.google.auth.oauth2.StsRequestHandler;
import com.google.auth.oauth2.StsTokenExchangeRequest;
import com.google.auth.oauth2.StsTokenExchangeResponse;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.io.IOException;

public final class ClientSideCredentialAccessBoundaryFactory {
private final GoogleCredentials sourceCredential;
private final transient HttpTransportFactory transportFactory;
private final String tokenExchangeEndpoint;
private String acceessBoundarySessionKey;
private AccessToken intermediaryAccessToken;

private ClientSideCredentialAccessBoundaryFactory(Builder builder) {
this.transportFactory =
firstNonNull(
builder.transportFactory,
getFromServiceLoader(HttpTransportFactory.class, OAuth2Utils.HTTP_TRANSPORT_FACTORY));
this.sourceCredential = checkNotNull(builder.sourceCredential);

// Default to GDU when not supplied.
String universeDomain;
if (builder.universeDomain == null || builder.universeDomain.trim().isEmpty()) {
universeDomain = Credentials.GOOGLE_DEFAULT_UNIVERSE;
} else {
universeDomain = builder.universeDomain;
}

// Ensure source credential's universe domain matches.
try {
if (!universeDomain.equals(sourceCredential.getUniverseDomain())) {
throw new IllegalArgumentException(
"The client side access boundary credential's universe domain must be the same as the source "
+ "credential.");
}
} catch (IOException e) {
// Throwing an IOException would be a breaking change, so wrap it here.
throw new IllegalStateException(
"Error occurred when attempting to retrieve source credential universe domain.", e);
}
String TOKEN_EXCHANGE_URL_FORMAT = "https://sts.{universe_domain}/v1/token";
this.tokenExchangeEndpoint =
TOKEN_EXCHANGE_URL_FORMAT.replace("{universe_domain}", universeDomain);
}

public void fetchCredentials() throws IOException {
try {
this.sourceCredential.refreshIfExpired();
} catch (IOException e) {
throw new IOException("Unable to refresh the provided source credential.", e);
}

AccessToken sourceAccessToken = sourceCredential.getAccessToken();
if (sourceAccessToken == null || sourceAccessToken.getTokenValue() == null) {
throw new IOException("The source credential does not have an access token.");
}

StsTokenExchangeRequest request =
StsTokenExchangeRequest.newBuilder(
sourceAccessToken.getTokenValue(), OAuth2Utils.TOKEN_TYPE_ACCESS_TOKEN)
.setRequestTokenType(OAuth2Utils.TOKEN_TYPE_ACCESS_BOUNDARY_INTERMEDIARY_TOKEN)
.build();

StsRequestHandler handler =
StsRequestHandler.newBuilder(
tokenExchangeEndpoint, request, transportFactory.create().createRequestFactory())
.build();

StsTokenExchangeResponse response = handler.exchangeToken();
this.acceessBoundarySessionKey = response.getAccessBoundarySessionKey();
this.intermediaryAccessToken = response.getAccessToken();

// The STS endpoint will only return the expiration time for the intermediary token
// if the original access token represents a service account.
// The intermediary token's expiration time will always match the source credential expiration.
// When no expires_in is returned, we can copy the source credential's expiration time.
if (response.getAccessToken().getExpirationTime() == null) {
if (sourceAccessToken.getExpirationTime() != null) {
this.intermediaryAccessToken =
new AccessToken(
response.getAccessToken().getTokenValue(), sourceAccessToken.getExpirationTime());
}
}
}

public static Builder newBuilder() {
return new Builder();
}

public static class Builder {
private GoogleCredentials sourceCredential;
private HttpTransportFactory transportFactory;
private String universeDomain;

private Builder() {}

/**
* Sets the required source credential used to acquire the intermediary credential.
*
* @param sourceCredential the {@code GoogleCredentials} to set
* @return this {@code Builder} object
*/
public Builder setSourceCredential(GoogleCredentials sourceCredential) {
this.sourceCredential = sourceCredential;
return this;
}

/**
* Sets the HTTP transport factory.
*
* @param transportFactory the {@code HttpTransportFactory} to set
* @return this {@code Builder} object
*/
@CanIgnoreReturnValue
public Builder setHttpTransportFactory(HttpTransportFactory transportFactory) {
this.transportFactory = transportFactory;
return this;
}

/**
* Sets the optional universe domain.
*
* @param universeDomain the universe domain to set
* @return this {@code Builder} object
*/
@CanIgnoreReturnValue
public Builder setUniverseDomain(String universeDomain) {
this.universeDomain = universeDomain;
return this;
}

public ClientSideCredentialAccessBoundaryFactory build() {
return new ClientSideCredentialAccessBoundaryFactory(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public final class CredentialAccessBoundary {
/**
* Internal method that returns the JSON string representation of the credential access boundary.
*/
String toJson() {
public String toJson() {
List<GenericJson> rules = new ArrayList<>();
for (AccessBoundaryRule rule : accessBoundaryRules) {
GenericJson ruleJson = new GenericJson();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ protected static <T> T newInstance(String className) throws IOException, ClassNo
}
}

protected static <T> T getFromServiceLoader(Class<? extends T> clazz, T defaultInstance) {
public static <T> T getFromServiceLoader(Class<? extends T> clazz, T defaultInstance) {
return Iterables.getFirst(ServiceLoader.load(clazz), defaultInstance);
}

Expand Down
11 changes: 8 additions & 3 deletions oauth2_http/java/com/google/auth/oauth2/OAuth2Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,15 @@
import java.util.Set;

/** Internal utilities for the com.google.auth.oauth2 namespace. */
class OAuth2Utils {
public class OAuth2Utils {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it makes sense for this to be public @lqiu96

Copy link
Member

@lqiu96 lqiu96 Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One option for this to see if we can keep this package-private and move some of these constants to STSRequestHandler. It looks like TOKEN_TYPE_ACCESS_BOUNDARY_INTERMEDIARY_TOKEN and TOKEN_EXCHANGE_URL_FORMAT are new additions could live in the individual classes (i.e StsREquestHandler or DownscopedCredentials).

TOKEN_TYPE_ACCESS_TOKEN is used in a few places, but if it makes sense we can move them around. I don't know if TOKEN_TYPE_ACCESS_TOKEN` applies to non-CAB use cases/ if we rather keep it in this Utils class.

Otherwise, I think I'm ok with making it public given the module constraints we have.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to leave this comment open and keep it as-is for now so I can merge this PR and unblock #1571. I'll address this comment in the next PR.


static final String SIGNATURE_ALGORITHM = "SHA256withRSA";

static final String TOKEN_TYPE_ACCESS_TOKEN = "urn:ietf:params:oauth:token-type:access_token";
public static final String TOKEN_TYPE_ACCESS_TOKEN =
"urn:ietf:params:oauth:token-type:access_token";
static final String TOKEN_TYPE_TOKEN_EXCHANGE = "urn:ietf:params:oauth:token-type:token-exchange";
public static final String TOKEN_TYPE_ACCESS_BOUNDARY_INTERMEDIARY_TOKEN =
"urn:ietf:params:oauth:token-type:access_boundary_intermediary_token";
static final String GRANT_TYPE_JWT_BEARER = "urn:ietf:params:oauth:grant-type:jwt-bearer";

// generateIdToken endpoint is to be formatted with universe domain and client email
Expand All @@ -93,7 +97,8 @@ class OAuth2Utils {

static final HttpTransport HTTP_TRANSPORT = new NetHttpTransport();

static final HttpTransportFactory HTTP_TRANSPORT_FACTORY = new DefaultHttpTransportFactory();
public static final HttpTransportFactory HTTP_TRANSPORT_FACTORY =
new DefaultHttpTransportFactory();

static final JsonFactory JSON_FACTORY = GsonFactory.getDefaultInstance();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
import javax.annotation.Nullable;

/** Implements the OAuth 2.0 token exchange based on https://tools.ietf.org/html/rfc8693. */
final class StsRequestHandler {
public final class StsRequestHandler {
private static final String TOKEN_EXCHANGE_GRANT_TYPE =
"urn:ietf:params:oauth:grant-type:token-exchange";
private static final String PARSE_ERROR_PREFIX = "Error parsing token response.";
Expand Down Expand Up @@ -175,6 +175,11 @@ private StsTokenExchangeResponse buildResponse(GenericData responseData) throws
String scope = OAuth2Utils.validateString(responseData, "scope", PARSE_ERROR_PREFIX);
builder.setScopes(Arrays.asList(scope.trim().split("\\s+")));
}
if (responseData.containsKey("access_boundary_session_key")) {
builder.setAccessBoundarySessionKey(
OAuth2Utils.validateString(
responseData, "access_boundary_session_key", PARSE_ERROR_PREFIX));
}
return builder.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
* Defines an OAuth 2.0 token exchange request. Based on
* https://tools.ietf.org/html/rfc8693#section-2.1.
*/
final class StsTokenExchangeRequest {
public final class StsTokenExchangeRequest {
private static final String GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange";

private final String subjectToken;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,24 @@
* Defines an OAuth 2.0 token exchange successful response. Based on
* https://tools.ietf.org/html/rfc8693#section-2.2.1.
*/
final class StsTokenExchangeResponse {
public final class StsTokenExchangeResponse {
private final AccessToken accessToken;
private final String issuedTokenType;
private final String tokenType;

@Nullable private final Long expiresInSeconds;
@Nullable private final String refreshToken;
@Nullable private final List<String> scopes;
@Nullable private final String accessBoundarySessionKey;

private StsTokenExchangeResponse(
String accessToken,
String issuedTokenType,
String tokenType,
@Nullable Long expiresInSeconds,
@Nullable String refreshToken,
@Nullable List<String> scopes) {
@Nullable List<String> scopes,
@Nullable String accessBoundarySessionKey) {
checkNotNull(accessToken);

this.expiresInSeconds = expiresInSeconds;
Expand All @@ -71,6 +73,7 @@ private StsTokenExchangeResponse(
this.tokenType = checkNotNull(tokenType);
this.refreshToken = refreshToken;
this.scopes = scopes;
this.accessBoundarySessionKey = accessBoundarySessionKey;
}

public static Builder newBuilder(String accessToken, String issuedTokenType, String tokenType) {
Expand Down Expand Up @@ -107,6 +110,11 @@ public List<String> getScopes() {
return new ArrayList<>(scopes);
}

@Nullable
public String getAccessBoundarySessionKey() {
return accessBoundarySessionKey;
}

public static class Builder {
private final String accessToken;
private final String issuedTokenType;
Expand All @@ -115,6 +123,7 @@ public static class Builder {
@Nullable private Long expiresInSeconds;
@Nullable private String refreshToken;
@Nullable private List<String> scopes;
@Nullable private String accessBoundarySessionKey;

private Builder(String accessToken, String issuedTokenType, String tokenType) {
this.accessToken = accessToken;
Expand Down Expand Up @@ -142,9 +151,22 @@ public StsTokenExchangeResponse.Builder setScopes(List<String> scopes) {
return this;
}

@CanIgnoreReturnValue
public StsTokenExchangeResponse.Builder setAccessBoundarySessionKey(
String accessBoundarySessionKey) {
this.accessBoundarySessionKey = accessBoundarySessionKey;
return this;
}

public StsTokenExchangeResponse build() {
return new StsTokenExchangeResponse(
accessToken, issuedTokenType, tokenType, expiresInSeconds, refreshToken, scopes);
accessToken,
issuedTokenType,
tokenType,
expiresInSeconds,
refreshToken,
scopes,
accessBoundarySessionKey);
}
}
}