Skip to content

Commit 0241ebc

Browse files
authored
Merge pull request #303 from databricks/jprakash-db/PECO-1758
[ PECO - 1758 ] - [OSS JDBC] Make batch APIs work in JDBC
2 parents 043d82a + fe7d80b commit 0241ebc

File tree

4 files changed

+144
-7
lines changed

4 files changed

+144
-7
lines changed

src/main/java/com/databricks/jdbc/core/DatabricksPreparedStatement.java

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
public class DatabricksPreparedStatement extends DatabricksStatement implements PreparedStatement {
2323
private final String sql;
24-
private final DatabricksParameterMetaData databricksParameterMetaData;
24+
private DatabricksParameterMetaData databricksParameterMetaData;
25+
List<DatabricksParameterMetaData> databricksBatchParameterMetaData;
2526
private final boolean supportManyParameters;
2627

2728
private final int CHUNK_SIZE = 8192;
@@ -32,6 +33,7 @@ public DatabricksPreparedStatement(DatabricksConnection connection, String sql)
3233
this.supportManyParameters =
3334
connection.getSession().getConnectionContext().supportManyParameters();
3435
this.databricksParameterMetaData = new DatabricksParameterMetaData();
36+
this.databricksBatchParameterMetaData = new ArrayList<>();
3537
}
3638

3739
private void checkLength(int targetLength, int sourceLength) throws SQLException {
@@ -56,6 +58,15 @@ private void checkLength(long targetLength, long sourceLength) throws SQLExcepti
5658
}
5759
}
5860

