Skip to content

Commit 55cc9ed

Browse files
[FEATURE] Adding token cache to u2m oauth (#429)
## What changes are proposed in this pull request? **What**: Introduce a token cache to be used in browser based u2m oauth **Why**: Makes user experience better by using valid access/refresh tokens from the cache instead of requiring a new browser auth ## How is this tested? In addition to the unit tests, Manual testing: - ran example program without the cache, performed browser action and saw that a cache was created - reconnected the program and program worked without the need for browser action - changed the expiry in cache to past time and observed that token was refreshed - changed the refresh token in cache to be garbage, saw refresh failing and performed browser action
1 parent cf604c1 commit 55cc9ed

File tree

15 files changed

+826
-26
lines changed

15 files changed

+826
-26
lines changed

NEXT_CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
## Release v0.46.0
44

55
### New Features and Improvements
6-
6+
* Added `TokenCache` to `ExternalBrowserCredentialsProvider` to reduce number of authentications needed for U2M OAuth.
7+
78
### Bug Fixes
89

910
### Documentation

databricks-sdk-java/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,5 +97,11 @@
9797
<artifactId>google-auth-library-oauth2-http</artifactId>
9898
<version>1.20.0</version>
9999
</dependency>
100+
<!-- Jackson JSR310 module needed to serialize/deserialize java.time classes in TokenCache -->
101+
<dependency>
102+
<groupId>com.fasterxml.jackson.datatype</groupId>
103+
<artifactId>jackson-datatype-jsr310</artifactId>
104+
<version>${jackson.version}</version>
105+
</dependency>
100106
</dependencies>
101107
</project>

databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,4 +669,14 @@ public DatabricksConfig newWithWorkspaceHost(String host) {
669669
"headerFactory"));
670670
return clone(fieldsToSkip).setHost(host);
671671
}
672+
673+
/**
674+
* Gets the default OAuth redirect URL. If one is not provided explicitly, uses
675+
* http://localhost:8080/callback
676+
*
677+
* @return The OAuth redirect URL to use
678+
*/
679+
public String getEffectiveOAuthRedirectUrl() {
680+
return redirectUrl != null ? redirectUrl : "http://localhost:8080/callback";
681+
}
672682
}

databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java

Lines changed: 103 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,38 @@
55
import com.databricks.sdk.core.DatabricksException;
66
import com.databricks.sdk.core.HeaderFactory;
77
import java.io.IOException;
8+
import java.nio.file.Path;
9+
import java.util.Objects;
10+
import org.slf4j.Logger;
11+
import org.slf4j.LoggerFactory;
812

