Skip to content

Commit 01d5645

Browse files
added credential provider OAuthRefresh only (#419)
added credential provider for OAuthRefresh only
1 parent 7765156 commit 01d5645

11 files changed

+427
-99
lines changed

src/main/java/com/databricks/jdbc/api/IDatabricksConnectionContext.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,4 +182,6 @@ public static AuthMech parseAuthMech(String authMech) {
182182

183183
/** Returns the OAuth2 authentication scope used in the request. */
184184
String getAuthScope();
185+
186+
String getOAuthRefreshToken();
185187
}

src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,4 +618,9 @@ public String getOAuthDiscoveryURL() {
618618
public String getAuthScope() {
619619
return getParameter(AUTH_SCOPE, ALL_APIS_SCOPE);
620620
}
621+
622+
@Override
623+
public String getOAuthRefreshToken() {
624+
return getParameter(OAUTH_REFRESH_TOKEN);
625+
}
621626
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package com.databricks.jdbc.auth;
2+
3+
public class AuthConstants {
4+
public static final String GRANT_TYPE_REFRESH_TOKEN_KEY = "refresh_token";
5+
public static final String GRANT_TYPE_KEY = "grant_type";
6+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package com.databricks.jdbc.auth;
2+
3+
import com.databricks.jdbc.api.IDatabricksConnectionContext;
4+
import com.databricks.jdbc.common.LogLevel;
5+
import com.databricks.jdbc.common.util.LoggingUtil;
6+
import com.databricks.jdbc.dbclient.impl.http.DatabricksHttpClient;
7+
import com.databricks.jdbc.exception.DatabricksHttpException;
8+
import com.databricks.sdk.core.DatabricksException;
9+
import com.databricks.sdk.core.oauth.OpenIDConnectEndpoints;
10+
import com.fasterxml.jackson.databind.ObjectMapper;
11+
import java.io.IOException;
12+
import java.net.URISyntaxException;
13+
import org.apache.http.client.methods.CloseableHttpResponse;
14+
import org.apache.http.client.methods.HttpGet;
15+
import org.apache.http.client.utils.URIBuilder;
16+
17+
public class AuthUtils {
18+
public static String getTokenEndpoint(IDatabricksConnectionContext context) {
19+
String tokenUrl;
20+
if (context.getTokenEndpoint() != null) {
21+
tokenUrl = context.getTokenEndpoint();
22+
} else if (context.isOAuthDiscoveryModeEnabled()) {
23+
try {
24+
tokenUrl = getTokenEndpointFromDiscoveryEndpoint(context);
25+
} catch (DatabricksException e) {
26+
String exceptionMessage = "Failed to get token endpoint from discovery endpoint";
27+
LoggingUtil.log(LogLevel.ERROR, exceptionMessage);
28+
throw new DatabricksException(exceptionMessage, e);
29+
}
30+
} else {
31+
try {
32+
tokenUrl =
33+
new URIBuilder()
34+
.setHost(context.getHostForOAuth())
35+
.setScheme("https")
36+
.setPathSegments("oidc", "v1", "token")
37+
.build()
38+
.toString();
39+
} catch (URISyntaxException e) {
40+
String exceptionMessage = "Failed to build token url";
41+
LoggingUtil.log(LogLevel.ERROR, exceptionMessage);
42+
throw new DatabricksException(exceptionMessage, e);
43+
}
44+
}
45+
return tokenUrl;
46+
}
47+
48+
/*
49+
* TODO : The following will be removed once SDK changes are merged
50+
* https://github.com/databricks/databricks-sdk-java/pull/336
51+
* */
52+
private static String getTokenEndpointFromDiscoveryEndpoint(
53+
IDatabricksConnectionContext connectionContext) throws DatabricksException {
54+
if (connectionContext.getOAuthDiscoveryURL() == null) {
55+
String exceptionMessage =
56+
"If discovery mode is enabled, we also need the discovery URL to be set.";
57+
LoggingUtil.log(LogLevel.ERROR, exceptionMessage);
58+
throw new DatabricksException(exceptionMessage);
59+
}
60+
try {
61+
URIBuilder uriBuilder = new URIBuilder(connectionContext.getOAuthDiscoveryURL());
62+
DatabricksHttpClient httpClient = DatabricksHttpClient.getInstance(connectionContext);
63+
HttpGet getRequest = new HttpGet(uriBuilder.build());
64+
try (CloseableHttpResponse response = httpClient.execute(getRequest)) {
65+
if (response.getStatusLine().getStatusCode() != 200) {
66+
String exceptionMessage =
67+
"Error while calling discovery endpoint to fetch token endpoint. Response: "
68+
+ response;
69+
LoggingUtil.log(LogLevel.DEBUG, exceptionMessage);
70+
throw new DatabricksHttpException(exceptionMessage);
71+
}
72+
OpenIDConnectEndpoints openIDConnectEndpoints =
73+
new ObjectMapper()
74+
.readValue(response.getEntity().getContent(), OpenIDConnectEndpoints.class);
75+
return openIDConnectEndpoints.getTokenEndpoint();
76+
}
77+
} catch (URISyntaxException | DatabricksHttpException | IOException e) {
78+
String exceptionMessage = "Failed to get token endpoint from discovery endpoint";
79+
LoggingUtil.log(LogLevel.ERROR, exceptionMessage);
80+
throw new DatabricksException(exceptionMessage, e);
81+
}
82+
}
83+
}

src/main/java/com/databricks/jdbc/auth/OAuthAuthenticator.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import com.databricks.jdbc.common.DatabricksJdbcConstants;
55
import com.databricks.jdbc.exception.DatabricksParsingException;
66
import com.databricks.sdk.WorkspaceClient;
7+
import com.databricks.sdk.core.CredentialsProvider;
78
import com.databricks.sdk.core.DatabricksConfig;
89

910
public class OAuthAuthenticator {
@@ -31,7 +32,11 @@ else if (this.connectionContext
3132
.equals(IDatabricksConnectionContext.AuthMech.OAUTH)) {
3233
switch (this.connectionContext.getAuthFlow()) {
3334
case TOKEN_PASSTHROUGH:
34-
setupAccessTokenConfig(databricksConfig);
35+
if (connectionContext.getOAuthRefreshToken() != null) {
36+
setupU2MRefreshConfig(databricksConfig);
37+
} else {
38+
setupAccessTokenConfig(databricksConfig);
39+
}
3540
break;
3641
case CLIENT_CREDENTIALS:
3742
setupM2MConfig(databricksConfig);
@@ -65,6 +70,17 @@ public void setupAccessTokenConfig(DatabricksConfig databricksConfig)
6570
.setToken(connectionContext.getToken());
6671
}
6772

73+
public void setupU2MRefreshConfig(DatabricksConfig databricksConfig)
74+
throws DatabricksParsingException {
75+
CredentialsProvider provider = new OAuthRefreshCredentialsProvider(connectionContext);
76+
databricksConfig
77+
.setHost(connectionContext.getHostForOAuth())
78+
.setAuthType(provider.authType())
79+
.setCredentialsProvider(provider)
80+
.setClientId(connectionContext.getClientId())
81+
.setClientSecret(connectionContext.getClientSecret());
82+
}
83+
6884
public void setupM2MConfig(DatabricksConfig databricksConfig) throws DatabricksParsingException {
6985
databricksConfig
7086
.setAuthType(DatabricksJdbcConstants.M2M_AUTH_TYPE)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package com.databricks.jdbc.auth;
2+
3+
import static com.databricks.jdbc.auth.AuthConstants.*;
4+
5+
import com.databricks.jdbc.api.IDatabricksConnectionContext;
6+
import com.databricks.jdbc.common.DatabricksJdbcConstants;
7+
import com.databricks.jdbc.common.LogLevel;
8+
import com.databricks.jdbc.common.util.LoggingUtil;
9+
import com.databricks.jdbc.exception.DatabricksParsingException;
10+
import com.databricks.sdk.core.CredentialsProvider;
11+
import com.databricks.sdk.core.DatabricksConfig;
12+
import com.databricks.sdk.core.DatabricksException;
13+
import com.databricks.sdk.core.HeaderFactory;
14+
import com.databricks.sdk.core.http.HttpClient;
15+
import com.databricks.sdk.core.oauth.AuthParameterPosition;
16+
import com.databricks.sdk.core.oauth.RefreshableTokenSource;
17+
import com.databricks.sdk.core.oauth.Token;
18+
import java.time.LocalDateTime;
19+
import java.util.HashMap;
20+
import java.util.Map;
21+
import org.apache.http.HttpHeaders;
22+
23+
public class OAuthRefreshCredentialsProvider extends RefreshableTokenSource
24+
implements CredentialsProvider {
25+
IDatabricksConnectionContext context;
26+
private HttpClient hc;
27+
private final String tokenUrl;
28+
private final String clientId;
29+
private final String clientSecret;
30+
31+
public OAuthRefreshCredentialsProvider(IDatabricksConnectionContext context) {
32+
this.context = context;
33+
this.tokenUrl = AuthUtils.getTokenEndpoint(context);
34+
try {
35+
this.clientId = context.getClientId();
36+
} catch (DatabricksParsingException e) {
37+
String exceptionMessage = "Failed to parse client id";
38+
LoggingUtil.log(LogLevel.ERROR, exceptionMessage);
39+
throw new DatabricksException(exceptionMessage, e);
40+
}
41+
this.clientSecret = context.getClientSecret();
42+
// Create an expired dummy token object with the refresh token to use
43+
this.token =
44+
new Token(
45+
DatabricksJdbcConstants.EMPTY_STRING,
46+
DatabricksJdbcConstants.EMPTY_STRING,
47+
context.getOAuthRefreshToken(),
48+
LocalDateTime.now().minusMinutes(1));
49+
}
50+
51+
@Override
52+
public String authType() {
53+
return "oauth-refresh";
54+
}
55+
56+
@Override
57+
public HeaderFactory configure(DatabricksConfig databricksConfig) {
58+
if (this.hc == null) {
59+
this.hc = databricksConfig.getHttpClient();
60+
}
61+
return () -> {
62+
Map<String, String> headers = new HashMap<>();
63+
// An example header looks like: "Authorization: Bearer <access-token>"
64+
headers.put(
65+
HttpHeaders.AUTHORIZATION, getToken().getTokenType() + " " + getToken().getAccessToken());
66+
return headers;
67+
};
68+
}
69+
70+
@Override
71+
protected Token refresh() {
72+
if (this.token == null) {
73+
String exceptionMessage = "oauth2: token is not set";
74+
LoggingUtil.log(LogLevel.ERROR, exceptionMessage);
75+
throw new DatabricksException(exceptionMessage);
76+
}
77+
String refreshToken = this.token.getRefreshToken();
78+
if (refreshToken == null) {
79+
String exceptionMessage = "oauth2: token expired and refresh token is not set";
80+
LoggingUtil.log(LogLevel.ERROR, exceptionMessage);
81+
throw new DatabricksException(exceptionMessage);
82+
}
83+
84+
Map<String, String> params = new HashMap<>();
85+
params.put(GRANT_TYPE_KEY, GRANT_TYPE_REFRESH_TOKEN_KEY);
86+
params.put(GRANT_TYPE_REFRESH_TOKEN_KEY, refreshToken);
87+
Map<String, String> headers = new HashMap<>();
88+
return retrieveToken(
89+
hc, clientId, clientSecret, tokenUrl, params, headers, AuthParameterPosition.BODY);
90+
}
91+
}
Lines changed: 5 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,17 @@
11
package com.databricks.jdbc.auth;
22

33
import com.databricks.jdbc.api.IDatabricksConnectionContext;
4-
import com.databricks.jdbc.common.LogLevel;
5-
import com.databricks.jdbc.common.util.LoggingUtil;
64
import com.databricks.jdbc.dbclient.IDatabricksHttpClient;
75
import com.databricks.jdbc.dbclient.impl.http.DatabricksHttpClient;
8-
import com.databricks.jdbc.exception.DatabricksHttpException;
96
import com.databricks.sdk.core.CredentialsProvider;
107
import com.databricks.sdk.core.DatabricksConfig;
118
import com.databricks.sdk.core.HeaderFactory;
12-
import com.databricks.sdk.core.oauth.OpenIDConnectEndpoints;
139
import com.databricks.sdk.core.oauth.Token;
14-
import com.fasterxml.jackson.databind.ObjectMapper;
1510
import com.google.common.annotations.VisibleForTesting;
16-
import java.io.IOException;
17-
import java.net.URISyntaxException;
1811
import java.util.Collections;
1912
import java.util.HashMap;
2013
import java.util.Map;
21-
import org.apache.http.client.methods.CloseableHttpResponse;
22-
import org.apache.http.client.methods.HttpGet;
23-
import org.apache.http.client.utils.URIBuilder;
14+
import org.apache.http.HttpHeaders;
2415

2516
public class PrivateKeyClientCredentialProvider implements CredentialsProvider {
2617

@@ -30,14 +21,8 @@ public class PrivateKeyClientCredentialProvider implements CredentialsProvider {
3021
IDatabricksHttpClient httpClient;
3122

3223
public PrivateKeyClientCredentialProvider(IDatabricksConnectionContext connectionContext) {
33-
this(connectionContext, DatabricksHttpClient.getInstance(connectionContext));
34-
}
35-
36-
@VisibleForTesting
37-
PrivateKeyClientCredentialProvider(
38-
IDatabricksConnectionContext connectionContext, IDatabricksHttpClient httpClient) {
3924
this.connectionContext = connectionContext;
40-
this.httpClient = httpClient;
25+
this.httpClient = DatabricksHttpClient.getInstance(connectionContext);
4126
}
4227

4328
@Override
@@ -47,17 +32,7 @@ public String authType() {
4732

4833
@VisibleForTesting
4934
JwtPrivateKeyClientCredentials getClientCredentialObject(DatabricksConfig config) {
50-
tokenEndpoint = connectionContext.getTokenEndpoint();
51-
if (tokenEndpoint == null && connectionContext.isOAuthDiscoveryModeEnabled()) {
52-
updateOidcFromDiscoveryEndpoint();
53-
}
54-
if (tokenEndpoint == null) {
55-
try {
56-
tokenEndpoint = config.getOidcEndpoints().getTokenEndpoint();
57-
} catch (IOException e) {
58-
LoggingUtil.log(LogLevel.ERROR, "Unable to set default token endpoint with error " + e);
59-
}
60-
}
35+
tokenEndpoint = AuthUtils.getTokenEndpoint(connectionContext);
6136
return new JwtPrivateKeyClientCredentials.Builder()
6237
.withHttpClient(this.httpClient)
6338
.withClientId(config.getClientId())
@@ -76,41 +51,9 @@ public HeaderFactory configure(DatabricksConfig config) {
7651
return () -> {
7752
Token token = clientCredentials.getToken();
7853
Map<String, String> headers = new HashMap<>();
79-
headers.put("Authorization", token.getTokenType() + " " + token.getAccessToken());
80-
headers.put("Content-Type", "application/x-www-form-urlencoded");
54+
headers.put(HttpHeaders.AUTHORIZATION, token.getTokenType() + " " + token.getAccessToken());
55+
headers.put(HttpHeaders.CONTENT_TYPE, "application/x-www-form-urlencoded");
8156
return headers;
8257
};
8358
}
84-
85-
/*
86-
* TODO : The following will be removed once SDK changes are merged
87-
* https://github.com/databricks/databricks-sdk-java/pull/336
88-
* */
89-
private void updateOidcFromDiscoveryEndpoint() {
90-
if (connectionContext.getOAuthDiscoveryURL() == null) {
91-
LoggingUtil.log(
92-
LogLevel.ERROR, "If discovery mode is enabled, we also need to put discovery URL");
93-
// not throwing exception as the interface does not support it
94-
return;
95-
}
96-
try {
97-
URIBuilder uriBuilder = new URIBuilder(connectionContext.getOAuthDiscoveryURL());
98-
HttpGet getRequest = new HttpGet(uriBuilder.build());
99-
CloseableHttpResponse response = this.httpClient.execute(getRequest);
100-
if (response.getStatusLine().getStatusCode() != 200) {
101-
LoggingUtil.log(
102-
LogLevel.DEBUG,
103-
"Error while calling discovery endpoint to fetch token endpoint. Response: "
104-
+ response);
105-
}
106-
OpenIDConnectEndpoints openIDConnectEndpoints =
107-
new ObjectMapper()
108-
.readValue(response.getEntity().getContent(), OpenIDConnectEndpoints.class);
109-
tokenEndpoint = openIDConnectEndpoints.getTokenEndpoint();
110-
} catch (URISyntaxException | DatabricksHttpException | IOException e) {
111-
LoggingUtil.log(
112-
LogLevel.ERROR,
113-
"Unable to retrieve token and auth endpoint from discovery endpoint. Error " + e);
114-
}
115-
}
11659
}

src/main/java/com/databricks/jdbc/common/DatabricksJdbcConstants.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ public final class DatabricksJdbcConstants {
9595

9696
public static final String AUTH_FLOW = "auth_flow";
9797

98+
public static final String OAUTH_REFRESH_TOKEN = "OAuthRefreshToken";
99+
98100
/** Only used when AUTH_MECH = 3 */
99101
public static final String PWD = "pwd";
100102

0 commit comments

Comments
 (0)