Skip to content

Commit 8eb46d0

Browse files
authored
[PECO-1565] [PECO-1563] Thrift - add other auth types + refresh token (#173)
* add other auth types * Refresh token * Add more tests * Modify naming
1 parent ea8a502 commit 8eb46d0

File tree

6 files changed

+95
-69
lines changed

6 files changed

+95
-69
lines changed

src/main/java/com/databricks/jdbc/client/impl/sdk/DatabricksSdkClient.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ public DatabricksSdkClient(
6969
ApiClient apiClient)
7070
throws DatabricksParsingException {
7171
this.connectionContext = connectionContext;
72-
// Handle more auth types
7372
this.databricksConfig =
7473
new DatabricksConfig()
7574
.setHost(connectionContext.getHostUrl())

src/main/java/com/databricks/jdbc/client/impl/thrift/commons/DatabricksThriftAccessor.java

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,37 +34,31 @@ public DatabricksThriftAccessor(IDatabricksConnectionContext connectionContext)
3434
new DatabricksHttpTTransport(
3535
DatabricksHttpClient.getInstance(connectionContext),
3636
connectionContext.getEndpointURL());
37-
// TODO : add other auth in followup PRs
38-
this.databricksConfig =
39-
new DatabricksConfig()
40-
.setHost(connectionContext.getHostUrl())
41-
.setToken(connectionContext.getToken());
37+
this.databricksConfig = new OAuthAuthenticator(connectionContext).getDatabricksConfig();
4238
Map<String, String> authHeaders = databricksConfig.authenticate();
4339
transport.setCustomHeaders(authHeaders);
4440
TBinaryProtocol protocol = new TBinaryProtocol(transport);
4541
this.thriftClient = new TCLIService.Client(protocol);
4642
}
4743

4844
@VisibleForTesting
49-
public DatabricksThriftAccessor(TCLIService.Client client) {
50-
this.databricksConfig = null;
45+
DatabricksThriftAccessor(TCLIService.Client client, DatabricksConfig config) {
46+
this.databricksConfig = config;
5147
this.thriftClient = client;
5248
}
5349

5450
public TBase getThriftResponse(
5551
TBase request, CommandName commandName, IDatabricksStatement parentStatement)
5652
throws DatabricksSQLException {
5753
/*Todo list :
58-
* 1. Poll until we get a success status
59-
* 2. Test out metadata operations.
60-
* 3. Add token refresh
61-
* 4. Handle cloud-fetch
62-
* 5. Handle compression
54+
* 1. Test out metadata operations.
55+
* 2. Handle compression
6356
* */
6457
LOGGER.debug(
6558
"Fetching thrift response for request {}, CommandName {}",
6659
request.toString(),
6760
commandName.name());
61+
refreshHeadersIfRequired();
6862
try {
6963
switch (commandName) {
7064
case OPEN_SESSION:
@@ -107,11 +101,12 @@ public TBase getThriftResponse(
107101

108102
public TFetchResultsResp getResultSetResp(TOperationHandle operationHandle, String context)
109103
throws DatabricksHttpException {
104+
refreshHeadersIfRequired();
110105
return getResultSetResp(
111106
TStatusCode.SUCCESS_STATUS, operationHandle, context, DEFAULT_ROW_LIMIT, false);
112107
}
113108

114-
public TFetchResultsResp getResultSetResp(
109+
private TFetchResultsResp getResultSetResp(
115110
TStatusCode responseCode,
116111
TOperationHandle operationHandle,
117112
String context,
@@ -169,6 +164,7 @@ public DatabricksResultSet execute(
169164
IDatabricksSession session,
170165
StatementType statementType)
171166
throws SQLException {
167+
refreshHeadersIfRequired();
172168
int maxRows = (parentStatement == null) ? DEFAULT_ROW_LIMIT : parentStatement.getMaxRows();
173169
TSparkGetDirectResults directResults =
174170
new TSparkGetDirectResults().setMaxBytes(DEFAULT_BYTE_LIMIT).setMaxRows(maxRows);
@@ -338,10 +334,15 @@ private TFetchResultsResp listColumns(TGetColumnsReq request)
338334
false);
339335
}
340336

341-
public TGetResultSetMetadataResp getResultSetMetadata(TOperationHandle operationHandle)
337+
private TGetResultSetMetadataResp getResultSetMetadata(TOperationHandle operationHandle)
342338
throws TException {
343339
TGetResultSetMetadataReq resultSetMetadataReq =
344340
new TGetResultSetMetadataReq().setOperationHandle(operationHandle);
345341
return thriftClient.GetResultSetMetadata(resultSetMetadataReq);
346342
}
343+
344+
private void refreshHeadersIfRequired() {
345+
((DatabricksHttpTTransport) thriftClient.getInputProtocol().getTransport())
346+
.setCustomHeaders(databricksConfig.authenticate());
347+
}
347348
}

src/main/java/com/databricks/jdbc/core/OAuthAuthenticator.java

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,56 +14,55 @@ public OAuthAuthenticator(IDatabricksConnectionContext connectionContext) {
1414
}
1515

1616
public WorkspaceClient getWorkspaceClient() throws DatabricksParsingException {
17+
return new WorkspaceClient(getDatabricksConfig());
18+
}
19+
20+
public DatabricksConfig getDatabricksConfig() throws DatabricksParsingException {
1721
if (this.connectionContext.getAuthMech().equals(IDatabricksConnectionContext.AuthMech.PAT)) {
18-
return authenticateAccessToken();
22+
return createAccessTokenConfig();
1923
}
2024
// TODO(Madhav): Revisit these to set JDBC values
2125
else if (this.connectionContext
2226
.getAuthMech()
2327
.equals(IDatabricksConnectionContext.AuthMech.OAUTH)) {
2428
switch (this.connectionContext.getAuthFlow()) {
2529
case TOKEN_PASSTHROUGH:
26-
return authenticateAccessToken();
30+
return createAccessTokenConfig();
2731
case CLIENT_CREDENTIALS:
28-
return authenticateM2M();
32+
return createM2MConfig();
2933
case BROWSER_BASED_AUTHENTICATION:
30-
return authenticateU2M();
34+
return createU2MConfig();
3135
}
3236
}
33-
return authenticateAccessToken();
37+
return createAccessTokenConfig();
3438
}
3539

36-
public WorkspaceClient authenticateU2M() throws DatabricksParsingException {
40+
public DatabricksConfig createU2MConfig() throws DatabricksParsingException {
3741
DatabricksConfig config =
3842
new DatabricksConfig()
3943
.setAuthType(DatabricksJdbcConstants.U2M_AUTH_TYPE)
40-
.setHost(this.connectionContext.getHostForOAuth())
41-
.setClientId(this.connectionContext.getClientId())
42-
.setClientSecret(this.connectionContext.getClientSecret())
44+
.setHost(connectionContext.getHostForOAuth())
45+
.setClientId(connectionContext.getClientId())
46+
.setClientSecret(connectionContext.getClientSecret())
4347
.setOAuthRedirectUrl(DatabricksJdbcConstants.U2M_AUTH_REDIRECT_URL);
4448
if (!config.isAzure()) {
45-
// Default scope is already being set for Azure in databricks-sdk.
46-
config.setScopes(this.connectionContext.getOAuthScopesForU2M());
49+
config.setScopes(connectionContext.getOAuthScopesForU2M());
4750
}
48-
return new WorkspaceClient(config);
51+
return config;
4952
}
5053

51-
public WorkspaceClient authenticateAccessToken() throws DatabricksParsingException {
52-
DatabricksConfig config =
53-
new DatabricksConfig()
54-
.setAuthType(DatabricksJdbcConstants.ACCESS_TOKEN_AUTH_TYPE)
55-
.setHost(this.connectionContext.getHostUrl())
56-
.setToken(this.connectionContext.getToken());
57-
return new WorkspaceClient(config);
54+
public DatabricksConfig createAccessTokenConfig() throws DatabricksParsingException {
55+
return new DatabricksConfig()
56+
.setAuthType(DatabricksJdbcConstants.ACCESS_TOKEN_AUTH_TYPE)
57+
.setHost(connectionContext.getHostUrl())
58+
.setToken(connectionContext.getToken());
5859
}
5960

60-
public WorkspaceClient authenticateM2M() throws DatabricksParsingException {
61-
DatabricksConfig config =
62-
new DatabricksConfig()
63-
.setAuthType(DatabricksJdbcConstants.M2M_AUTH_TYPE)
64-
.setHost(this.connectionContext.getHostForOAuth())
65-
.setClientId(this.connectionContext.getClientId())
66-
.setClientSecret(this.connectionContext.getClientSecret());
67-
return new WorkspaceClient(config);
61+
public DatabricksConfig createM2MConfig() throws DatabricksParsingException {
62+
return new DatabricksConfig()
63+
.setAuthType(DatabricksJdbcConstants.M2M_AUTH_TYPE)
64+
.setHost(connectionContext.getHostForOAuth())
65+
.setClientId(connectionContext.getClientId())
66+
.setClientSecret(connectionContext.getClientSecret());
6867
}
6968
}

0 commit comments

Comments
 (0)