913
/**
1014
* A {@code CredentialsProvider} which implements the Authorization Code + PKCE flow by opening a
11-
* browser for the user to authorize the application.
15+
* browser for the user to authorize the application. Uses a specified TokenCache or creates a
16+
* default one if none is provided.
1217
*/
1318
public class ExternalBrowserCredentialsProvider implements CredentialsProvider {
19+
private static final Logger LOGGER =
20+
LoggerFactory.getLogger(ExternalBrowserCredentialsProvider.class);
21+
22+
private TokenCache tokenCache;
23+
24+
/**
25+
* Creates a new ExternalBrowserCredentialsProvider with the specified TokenCache.
26+
*
27+
* @param tokenCache the TokenCache to use for caching tokens
28+
*/
29+
public ExternalBrowserCredentialsProvider(TokenCache tokenCache) {
30+
this.tokenCache = tokenCache;
31+
}
32+
33+
/**
34+
* Creates a new ExternalBrowserCredentialsProvider with a default TokenCache. A FileTokenCache
35+
* will be created when credentials are configured.
36+
*/
37+
public ExternalBrowserCredentialsProvider() {
38+
this(null);
39+
}
1440

1541
@Override
1642
public String authType() {
@@ -19,16 +45,87 @@ public String authType() {
1945

2046
@Override
2147
public HeaderFactory configure(DatabricksConfig config) {
22-
if (config.getHost() == null || config.getAuthType() != "external-browser") {
48+
if (config.getHost() == null || !Objects.equals(config.getAuthType(), "external-browser")) {
2349
return null;
2450
}
51+
52+
// Use the utility class to resolve client ID and client secret
53+
String clientId = OAuthClientUtils.resolveClientId(config);
54+
String clientSecret = OAuthClientUtils.resolveClientSecret(config);
55+
2556
try {
26-
OAuthClient client = new OAuthClient(config);
27-
Consent consent = client.initiateConsent();
28-
SessionCredentials creds = consent.launchExternalBrowser();
29-
return creds.configure(config);
57+
if (tokenCache == null) {
58+
// Create a default FileTokenCache based on config
59+
Path cachePath =
60+
TokenCacheUtils.getCacheFilePath(config.getHost(), clientId, config.getScopes());
61+
tokenCache = new FileTokenCache(cachePath);
62+
}
63+
64+
// First try to use the cached token if available (will return null if disabled)
65+
Token cachedToken = tokenCache.load();
66+
if (cachedToken != null && cachedToken.getRefreshToken() != null) {
67+
LOGGER.debug("Found cached token for {}:{}", config.getHost(), clientId);
68+
69+
try {
70+
// Create SessionCredentials with the cached token and try to refresh if needed
71+
SessionCredentials cachedCreds =
72+
new SessionCredentials.Builder()
73+
.withToken(cachedToken)
74+
.withHttpClient(config.getHttpClient())
75+
.withClientId(clientId)
76+
.withClientSecret(clientSecret)
77+
.withTokenUrl(config.getOidcEndpoints().getTokenEndpoint())
78+
.withRedirectUrl(config.getEffectiveOAuthRedirectUrl())
79+
.withTokenCache(tokenCache)
80+
.build();
81+
82+
LOGGER.debug("Using cached token, will immediately refresh");
83+
cachedCreds.token = cachedCreds.refresh();
84+
return cachedCreds.configure(config);
85+
} catch (Exception e) {
86+
// If token refresh fails, log and continue to browser auth
87+
LOGGER.info("Token refresh failed: {}, falling back to browser auth", e.getMessage());
88+
}
89+
}
90+
91+
// If no cached token or refresh failed, perform browser auth
92+
SessionCredentials credentials =
93+
performBrowserAuth(config, clientId, clientSecret, tokenCache);
94+
tokenCache.save(credentials.getToken());
95+
return credentials.configure(config);
3096
} catch (IOException | DatabricksException e) {
97+
LOGGER.error("Failed to authenticate: {}", e.getMessage());
3198
return null;
3299
}
33100
}
101+
102+
SessionCredentials performBrowserAuth(
103+
DatabricksConfig config, String clientId, String clientSecret, TokenCache tokenCache)
104+
throws IOException {
105+
LOGGER.debug("Performing browser authentication");
106+
OAuthClient client =
107+
new OAuthClient.Builder()
108+
.withHttpClient(config.getHttpClient())
109+
.withClientId(clientId)
110+
.withClientSecret(clientSecret)
111+
.withHost(config.getHost())
112+
.withRedirectUrl(config.getEffectiveOAuthRedirectUrl())
113+
.withScopes(config.getScopes())
114+
.build();
115+
Consent consent = client.initiateConsent();
116+
117+
// Use the existing browser flow to get credentials
118+
SessionCredentials credentials = consent.launchExternalBrowser();
119+
120+
// Create a new SessionCredentials with the same token but with our token cache
121+
return new SessionCredentials.Builder()
122+
.withToken(credentials.getToken())
123+
.withHttpClient(config.getHttpClient())
124+
.withClientId(config.getClientId())
125+
.withClientSecret(config.getClientSecret())
126+
.withTokenUrl(config.getOidcEndpoints().getTokenEndpoint())
127+
.withRedirectUrl(config.getEffectiveOAuthRedirectUrl())
128+
.withTokenCache(tokenCache)
129+
.build();
130+
}
34131
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package com.databricks.sdk.core.oauth;
2+
3+
import com.databricks.sdk.core.utils.SerDeUtils;
4+
import com.fasterxml.jackson.databind.ObjectMapper;
5+
import java.io.File;
6+
import java.nio.charset.StandardCharsets;
7+
import java.nio.file.Files;
8+
import java.nio.file.Path;
9+
import java.util.Objects;
10+
import org.slf4j.Logger;
11+
import org.slf4j.LoggerFactory;
12+
13+
/** A TokenCache implementation that stores tokens as plain files. */
14+
public class FileTokenCache implements TokenCache {
15+
private static final Logger LOGGER = LoggerFactory.getLogger(FileTokenCache.class);
16+
17+
private final Path cacheFile;
18+
private final ObjectMapper mapper;
19+
20+
/**
21+
* Constructs a new SimpleFileTokenCache instance.
22+
*
23+
* @param cacheFilePath The path where the token cache will be stored
24+
*/
25+
public FileTokenCache(Path cacheFilePath) {
26+
Objects.requireNonNull(cacheFilePath, "cacheFilePath must be defined");
27+
28+
this.cacheFile = cacheFilePath;
29+
this.mapper = SerDeUtils.createMapper();
30+
}
31+
32+
@Override
33+
public void save(Token token) {
34+
try {
35+
Files.createDirectories(cacheFile.getParent());
36+
37+
// Serialize token to JSON
38+
String json = mapper.writeValueAsString(token);
39+
byte[] dataToWrite = json.getBytes(StandardCharsets.UTF_8);
40+
41+
Files.write(cacheFile, dataToWrite);
42+
// Set file permissions to be readable only by the owner (equivalent to 0600)
43+
File file = cacheFile.toFile();
44+
file.setReadable(false, false);
45+
file.setReadable(true, true);
46+
file.setWritable(false, false);
47+
file.setWritable(true, true);
48+
49+
LOGGER.debug("Successfully saved token to cache: {}", cacheFile);
50+
} catch (Exception e) {
51+
LOGGER.warn("Failed to save token to cache: {}", cacheFile, e);
52+
}
53+
}
54+
55+
@Override
56+
public Token load() {
57+
try {
58+
if (!Files.exists(cacheFile)) {
59+
LOGGER.debug("No token cache file found at: {}", cacheFile);
60+
return null;
61+
}
62+
63+
byte[] fileContent = Files.readAllBytes(cacheFile);
64+
65+
// Deserialize token from JSON
66+
String json = new String(fileContent, StandardCharsets.UTF_8);
67+
Token token = mapper.readValue(json, Token.class);
68+
LOGGER.debug("Successfully loaded token from cache: {}", cacheFile);
69+
return token;
70+
} catch (Exception e) {
71+
// If there's any issue loading the token, return null
72+
// to allow a fresh token to be obtained
73+
LOGGER.warn("Failed to load token from cache: {}", e.getMessage());
74+
return null;
75+
}
76+
}
77+
}

databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthClient.java

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,20 +85,6 @@ public OAuthClient build() throws IOException {
8585
private final boolean isAws;
8686
private final boolean isAzure;
8787

88-
public OAuthClient(DatabricksConfig config) throws IOException {
89-
this(
90-
new Builder()
91-
.withHttpClient(config.getHttpClient())
92-
.withClientId(config.getClientId())
93-
.withClientSecret(config.getClientSecret())
94-
.withHost(config.getHost())
95-
.withRedirectUrl(
96-
config.getOAuthRedirectUrl() != null
97-
? config.getOAuthRedirectUrl()
98-
: "http://localhost:8080/callback")
99-
.withScopes(config.getScopes()));
100-
}
101-
10288
private OAuthClient(Builder b) throws IOException {
10389
this.clientId = Objects.requireNonNull(b.clientId);
10490
this.clientSecret = b.clientSecret;
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package com.databricks.sdk.core.oauth;
2+
3+
import com.databricks.sdk.core.DatabricksConfig;
4+
5+
/** Utility methods for OAuth client credentials resolution. */
6+
public class OAuthClientUtils {
7+
8+
/** Default client ID to use when no client ID is specified. */
9+
private static final String DEFAULT_CLIENT_ID = "databricks-cli";
10+
11+
/**
12+
* Resolves the OAuth client ID from the configuration. Prioritizes regular OAuth client ID, then
13+
* Azure client ID, and falls back to default client ID.
14+
*
15+
* @param config The Databricks configuration
16+
* @return The resolved client ID
17+
*/
18+
public static String resolveClientId(DatabricksConfig config) {
19+
if (config.getClientId() != null) {
20+
return config.getClientId();
21+
} else if (config.getAzureClientId() != null) {
22+
return config.getAzureClientId();
23+
}
24+
return DEFAULT_CLIENT_ID;
25+
}
26+
27+
/**
28+
* Resolves the OAuth client secret from the configuration. Prioritizes regular OAuth client
29+
* secret, then Azure client secret.
30+
*
31+
* @param config The Databricks configuration
32+
* @return The resolved client secret, or null if not present
33+
*/
34+
public static String resolveClientSecret(DatabricksConfig config) {
35+
if (config.getClientSecret() != null) {
36+
return config.getClientSecret();
37+
} else if (config.getAzureClientSecret() != null) {
38+
return config.getAzureClientSecret();
39+
}
40+
return null;
41+
}
42+
}

databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import java.util.HashMap;
1010
import java.util.Map;
1111
import org.apache.http.HttpHeaders;
12+
import org.slf4j.Logger;
13+
import org.slf4j.LoggerFactory;
1214

1315
/**
1416
* An implementation of RefreshableTokenSource implementing the refresh_token OAuth grant type.
@@ -20,6 +22,7 @@
2022
public class SessionCredentials extends RefreshableTokenSource
2123
implements CredentialsProvider, Serializable {
2224
private static final long serialVersionUID = 3083941540130596650L;
25+
private static final Logger LOGGER = LoggerFactory.getLogger(SessionCredentials.class);
2326

2427
@Override
2528
public String authType() {
@@ -43,6 +46,7 @@ static class Builder {
4346
private String redirectUrl;
4447
private String clientId;
4548
private String clientSecret;
49+
private TokenCache tokenCache;
4650

4751
public Builder withHttpClient(HttpClient hc) {
4852
this.hc = hc;
@@ -74,6 +78,11 @@ public Builder withClientSecret(String clientSecret) {
7478
return this;
7579
}
7680

81+
public Builder withTokenCache(TokenCache tokenCache) {
82+
this.tokenCache = tokenCache;
83+
return this;
84+
}
85+
7786
public SessionCredentials build() {
7887
return new SessionCredentials(this);
7988
}
@@ -84,6 +93,7 @@ public SessionCredentials build() {
8493
private final String redirectUrl;
8594
private final String clientId;
8695
private final String clientSecret;
96+
private final TokenCache tokenCache;
8797

8898
private SessionCredentials(Builder b) {
8999
super(b.token);
@@ -92,6 +102,7 @@ private SessionCredentials(Builder b) {
92102
this.redirectUrl = b.redirectUrl;
93103
this.clientId = b.clientId;
94104
this.clientSecret = b.clientSecret;
105+
this.tokenCache = b.tokenCache;
95106
}
96107

97108
@Override
@@ -113,7 +124,15 @@ protected Token refresh() {
113124
// cross-origin requests
114125
headers.put("Origin", redirectUrl);
115126
}
116-
return retrieveToken(
117-
hc, clientId, clientSecret, tokenUrl, params, headers, AuthParameterPosition.BODY);
127+
Token newToken =
128+
retrieveToken(
129+
hc, clientId, clientSecret, tokenUrl, params, headers, AuthParameterPosition.BODY);
130+
131+
// Save the refreshed token directly to cache
132+
if (tokenCache != null) {
133+
tokenCache.save(newToken);
134+
LOGGER.debug("Saved refreshed token to cache");
135+
}
136+
return newToken;
118137
}
119138
}

0 commit comments

Comments
 (0)