@@ -34,37 +34,31 @@ public DatabricksThriftAccessor(IDatabricksConnectionContext connectionContext)
3434 new DatabricksHttpTTransport (
3535 DatabricksHttpClient .getInstance (connectionContext ),
3636 connectionContext .getEndpointURL ());
37- // TODO : add other auth in followup PRs
38- this .databricksConfig =
39- new DatabricksConfig ()
40- .setHost (connectionContext .getHostUrl ())
41- .setToken (connectionContext .getToken ());
37+ this .databricksConfig = new OAuthAuthenticator (connectionContext ).getDatabricksConfig ();
4238 Map <String , String > authHeaders = databricksConfig .authenticate ();
4339 transport .setCustomHeaders (authHeaders );
4440 TBinaryProtocol protocol = new TBinaryProtocol (transport );
4541 this .thriftClient = new TCLIService .Client (protocol );
4642 }
4743
4844 @ VisibleForTesting
49- public DatabricksThriftAccessor (TCLIService .Client client ) {
50- this .databricksConfig = null ;
45+ DatabricksThriftAccessor (TCLIService .Client client , DatabricksConfig config ) {
46+ this .databricksConfig = config ;
5147 this .thriftClient = client ;
5248 }
5349
5450 public TBase getThriftResponse (
5551 TBase request , CommandName commandName , IDatabricksStatement parentStatement )
5652 throws DatabricksSQLException {
5753 /*Todo list :
58- * 1. Poll until we get a success status
59- * 2. Test out metadata operations.
60- * 3. Add token refresh
61- * 4. Handle cloud-fetch
62- * 5. Handle compression
54+ * 1. Test out metadata operations.
55+ * 2. Handle compression
6356 * */
6457 LOGGER .debug (
6558 "Fetching thrift response for request {}, CommandName {}" ,
6659 request .toString (),
6760 commandName .name ());
61+ refreshHeadersIfRequired ();
6862 try {
6963 switch (commandName ) {
7064 case OPEN_SESSION :
@@ -107,11 +101,12 @@ public TBase getThriftResponse(
107101
108102 public TFetchResultsResp getResultSetResp (TOperationHandle operationHandle , String context )
109103 throws DatabricksHttpException {
104+ refreshHeadersIfRequired ();
110105 return getResultSetResp (
111106 TStatusCode .SUCCESS_STATUS , operationHandle , context , DEFAULT_ROW_LIMIT , false );
112107 }
113108
114- public TFetchResultsResp getResultSetResp (
109+ private TFetchResultsResp getResultSetResp (
115110 TStatusCode responseCode ,
116111 TOperationHandle operationHandle ,
117112 String context ,
@@ -169,6 +164,7 @@ public DatabricksResultSet execute(
169164 IDatabricksSession session ,
170165 StatementType statementType )
171166 throws SQLException {
167+ refreshHeadersIfRequired ();
172168 int maxRows = (parentStatement == null ) ? DEFAULT_ROW_LIMIT : parentStatement .getMaxRows ();
173169 TSparkGetDirectResults directResults =
174170 new TSparkGetDirectResults ().setMaxBytes (DEFAULT_BYTE_LIMIT ).setMaxRows (maxRows );
@@ -338,10 +334,15 @@ private TFetchResultsResp listColumns(TGetColumnsReq request)
338334 false );
339335 }
340336
341- public TGetResultSetMetadataResp getResultSetMetadata (TOperationHandle operationHandle )
337+ private TGetResultSetMetadataResp getResultSetMetadata (TOperationHandle operationHandle )
342338 throws TException {
343339 TGetResultSetMetadataReq resultSetMetadataReq =
344340 new TGetResultSetMetadataReq ().setOperationHandle (operationHandle );
345341 return thriftClient .GetResultSetMetadata (resultSetMetadataReq );
346342 }
343+
344+ private void refreshHeadersIfRequired () {
345+ ((DatabricksHttpTTransport ) thriftClient .getInputProtocol ().getTransport ())
346+ .setCustomHeaders (databricksConfig .authenticate ());
347+ }
347348}
0 commit comments