61+
private void checkIfBatchOperation() throws DatabricksSQLException {
62+
if (!this.databricksBatchParameterMetaData.isEmpty()) {
63+
String errorMessage =
64+
"Batch must either be executed with executeBatch() or cleared with clearBatch()";
65+
LoggingUtil.log(LogLevel.ERROR, errorMessage);
66+
throw new DatabricksSQLException(errorMessage);
67+
}
68+
}
69+
5970
private byte[] readByteStream(InputStream x, int length) throws SQLException {
6071
if (x == null) {
6172
String errorMessage = "InputStream cannot be null";
@@ -77,16 +88,38 @@ private byte[] readByteStream(InputStream x, int length) throws SQLException {
7788
@Override
7889
public ResultSet executeQuery() throws SQLException {
7990
LoggingUtil.log(LogLevel.DEBUG, "public ResultSet executeQuery()");
91+
checkIfBatchOperation();
8092
return interpolateIfRequiredAndExecute(StatementType.QUERY);
8193
}
8294

8395
@Override
8496
public int executeUpdate() throws SQLException {
8597
LoggingUtil.log(LogLevel.DEBUG, "public int executeUpdate()");
98+
checkIfBatchOperation();
8699
interpolateIfRequiredAndExecute(StatementType.UPDATE);
87100
return (int) resultSet.getUpdateCount();
88101
}
89102

103+
@Override
104+
public int[] executeBatch() {
105+
LoggingUtil.log(LogLevel.DEBUG, "public int executeBatch()");
106+
int[] updateCount = new int[databricksBatchParameterMetaData.size()];
107+
108+
for (int i = 0; i < databricksBatchParameterMetaData.size(); i++) {
109+
DatabricksParameterMetaData databricksParameterMetaData =
110+
databricksBatchParameterMetaData.get(i);
111+
try {
112+
executeInternal(
113+
sql, databricksParameterMetaData.getParameterBindings(), StatementType.UPDATE, false);
114+
updateCount[i] = (int) resultSet.getUpdateCount();
115+
} catch (SQLException e) {
116+
LoggingUtil.log(LogLevel.ERROR, e.getMessage());
117+
updateCount[i] = -1;
118+
}
119+
}
120+
return updateCount;
121+
}
122+
90123
@Override
91124
public void setNull(int parameterIndex, int sqlType) throws SQLException {
92125
LoggingUtil.log(LogLevel.DEBUG, "public void setNull(int parameterIndex, int sqlType)");
@@ -264,15 +297,24 @@ private void setObject(int parameterIndex, Object x, String databricksType) {
264297
public boolean execute() throws SQLException {
265298
LoggingUtil.log(LogLevel.DEBUG, "public boolean execute()");
266299
checkIfClosed();
300+
checkIfBatchOperation();
267301
interpolateIfRequiredAndExecute(StatementType.SQL);
268302
return shouldReturnResultSet(sql);
269303
}
270304

271305
@Override
272-
public void addBatch() throws SQLException {
306+
public void addBatch() {
273307
LoggingUtil.log(LogLevel.DEBUG, "public void addBatch()");
274-
throw new UnsupportedOperationException(
275-
"Not implemented in DatabricksPreparedStatement - addBatch()");
308+
this.databricksBatchParameterMetaData.add(databricksParameterMetaData);
309+
this.databricksParameterMetaData = new DatabricksParameterMetaData();
310+
}
311+
312+
@Override
313+
public void clearBatch() throws DatabricksSQLException {
314+
LoggingUtil.log(LogLevel.DEBUG, "public void clearBatch()");
315+
checkIfClosed();
316+
this.databricksParameterMetaData = new DatabricksParameterMetaData();
317+
this.databricksBatchParameterMetaData = new ArrayList<>();
276318
}
277319

278320
@Override

src/main/java/com/databricks/jdbc/core/DatabricksStatement.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,10 @@ public void handleResultSetClose(IDatabricksResultSet resultSet) throws SQLExcep
429429
}
430430

431431
DatabricksResultSet executeInternal(
432-
String sql, Map<Integer, ImmutableSqlParameter> params, StatementType statementType)
432+
String sql,
433+
Map<Integer, ImmutableSqlParameter> params,
434+
StatementType statementType,
435+
boolean closeStatement)
433436
throws SQLException {
434437
String stackTraceMessage =
435438
format(
@@ -444,7 +447,9 @@ DatabricksResultSet executeInternal(
444447
? futureResultSet.get() // Wait indefinitely when timeout is 0
445448
: futureResultSet.get(timeoutInSeconds, TimeUnit.SECONDS);
446449
} catch (TimeoutException e) {
447-
this.close(); // Close the statement
450+
if (closeStatement) {
451+
this.close(); // Close the statement
452+
}
448453
futureResultSet.cancel(true); // Cancel execution run
449454
throw new DatabricksTimeoutException(
450455
"Statement execution timed-out. " + stackTraceMessage, e);
@@ -466,6 +471,12 @@ DatabricksResultSet executeInternal(
466471
return resultSet;
467472
}
468473

474+
DatabricksResultSet executeInternal(
475+
String sql, Map<Integer, ImmutableSqlParameter> params, StatementType statementType)
476+
throws SQLException {
477+
return executeInternal(sql, params, statementType, true);
478+
}
479+
469480
// Todo : Add timeout tests in the subsequent PR
470481
CompletableFuture<DatabricksResultSet> getFutureResult(
471482
String sql, Map<Integer, ImmutableSqlParameter> params, StatementType statementType) {

src/test/java/com/databricks/jdbc/core/DatabricksPreparedStatementTest.java

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ public class DatabricksPreparedStatementTest {
3333
private static final String WAREHOUSE_ID = "erg6767gg";
3434
private static final String STATEMENT =
3535
"SELECT * FROM orders WHERE user_id = ? AND shard = ? AND region_code = ? AND namespace = ?";
36+
private static final String BATCH_STATEMENT =
37+
"INSERT INTO orders (user_id, shard, region_code, namespace) VALUES (?, ?, ?, ?)";
3638
private static final String JDBC_URL =
3739
"jdbc:databricks://adb-565757575.18.azuredatabricks.net:4423/default;transportMode=http;ssl=1;AuthMech=3;httpPath=/sql/1.0/warehouses/erg6767gg;";
3840
private static final String JDBC_URL_WITH_MANY_PARAMETERS =
@@ -147,6 +149,43 @@ public void testExecuteUpdateStatement() throws Exception {
147149
assertTrue(statement.isClosed());
148150
}
149151

152+
@Test
153+
public void testExecuteBatchStatement() throws Exception {
154+
IDatabricksConnectionContext connectionContext =
155+
DatabricksConnectionContext.parse(JDBC_URL, new Properties());
156+
DatabricksConnection connection = new DatabricksConnection(connectionContext, client);
157+
DatabricksPreparedStatement statement =
158+
new DatabricksPreparedStatement(connection, BATCH_STATEMENT);
159+
160+
// Setting to execute a batch of 4 statements
161+
for (int i = 1; i <= 4; i++) {
162+
statement.setLong(1, (long) 100);
163+
statement.setShort(2, (short) 10);
164+
statement.setByte(3, (byte) 15);
165+
statement.setString(4, "value");
166+
statement.addBatch();
167+
}
168+
169+
when(client.executeStatement(
170+
eq(BATCH_STATEMENT),
171+
eq(new Warehouse(WAREHOUSE_ID)),
172+
any(HashMap.class),
173+
eq(StatementType.UPDATE),
174+
any(IDatabricksSession.class),
175+
eq(statement)))
176+
.thenReturn(resultSet);
177+
178+
when(resultSet.getUpdateCount()).thenReturn(1L);
179+
180+
int[] expectedCountsResult = {1, 1, 1, 1};
181+
int[] updateCounts = statement.executeBatch();
182+
183+
assertArrayEquals(expectedCountsResult, updateCounts);
184+
assertFalse(statement.isClosed());
185+
statement.close();
186+
assertTrue(statement.isClosed());
187+
}
188+
150189
public static ImmutableSqlParameter getSqlParam(
151190
int parameterIndex, Object x, String databricksType) {
152191
return ImmutableSqlParameter.builder()
@@ -412,7 +451,6 @@ void testUnsupportedMethods() throws DatabricksSQLException {
412451
assertThrows(
413452
UnsupportedOperationException.class, () -> preparedStatement.setTime(1, null, null));
414453
assertThrows(UnsupportedOperationException.class, () -> preparedStatement.setBytes(1, null));
415-
assertThrows(UnsupportedOperationException.class, () -> preparedStatement.addBatch());
416454
assertThrows(
417455
SQLFeatureNotSupportedException.class, () -> preparedStatement.setObject(1, null, null));
418456
assertThrows(

src/test/java/com/databricks/jdbc/local/DriverTester.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,4 +250,50 @@ void testAllPurposeClusters_errorHandling() throws Exception {
250250
con.close();
251251
System.out.println("Connection closed successfully......");
252252
}
253+
254+
@Test
255+
void testSimbaBatchFunction() throws Exception {
256+
257+
String jdbcUrl =
258+
"jdbc:databricks://e2-dogfood.staging.cloud.databricks.com:443/default;transportMode=http;ssl=1;AuthMech=3;httpPath=/sql/1.0/warehouses/dd43ee29fedd958d;";
259+
Connection con = DriverManager.getConnection(jdbcUrl, "[email protected]", "xx");
260+
System.out.println("Connection established......");
261+
262+
//
263+
// Batch Statement Testing
264+
//
265+
String sqlStatement =
266+
"INSERT INTO ___________________first.`jprakash-test`.diamonds (carat, cut, color, clarity) VALUES (?, ?, ?, ?)";
267+
PreparedStatement pstmt = con.prepareStatement(sqlStatement);
268+
for (int i = 1; i <= 3; i++) {
269+
pstmt.setFloat(1, 0.23f);
270+
pstmt.setString(2, "OK");
271+
pstmt.setString(3, "E");
272+
pstmt.setString(4, "SI2");
273+
pstmt.addBatch();
274+
}
275+
276+
pstmt.setString(1, "Shaama");
277+
pstmt.setString(2, "Bad");
278+
pstmt.setString(3, "F");
279+
pstmt.setString(4, "SI6");
280+
pstmt.addBatch();
281+
282+
for (int i = 1; i <= 3; i++) {
283+
pstmt.setFloat(1, 0.23f);
284+
pstmt.setString(2, "Bad");
285+
pstmt.setString(3, "F");
286+
pstmt.setString(4, "SI6");
287+
pstmt.addBatch();
288+
}
289+
290+
// Execute the batch
291+
int[] updateCounts = pstmt.executeBatch();
292+
293+
// Process the update counts
294+
for (int count : updateCounts) {
295+
System.out.println("Update count: " + count);
296+
}
297+
con.close();
298+
}
253299
}

0 commit comments

Comments
 (0)