Skip to content

Commit 727a245

Browse files
[PECOBLR-131][PECOBLR-180] Add token cache for U2M OAuth by using sdk (#783)
1 parent 907e632 commit 727a245

File tree

12 files changed

+594
-18
lines changed

12 files changed

+594
-18
lines changed

NEXT_CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
## [Unreleased]
44

55
### Added
6-
-
6+
- Support for token cache in OAuth U2M Flow using the configuration parameters: `EnableTokenCache` and `TokenCachePassPhrase`.
77

88
### Updated
99
-

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
<httpclient.version>4.5.14</httpclient.version>
5555
<commons-configuration.version>2.10.1</commons-configuration.version>
5656
<commons-io.version>2.14.0</commons-io.version>
57-
<databricks-sdk.version>0.44.0</databricks-sdk.version>
57+
<databricks-sdk.version>0.46.0</databricks-sdk.version>
5858
<maven-surefire-plugin.version>3.1.2</maven-surefire-plugin.version>
5959
<sql-logic-test.version>0.3</sql-logic-test.version>
6060
<lz4-compression.version>1.8.0</lz4-compression.version>

src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,16 @@ public int getSocketTimeout() {
850850
return Integer.parseInt(getParameter(DatabricksJdbcUrlParams.SOCKET_TIMEOUT));
851851
}
852852

853+
@Override
854+
public String getTokenCachePassPhrase() {
855+
return getParameter(DatabricksJdbcUrlParams.TOKEN_CACHE_PASS_PHRASE);
856+
}
857+
858+
@Override
859+
public boolean isTokenCacheEnabled() {
860+
return getParameter(DatabricksJdbcUrlParams.ENABLE_TOKEN_CACHE).equals("1");
861+
}
862+
853863
private static boolean nullOrEmptyString(String s) {
854864
return s == null || s.isEmpty();
855865
}

src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,4 +293,10 @@ public interface IDatabricksConnectionContext {
293293
* @return true if the system property trust store should be used, false otherwise
294294
*/
295295
boolean useSystemTrustStore();
296+
297+
/** Returns the passphrase used for encrypting/decrypting token cache */
298+
String getTokenCachePassPhrase();
299+
300+
/** Returns whether token caching is enabled for OAuth authentication */
301+
boolean isTokenCacheEnabled();
296302
}
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
package com.databricks.jdbc.auth;
2+
3+
import com.databricks.jdbc.log.JdbcLogger;
4+
import com.databricks.jdbc.log.JdbcLoggerFactory;
5+
import com.databricks.sdk.core.DatabricksException;
6+
import com.databricks.sdk.core.oauth.Token;
7+
import com.databricks.sdk.core.oauth.TokenCache;
8+
import com.databricks.sdk.core.utils.SerDeUtils;
9+
import com.fasterxml.jackson.databind.ObjectMapper;
10+
import java.io.File;
11+
import java.nio.charset.StandardCharsets;
12+
import java.nio.file.Files;
13+
import java.nio.file.Path;
14+
import java.security.SecureRandom;
15+
import java.security.spec.KeySpec;
16+
import java.util.Base64;
17+
import java.util.Objects;
18+
import javax.crypto.Cipher;
19+
import javax.crypto.SecretKey;
20+
import javax.crypto.SecretKeyFactory;
21+
import javax.crypto.spec.IvParameterSpec;
22+
import javax.crypto.spec.PBEKeySpec;
23+
import javax.crypto.spec.SecretKeySpec;
24+
25+
/** A TokenCache implementation that stores tokens in encrypted files. */
26+
public class EncryptedFileTokenCache implements TokenCache {
27+
private static final JdbcLogger LOGGER =
28+
JdbcLoggerFactory.getLogger(EncryptedFileTokenCache.class);
29+
30+
// Encryption constants
31+
private static final String ALGORITHM = "AES";
32+
private static final String TRANSFORMATION = "AES/CBC/PKCS5Padding";
33+
private static final String SECRET_KEY_ALGORITHM = "PBKDF2WithHmacSHA256";
34+
private static final byte[] SALT = "DatabricksJdbcTokenCache".getBytes();
35+
private static final int ITERATION_COUNT = 65536;
36+
private static final int KEY_LENGTH = 256;
37+
private static final int IV_SIZE = 16; // 128 bits
38+
39+
private final Path cacheFile;
40+
private final ObjectMapper mapper;
41+
private final String passphrase;
42+
43+
/**
44+
* Constructs a new EncryptingFileTokenCache instance.
45+
*
46+
* @param cacheFilePath The path where the token cache will be stored
47+
* @param passphrase The passphrase used for encryption
48+
*/
49+
public EncryptedFileTokenCache(Path cacheFilePath, String passphrase) {
50+
Objects.requireNonNull(cacheFilePath, "cacheFilePath must be defined");
51+
Objects.requireNonNull(passphrase, "passphrase must be defined for encrypted token cache");
52+
53+
this.cacheFile = cacheFilePath;
54+
this.mapper = SerDeUtils.createMapper();
55+
this.passphrase = passphrase;
56+
}
57+
58+
@Override
59+
public void save(Token token) throws DatabricksException {
60+
try {
61+
Files.createDirectories(cacheFile.getParent());
62+
63+
// Serialize token to JSON
64+
String json = mapper.writeValueAsString(token);
65+
byte[] dataToWrite = json.getBytes(StandardCharsets.UTF_8);
66+
67+
// Encrypt data
68+
dataToWrite = encrypt(dataToWrite);
69+
70+
Files.write(cacheFile, dataToWrite);
71+
// Set file permissions to be readable only by the owner (equivalent to 0600)
72+
File file = cacheFile.toFile();
73+
file.setReadable(false, false);
74+
file.setReadable(true, true);
75+
file.setWritable(false, false);
76+
file.setWritable(true, true);
77+
78+
LOGGER.debug("Successfully saved encrypted token to cache: %s", cacheFile);
79+
} catch (Exception e) {
80+
throw new DatabricksException("Failed to save token cache: " + e.getMessage(), e);
81+
}
82+
}
83+
84+
@Override
85+
public Token load() {
86+
try {
87+
if (!Files.exists(cacheFile)) {
88+
LOGGER.debug("No token cache file found at: %s", cacheFile);
89+
return null;
90+
}
91+
92+
byte[] fileContent = Files.readAllBytes(cacheFile);
93+
94+
// Decrypt data
95+
byte[] decodedContent;
96+
try {
97+
decodedContent = decrypt(fileContent);
98+
} catch (Exception e) {
99+
LOGGER.debug("Failed to decrypt token cache: %s", e.getMessage());
100+
return null;
101+
}
102+
103+
// Deserialize token from JSON
104+
String json = new String(decodedContent, StandardCharsets.UTF_8);
105+
Token token = mapper.readValue(json, Token.class);
106+
LOGGER.debug("Successfully loaded encrypted token from cache: %s", cacheFile);
107+
return token;
108+
} catch (Exception e) {
109+
// If there's any issue loading the token, return null
110+
// to allow a fresh token to be obtained
111+
LOGGER.debug("Failed to load token from cache: %s", e.getMessage());
112+
return null;
113+
}
114+
}
115+
116+
/**
117+
* Generates a secret key from the passphrase using PBKDF2 with HMAC-SHA256.
118+
*
119+
* @return A SecretKey generated from the passphrase
120+
* @throws Exception If an error occurs generating the key
121+
*/
122+
private SecretKey generateSecretKey() throws Exception {
123+
SecretKeyFactory factory = SecretKeyFactory.getInstance(SECRET_KEY_ALGORITHM);
124+
KeySpec spec = new PBEKeySpec(passphrase.toCharArray(), SALT, ITERATION_COUNT, KEY_LENGTH);
125+
return new SecretKeySpec(factory.generateSecret(spec).getEncoded(), ALGORITHM);
126+
}
127+
128+
/**
129+
* Encrypts the given data using AES/CBC/PKCS5Padding encryption with a key derived from the
130+
* passphrase. The IV is generated randomly and prepended to the encrypted data.
131+
*
132+
* @param data The data to encrypt
133+
* @return The encrypted data with IV prepended
134+
* @throws Exception If an error occurs during encryption
135+
*/
136+
private byte[] encrypt(byte[] data) throws Exception {
137+
Cipher cipher = Cipher.getInstance(TRANSFORMATION);
138+
139+
// Generate a random IV
140+
SecureRandom random = new SecureRandom();
141+
byte[] iv = new byte[IV_SIZE];
142+
random.nextBytes(iv);
143+
IvParameterSpec ivSpec = new IvParameterSpec(iv);
144+
145+
cipher.init(Cipher.ENCRYPT_MODE, generateSecretKey(), ivSpec);
146+
byte[] encryptedData = cipher.doFinal(data);
147+
148+
// Combine IV and encrypted data
149+
byte[] combined = new byte[iv.length + encryptedData.length];
150+
System.arraycopy(iv, 0, combined, 0, iv.length);
151+
System.arraycopy(encryptedData, 0, combined, iv.length, encryptedData.length);
152+
153+
return Base64.getEncoder().encode(combined);
154+
}
155+
156+
/**
157+
* Decrypts the given encrypted data using AES/CBC/PKCS5Padding decryption with a key derived from
158+
* the passphrase. The IV is extracted from the beginning of the encrypted data.
159+
*
160+
* @param encryptedData The encrypted data with IV prepended, Base64 encoded
161+
* @return The decrypted data
162+
* @throws Exception If an error occurs during decryption
163+
*/
164+
private byte[] decrypt(byte[] encryptedData) throws Exception {
165+
byte[] decodedData = Base64.getDecoder().decode(encryptedData);
166+
167+
// Extract IV
168+
byte[] iv = new byte[IV_SIZE];
169+
byte[] actualData = new byte[decodedData.length - IV_SIZE];
170+
System.arraycopy(decodedData, 0, iv, 0, IV_SIZE);
171+
System.arraycopy(decodedData, IV_SIZE, actualData, 0, actualData.length);
172+
173+
Cipher cipher = Cipher.getInstance(TRANSFORMATION);
174+
IvParameterSpec ivSpec = new IvParameterSpec(iv);
175+
cipher.init(Cipher.DECRYPT_MODE, generateSecretKey(), ivSpec);
176+
177+
return cipher.doFinal(actualData);
178+
}
179+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package com.databricks.jdbc.auth;
2+
3+
import com.databricks.jdbc.log.JdbcLogger;
4+
import com.databricks.jdbc.log.JdbcLoggerFactory;
5+
import com.databricks.sdk.core.oauth.Token;
6+
import com.databricks.sdk.core.oauth.TokenCache;
7+
8+
/**
9+
* A no-operation implementation of TokenCache that does nothing. Used when token caching is
10+
* explicitly disabled.
11+
*/
12+
public class NoOpTokenCache implements TokenCache {
13+
private static final JdbcLogger LOGGER = JdbcLoggerFactory.getLogger(NoOpTokenCache.class);
14+
15+
@Override
16+
public void save(Token token) {
17+
LOGGER.debug("Token caching is disabled, skipping save operation");
18+
}
19+
20+
@Override
21+
public Token load() {
22+
LOGGER.debug("Token caching is disabled, skipping load operation");
23+
return null;
24+
}
25+
}

src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ public enum DatabricksJdbcUrlParams {
120120
"DefaultStringColumnLength",
121121
"Maximum number of characters that can be contained in STRING columns",
122122
"255"),
123-
SOCKET_TIMEOUT("socketTimeout", "Socket timeout in seconds", "900");
123+
SOCKET_TIMEOUT("socketTimeout", "Socket timeout in seconds", "900"),
124+
TOKEN_CACHE_PASS_PHRASE("TokenCachePassPhrase", "Pass phrase to use for OAuth U2M Token Cache"),
125+
ENABLE_TOKEN_CACHE("EnableTokenCache", "Enable caching OAuth tokens", "1");
124126

125127
private final String paramName;
126128
private final String defaultValue;

src/main/java/com/databricks/jdbc/dbclient/impl/common/ClientConfigurator.java

Lines changed: 75 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
import static com.databricks.jdbc.common.util.DatabricksAuthUtil.initializeConfigWithToken;
55

66
import com.databricks.jdbc.api.internal.IDatabricksConnectionContext;
7-
import com.databricks.jdbc.auth.AzureMSICredentialProvider;
8-
import com.databricks.jdbc.auth.OAuthRefreshCredentialsProvider;
9-
import com.databricks.jdbc.auth.PrivateKeyClientCredentialProvider;
7+
import com.databricks.jdbc.auth.*;
108
import com.databricks.jdbc.common.AuthMech;
119
import com.databricks.jdbc.common.DatabricksJdbcConstants;
1210
import com.databricks.jdbc.common.util.DriverUtil;
@@ -20,9 +18,13 @@
2018
import com.databricks.sdk.core.DatabricksException;
2119
import com.databricks.sdk.core.ProxyConfig;
2220
import com.databricks.sdk.core.commons.CommonsHttpClient;
21+
import com.databricks.sdk.core.oauth.ExternalBrowserCredentialsProvider;
22+
import com.databricks.sdk.core.oauth.TokenCache;
2323
import com.databricks.sdk.core.utils.Cloud;
2424
import java.io.IOException;
2525
import java.net.ServerSocket;
26+
import java.nio.file.Path;
27+
import java.nio.file.Paths;
2628
import java.util.ArrayList;
2729
import java.util.Arrays;
2830
import java.util.List;
@@ -52,6 +54,53 @@ public ClientConfigurator(IDatabricksConnectionContext connectionContext) {
5254
this.databricksConfig.resolve();
5355
}
5456

57+
/**
58+
* Returns the path for the token cache file based on host, client ID, and scopes. This creates a
59+
* unique cache path using a hash of these parameters.
60+
*
61+
* @param host The host URL
62+
* @param clientId The OAuth client ID
63+
* @param scopes The OAuth scopes
64+
* @return The path for the token cache file
65+
*/
66+
public static Path getTokenCachePath(String host, String clientId, List<String> scopes) {
67+
String userHome = System.getProperty("user.home");
68+
Path homeDir = Paths.get(userHome);
69+
Path databricksDir = homeDir.resolve(".config/databricks-jdbc/oauth");
70+
71+
// Create a unique string identifier from the combination of parameters
72+
String uniqueIdentifier = createUniqueIdentifier(host, clientId, scopes);
73+
74+
String filename = "token-cache-" + uniqueIdentifier;
75+
76+
return databricksDir.resolve(filename);
77+
}
78+
79+
/**
80+
* Creates a unique identifier string from the given parameters. Uses a hash function to create a
81+
* compact representation.
82+
*
83+
* @param host The host URL
84+
* @param clientId The OAuth client ID
85+
* @param scopes The OAuth scopes
86+
* @return A unique identifier string
87+
*/
88+
private static String createUniqueIdentifier(String host, String clientId, List<String> scopes) {
89+
// Normalize inputs to handle null values
90+
host = (host != null) ? host : EMPTY_STRING;
91+
clientId = (clientId != null) ? clientId : EMPTY_STRING;
92+
scopes = (scopes != null) ? scopes : List.of();
93+
94+
// Combine all parameters
95+
String combined = host + URL_DELIMITER + clientId + URL_DELIMITER + String.join(COMMA, scopes);
96+
97+
// Create a hash from the combined string
98+
int hash = combined.hashCode();
99+
100+
// Convert to a positive hexadecimal string
101+
return Integer.toHexString(hash & 0x7FFFFFFF);
102+
}
103+
55104
/**
56105
* Setup the SSL configuration in the httpClientBuilder.
57106
*
@@ -136,10 +185,13 @@ public void setupU2MConfig() throws DatabricksParsingException {
136185
int redirectPort = findAvailablePort(connectionContext.getOAuth2RedirectUrlPorts());
137186
String redirectUrl = String.format("http://localhost:%d", redirectPort);
138187

188+
String host = connectionContext.getHostForOAuth();
189+
String clientId = connectionContext.getClientId();
190+
139191
databricksConfig
140192
.setAuthType(DatabricksJdbcConstants.U2M_AUTH_TYPE)
141-
.setHost(connectionContext.getHostForOAuth())
142-
.setClientId(connectionContext.getClientId())
193+
.setHost(host)
194+
.setClientId(clientId)
143195
.setClientSecret(connectionContext.getClientSecret())
144196
.setOAuthRedirectUrl(redirectUrl);
145197

@@ -148,6 +200,21 @@ public void setupU2MConfig() throws DatabricksParsingException {
148200
if (!databricksConfig.isAzure()) {
149201
databricksConfig.setScopes(connectionContext.getOAuthScopesForU2M());
150202
}
203+
204+
TokenCache tokenCache;
205+
if (connectionContext.isTokenCacheEnabled()) {
206+
if (connectionContext.getTokenCachePassPhrase() == null) {
207+
LOGGER.error("No token cache passphrase configured");
208+
throw new DatabricksException("No token cache passphrase configured");
209+
}
210+
Path tokenCachePath = getTokenCachePath(host, clientId, databricksConfig.getScopes());
211+
tokenCache =
212+
new EncryptedFileTokenCache(tokenCachePath, connectionContext.getTokenCachePassPhrase());
213+
} else {
214+
tokenCache = new NoOpTokenCache();
215+
}
216+
CredentialsProvider provider = new ExternalBrowserCredentialsProvider(tokenCache);
217+
databricksConfig.setCredentialsProvider(provider).setAuthType(provider.authType());
151218
}
152219

153220
/**
@@ -229,14 +296,13 @@ public void resetAccessTokenInConfig(String newAccessToken) {
229296

230297
/** Setup the OAuth U2M refresh token authentication settings in the databricks config. */
231298
public void setupU2MRefreshConfig() throws DatabricksParsingException {
232-
CredentialsProvider provider =
233-
new OAuthRefreshCredentialsProvider(connectionContext, databricksConfig);
234299
databricksConfig
235300
.setHost(connectionContext.getHostForOAuth())
236-
.setAuthType(provider.authType()) // oauth-refresh
237-
.setCredentialsProvider(provider)
238301
.setClientId(connectionContext.getClientId())
239302
.setClientSecret(connectionContext.getClientSecret());
303+
CredentialsProvider provider =
304+
new OAuthRefreshCredentialsProvider(connectionContext, databricksConfig);
305+
databricksConfig.setAuthType(provider.authType()).setCredentialsProvider(provider);
240306
}
241307

242308
/** Setup the OAuth M2M authentication settings in the databricks config. */

0 commit comments

Comments
 (0)