Skip to content

Commit f46c957

Browse files
authored
[PECO-2136] Changes to fix async execution (#603)
* Adding changes for Fixing async execution * Add tests * Fixes as per comments * Add tests
1 parent f63269e commit f46c957

File tree

9 files changed

+218
-43
lines changed

9 files changed

+218
-43
lines changed

src/main/java/com/databricks/jdbc/api/impl/DatabricksResultSet.java

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,9 @@
1818
import com.databricks.jdbc.log.JdbcLogger;
1919
import com.databricks.jdbc.log.JdbcLoggerFactory;
2020
import com.databricks.jdbc.model.client.thrift.generated.TFetchResultsResp;
21-
import com.databricks.jdbc.model.client.thrift.generated.TStatus;
2221
import com.databricks.jdbc.model.core.ColumnMetadata;
2322
import com.databricks.jdbc.model.core.ResultData;
2423
import com.databricks.jdbc.model.core.ResultManifest;
25-
import com.databricks.sdk.service.sql.StatementState;
2624
import com.databricks.sdk.service.sql.StatementStatus;
2725
import com.google.common.annotations.VisibleForTesting;
2826
import java.io.InputStream;
@@ -103,25 +101,14 @@ public DatabricksResultSet(
103101

104102
// Constructor for thrift result set
105103
public DatabricksResultSet(
106-
TStatus statementStatus,
104+
StatementStatus statementStatus,
107105
StatementId statementId,
108106
TFetchResultsResp resultsResp,
109107
StatementType statementType,
110108
IDatabricksStatementInternal parentStatement,
111109
IDatabricksSession session)
112110
throws SQLException {
113-
switch (statementStatus.getStatusCode()) {
114-
case SUCCESS_STATUS:
115-
case SUCCESS_WITH_INFO_STATUS:
116-
this.statementStatus = new StatementStatus().setState(StatementState.SUCCEEDED);
117-
break;
118-
case STILL_EXECUTING_STATUS:
119-
this.statementStatus = new StatementStatus().setState(StatementState.RUNNING);
120-
break;
121-
default:
122-
this.statementStatus = new StatementStatus().setState(StatementState.FAILED);
123-
}
124-
111+
this.statementStatus = statementStatus;
125112
this.statementId = statementId;
126113
if (resultsResp != null) {
127114
this.executionResult =

src/main/java/com/databricks/jdbc/common/util/DatabricksThriftUtil.java

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import com.databricks.jdbc.model.client.thrift.generated.*;
1515
import com.databricks.jdbc.model.core.ExternalLink;
1616
import com.databricks.sdk.service.sql.ColumnInfoTypeName;
17+
import com.databricks.sdk.service.sql.StatementState;
18+
import com.databricks.sdk.service.sql.StatementStatus;
1719
import java.nio.ByteBuffer;
1820
import java.util.*;
1921
import java.util.stream.Collectors;
@@ -98,6 +100,71 @@ public static List<List<Object>> extractValuesColumnar(List<TColumn> columnList)
98100
.collect(Collectors.toList());
99101
}
100102

103+
/** Returns statement status for given operation status response */
104+
public static StatementStatus getStatementStatus(TGetOperationStatusResp resp) {
105+
StatementState state = null;
106+
switch (resp.getOperationState()) {
107+
case INITIALIZED_STATE:
108+
case PENDING_STATE:
109+
state = StatementState.PENDING;
110+
break;
111+
112+
case RUNNING_STATE:
113+
state = StatementState.RUNNING;
114+
break;
115+
116+
case FINISHED_STATE:
117+
state = StatementState.SUCCEEDED;
118+
break;
119+
120+
case ERROR_STATE:
121+
case TIMEDOUT_STATE:
122+
// TODO: Also set the sql_state and error message
123+
state = StatementState.FAILED;
124+
break;
125+
126+
case CLOSED_STATE:
127+
state = StatementState.CLOSED;
128+
break;
129+
130+
case CANCELED_STATE:
131+
state = StatementState.CANCELED;
132+
break;
133+
134+
case UKNOWN_STATE:
135+
state = StatementState.FAILED;
136+
}
137+
138+
return new StatementStatus().setState(state);
139+
}
140+
141+
/** Returns statement status for given status response */
142+
public static StatementStatus getAsyncStatus(TStatus status) {
143+
StatementStatus statementStatus = new StatementStatus();
144+
StatementState state = null;
145+
146+
switch (status.getStatusCode()) {
147+
// For async mode, success would just mean that statement was successfully submitted
148+
// actual status should be checked using GetOperationStatus
149+
case SUCCESS_STATUS:
150+
case SUCCESS_WITH_INFO_STATUS:
151+
case STILL_EXECUTING_STATUS:
152+
state = StatementState.RUNNING;
153+
break;
154+
155+
case INVALID_HANDLE_STATUS:
156+
case ERROR_STATUS:
157+
// TODO: set sql_state in case of error
158+
state = StatementState.FAILED;
159+
break;
160+
161+
default:
162+
state = StatementState.FAILED;
163+
}
164+
165+
return new StatementStatus().setState(state);
166+
}
167+
101168
private static Object getObjectInColumn(TColumn column, int index) {
102169
if (column == null) {
103170
return NULL_STRING;

src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import com.databricks.jdbc.log.JdbcLoggerFactory;
2121
import com.databricks.jdbc.model.client.thrift.generated.*;
2222
import com.databricks.sdk.core.DatabricksConfig;
23+
import com.databricks.sdk.service.sql.StatementState;
24+
import com.databricks.sdk.service.sql.StatementStatus;
2325
import com.google.common.annotations.VisibleForTesting;
2426
import java.sql.SQLException;
2527
import java.util.Arrays;
@@ -265,6 +267,14 @@ DatabricksResultSet execute(
265267
maxRows,
266268
true);
267269
}
270+
271+
StatementId statementId = new StatementId(response.getOperationHandle().operationId);
272+
if (parentStatement != null) {
273+
parentStatement.setStatementId(statementId);
274+
}
275+
StatementStatus statementStatus = getStatementStatus(statusResp);
276+
return new DatabricksResultSet(
277+
statementStatus, statementId, resultSet, statementType, parentStatement, session);
268278
} catch (TException e) {
269279
String errorMessage =
270280
String.format(
@@ -273,12 +283,6 @@ DatabricksResultSet execute(
273283
LOGGER.error(e, errorMessage);
274284
throw new DatabricksHttpException(errorMessage, e);
275285
}
276-
StatementId statementId = new StatementId(response.getOperationHandle().operationId);
277-
if (parentStatement != null) {
278-
parentStatement.setStatementId(statementId);
279-
}
280-
return new DatabricksResultSet(
281-
response.getStatus(), statementId, resultSet, statementType, parentStatement, session);
282286
}
283287

284288
DatabricksResultSet executeAsync(
@@ -314,8 +318,9 @@ DatabricksResultSet executeAsync(
314318
if (parentStatement != null) {
315319
parentStatement.setStatementId(statementId);
316320
}
321+
StatementStatus statementStatus = getAsyncStatus(response.getStatus());
317322
return new DatabricksResultSet(
318-
response.getStatus(), statementId, null, statementType, parentStatement, session);
323+
statementStatus, statementId, null, statementType, parentStatement, session);
319324
}
320325

321326
DatabricksResultSet getStatementResult(
@@ -329,15 +334,21 @@ DatabricksResultSet getStatementResult(
329334
.setOperationHandle(operationHandle)
330335
.setGetProgressUpdate(false);
331336
TGetOperationStatusResp response;
332-
TStatusCode statusCode;
333337
TFetchResultsResp resultSet = null;
338+
StatementId statementId = new StatementId(operationHandle.getOperationId());
334339
try {
335340
response = getThriftClient().GetOperationStatus(request);
336-
statusCode = response.getStatus().getStatusCode();
337-
if (statusCode == TStatusCode.SUCCESS_STATUS
338-
|| statusCode == TStatusCode.SUCCESS_WITH_INFO_STATUS) {
341+
TOperationState operationState = response.getOperationState();
342+
if (operationState == TOperationState.FINISHED_STATE) {
339343
resultSet =
340344
getResultSetResp(response.getStatus(), operationHandle, response.toString(), -1, true);
345+
return new DatabricksResultSet(
346+
new StatementStatus().setState(StatementState.SUCCEEDED),
347+
statementId,
348+
resultSet,
349+
StatementType.SQL,
350+
parentStatement,
351+
session);
341352
}
342353
} catch (TException e) {
343354
String errorMessage =
@@ -347,9 +358,9 @@ DatabricksResultSet getStatementResult(
347358
LOGGER.error(e, errorMessage);
348359
throw new DatabricksHttpException(errorMessage, e);
349360
}
350-
StatementId statementId = new StatementId(operationHandle.getOperationId());
361+
StatementStatus executionStatus = getStatementStatus(response);
351362
return new DatabricksResultSet(
352-
response.getStatus(), statementId, resultSet, StatementType.SQL, parentStatement, session);
363+
executionStatus, statementId, resultSet, StatementType.SQL, parentStatement, session);
353364
}
354365

355366
void resetAccessToken(String newAccessToken) {
@@ -528,7 +539,10 @@ private void checkOperationStatusForErrors(TGetOperationStatusResp statusResp)
528539
if (statusResp != null
529540
&& statusResp.isSetOperationState()
530541
&& isErrorOperationState(statusResp.getOperationState())) {
531-
throw new DatabricksSQLException("Operation state erroneous");
542+
String errorMsg =
543+
String.format("Operation failed with error: %s", statusResp.getErrorMessage());
544+
LOGGER.error(errorMsg);
545+
throw new DatabricksSQLException(errorMsg, statusResp.getSqlState());
532546
}
533547
}
534548

src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClient.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ public DatabricksResultSet executeStatementAsync(
167167
.setSessionHandle(session.getSessionInfo().sessionHandle())
168168
.setCanDecompressLZ4Result(true)
169169
.setCanReadArrowResult(this.connectionContext.shouldEnableArrow())
170+
.setRunAsync(true)
170171
.setCanDownloadResult(true);
171172
return thriftAccessor.executeAsync(request, parentStatement, session, StatementType.SQL);
172173
}

src/test/java/com/databricks/client/jdbc/DriverTest.java

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import com.databricks.jdbc.api.impl.volume.DatabricksVolumeClientFactory;
1010
import com.databricks.jdbc.common.DatabricksJdbcConstants;
1111
import com.databricks.jdbc.exception.DatabricksSQLException;
12+
import com.databricks.sdk.service.sql.StatementState;
1213
import java.io.File;
1314
import java.io.FileInputStream;
1415
import java.math.BigDecimal;
@@ -437,26 +438,51 @@ void testAllPurposeClusters_async() throws Exception {
437438
System.out.println("Connection established...... con1");
438439
Statement s = con.createStatement();
439440
IDatabricksStatement ids = s.unwrap(IDatabricksStatement.class);
440-
ResultSet rs = ids.executeAsync("SELECT * from RANGE(10)");
441+
long initialTime = System.currentTimeMillis();
442+
String sql =
443+
"CREATE TABLE TMP_P2P_EKKO_EKPO_ASYNC8 AS ("
444+
+ " SELECT * FROM ("
445+
+ " SELECT * FROM ("
446+
+ " SELECT t1.*"
447+
+ " FROM main.streaming.random_large_table t1"
448+
+ " INNER JOIN main.streaming.random_large_table t2"
449+
+ " ON t1.prompt = t2.prompt"
450+
+ " ) nested_t1"
451+
+ " ) nested_t2"
452+
+ ")";
453+
ResultSet rs = ids.executeAsync(sql);
441454
System.out.println(
442-
"1Status of async execution " + rs.unwrap(IDatabricksResultSet.class).getStatementStatus());
443-
444-
ResultSet rs3 = s.unwrap(IDatabricksStatement.class).getExecutionResult();
445-
System.out.println(
446-
"2Status of async execution "
447-
+ rs3.unwrap(IDatabricksResultSet.class).getStatementStatus());
448-
455+
"Status of async execution " + rs.unwrap(IDatabricksResultSet.class).getStatementStatus());
456+
System.out.println("Time taken: " + (System.currentTimeMillis() - initialTime));
449457
System.out.println("StatementId " + rs.unwrap(IDatabricksResultSet.class).getStatementId());
450458

459+
int count = 1;
460+
StatementState state = rs.unwrap(IDatabricksResultSet.class).getStatementStatus().getState();
461+
while (state != StatementState.SUCCEEDED && state != StatementState.FAILED) {
462+
Thread.sleep(1000);
463+
rs = s.unwrap(IDatabricksStatement.class).getExecutionResult();
464+
state = rs.unwrap(IDatabricksResultSet.class).getStatementStatus().getState();
465+
System.out.println(
466+
"Status of async execution "
467+
+ state
468+
+ " attempt "
469+
+ count++
470+
+ " time taken "
471+
+ (System.currentTimeMillis() - initialTime));
472+
}
473+
451474
Connection con2 = DriverManager.getConnection(jdbcUrl, "token", "token");
452475
System.out.println("Connection established......con2");
453476
IDatabricksConnection idc = con2.unwrap(IDatabricksConnection.class);
454477
Statement stm = idc.getStatement(rs.unwrap(IDatabricksResultSet.class).getStatementId());
455478
ResultSet rs2 = stm.unwrap(IDatabricksStatement.class).getExecutionResult();
479+
456480
System.out.println(
457-
"3Status of async execution "
458-
+ rs2.unwrap(IDatabricksResultSet.class).getStatementStatus());
481+
"Status of async execution using con2 "
482+
+ rs2.unwrap(IDatabricksResultSet.class).getStatementStatus().getState());
483+
459484
stm.cancel();
485+
stm.execute("DROP TABLE TMP_P2P_EKKO_EKPO_ASYNC8");
460486
System.out.println("Statement cancelled using con2");
461487
s.close();
462488
System.out.println("Statement cancelled using con1");

src/test/java/com/databricks/jdbc/api/impl/DatabricksResultSetTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ private DatabricksResultSet getThriftResultSetMetadata() throws SQLException {
6868
when(fetchResultsResp.getResults()).thenReturn(rowSet);
6969
when(fetchResultsResp.getResultSetMetadata()).thenReturn(metadataResp);
7070
return new DatabricksResultSet(
71-
new TStatus().setStatusCode(TStatusCode.SUCCESS_STATUS),
71+
new StatementStatus().setState(StatementState.SUCCEEDED),
7272
THRIFT_STATEMENT_ID,
7373
fetchResultsResp,
7474
StatementType.METADATA,

src/test/java/com/databricks/jdbc/common/util/DatabricksThriftUtilTest.java

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import com.databricks.jdbc.exception.DatabricksSQLException;
1313
import com.databricks.jdbc.model.client.thrift.generated.*;
1414
import com.databricks.sdk.service.sql.ColumnInfoTypeName;
15+
import com.databricks.sdk.service.sql.StatementState;
1516
import java.nio.ByteBuffer;
1617
import java.util.Collections;
1718
import java.util.List;
@@ -227,4 +228,80 @@ public void testGetTypeFromTypeDesc(TTypeId type, ColumnInfoTypeName typeName) {
227228
public void testCheckDirectResultsForErrorStatus(TSparkDirectResults response) {
228229
assertDoesNotThrow(() -> checkDirectResultsForErrorStatus(response, TEST_STRING));
229230
}
231+
232+
@Test
233+
public void testGetStatementStatus() throws Exception {
234+
assertEquals(
235+
StatementState.PENDING,
236+
DatabricksThriftUtil.getStatementStatus(
237+
new TGetOperationStatusResp().setOperationState(TOperationState.INITIALIZED_STATE))
238+
.getState());
239+
assertEquals(
240+
StatementState.PENDING,
241+
DatabricksThriftUtil.getStatementStatus(
242+
new TGetOperationStatusResp().setOperationState(TOperationState.PENDING_STATE))
243+
.getState());
244+
assertEquals(
245+
StatementState.SUCCEEDED,
246+
DatabricksThriftUtil.getStatementStatus(
247+
new TGetOperationStatusResp().setOperationState(TOperationState.FINISHED_STATE))
248+
.getState());
249+
assertEquals(
250+
StatementState.RUNNING,
251+
DatabricksThriftUtil.getStatementStatus(
252+
new TGetOperationStatusResp().setOperationState(TOperationState.RUNNING_STATE))
253+
.getState());
254+
assertEquals(
255+
StatementState.FAILED,
256+
DatabricksThriftUtil.getStatementStatus(
257+
new TGetOperationStatusResp().setOperationState(TOperationState.ERROR_STATE))
258+
.getState());
259+
assertEquals(
260+
StatementState.FAILED,
261+
DatabricksThriftUtil.getStatementStatus(
262+
new TGetOperationStatusResp().setOperationState(TOperationState.TIMEDOUT_STATE))
263+
.getState());
264+
assertEquals(
265+
StatementState.FAILED,
266+
DatabricksThriftUtil.getStatementStatus(
267+
new TGetOperationStatusResp().setOperationState(TOperationState.UKNOWN_STATE))
268+
.getState());
269+
assertEquals(
270+
StatementState.CLOSED,
271+
DatabricksThriftUtil.getStatementStatus(
272+
new TGetOperationStatusResp().setOperationState(TOperationState.CLOSED_STATE))
273+
.getState());
274+
assertEquals(
275+
StatementState.CANCELED,
276+
DatabricksThriftUtil.getStatementStatus(
277+
new TGetOperationStatusResp().setOperationState(TOperationState.CANCELED_STATE))
278+
.getState());
279+
}
280+
281+
@Test
282+
public void testGetStatementStatusForAsync() throws Exception {
283+
assertEquals(
284+
StatementState.RUNNING,
285+
DatabricksThriftUtil.getAsyncStatus(new TStatus().setStatusCode(TStatusCode.SUCCESS_STATUS))
286+
.getState());
287+
assertEquals(
288+
StatementState.RUNNING,
289+
DatabricksThriftUtil.getAsyncStatus(
290+
new TStatus().setStatusCode(TStatusCode.SUCCESS_WITH_INFO_STATUS))
291+
.getState());
292+
assertEquals(
293+
StatementState.RUNNING,
294+
DatabricksThriftUtil.getAsyncStatus(
295+
new TStatus().setStatusCode(TStatusCode.STILL_EXECUTING_STATUS))
296+
.getState());
297+
assertEquals(
298+
StatementState.FAILED,
299+
DatabricksThriftUtil.getAsyncStatus(
300+
new TStatus().setStatusCode(TStatusCode.INVALID_HANDLE_STATUS))
301+
.getState());
302+
assertEquals(
303+
StatementState.FAILED,
304+
DatabricksThriftUtil.getAsyncStatus(new TStatus().setStatusCode(TStatusCode.ERROR_STATUS))
305+
.getState());
306+
}
230307
}

0 commit comments

Comments
 (0)