Skip to content

Commit 2524b31

Browse files
Add batching support for PreparedStatement
Cherry-pick of trinodb/trino@13658e2 and trinodb/trino@26f688f Co-authored-by: ebyhr
1 parent f3c6b7c commit 2524b31

File tree

5 files changed

+205
-12
lines changed

5 files changed

+205
-12
lines changed

presto-jdbc/pom.xml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@
123123
<scope>test</scope>
124124
</dependency>
125125

126+
<dependency>
127+
<groupId>com.facebook.presto</groupId>
128+
<artifactId>presto-memory</artifactId>
129+
<scope>test</scope>
130+
</dependency>
131+
126132
<dependency>
127133
<groupId>com.facebook.presto</groupId>
128134
<artifactId>presto-main-base</artifactId>
@@ -218,11 +224,13 @@
218224
<artifactId>jjwt-api</artifactId>
219225
<scope>test</scope>
220226
</dependency>
227+
221228
<dependency>
222229
<groupId>io.jsonwebtoken</groupId>
223230
<artifactId>jjwt-impl</artifactId>
224231
<scope>test</scope>
225232
</dependency>
233+
226234
<dependency>
227235
<groupId>io.jsonwebtoken</groupId>
228236
<artifactId>jjwt-jackson</artifactId>

presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDatabaseMetaData.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,8 +1144,7 @@ public boolean insertsAreDetected(int type)
11441144
public boolean supportsBatchUpdates()
11451145
throws SQLException
11461146
{
1147-
// TODO: support batch updates
1148-
return false;
1147+
return true;
11491148
}
11501149

11511150
@Override

presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoPreparedStatement.java

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.math.BigDecimal;
2424
import java.net.URL;
2525
import java.sql.Array;
26+
import java.sql.BatchUpdateException;
2627
import java.sql.Blob;
2728
import java.sql.Clob;
2829
import java.sql.Date;
@@ -42,6 +43,7 @@
4243
import java.sql.Timestamp;
4344
import java.sql.Types;
4445
import java.util.ArrayList;
46+
import java.util.Arrays;
4547
import java.util.Calendar;
4648
import java.util.HashMap;
4749
import java.util.List;
@@ -72,9 +74,11 @@ public class PrestoPreparedStatement
7274
implements PreparedStatement
7375
{
7476
private final Map<Integer, String> parameters = new HashMap<>();
77+
private final List<List<String>> batchValues = new ArrayList<>();
7578
private final String statementName;
7679
private final String originalSql;
7780
private boolean isClosed;
81+
private boolean isBatch;
7882

7983
PrestoPreparedStatement(PrestoConnection connection, String statementName, String sql)
8084
throws SQLException
@@ -101,7 +105,8 @@ public void close()
101105
public ResultSet executeQuery()
102106
throws SQLException
103107
{
104-
if (!super.execute(getExecuteSql())) {
108+
requireNonBatchStatement();
109+
if (!super.execute(getExecuteSql(statementName, toValues(parameters)))) {
105110
throw new SQLException("Prepared SQL statement is not a query: " + originalSql);
106111
}
107112
return getResultSet();
@@ -111,14 +116,16 @@ public ResultSet executeQuery()
111116
public int executeUpdate()
112117
throws SQLException
113118
{
119+
requireNonBatchStatement();
114120
return Ints.saturatedCast(executeLargeUpdate());
115121
}
116122

117123
@Override
118124
public long executeLargeUpdate()
119125
throws SQLException
120126
{
121-
if (super.execute(getExecuteSql())) {
127+
requireNonBatchStatement();
128+
if (super.execute(getExecuteSql(statementName, toValues(parameters)))) {
122129
throw new SQLException("Prepared SQL is not an update statement: " + originalSql);
123130
}
124131
return getLargeUpdateCount();
@@ -128,7 +135,8 @@ public long executeLargeUpdate()
128135
public boolean execute()
129136
throws SQLException
130137
{
131-
return super.execute(getExecuteSql());
138+
requireNonBatchStatement();
139+
return super.execute(getExecuteSql(statementName, toValues(parameters)));
132140
}
133141

134142
@Override
@@ -430,7 +438,41 @@ else if (x instanceof Timestamp) {
430438
public void addBatch()
431439
throws SQLException
432440
{
433-
throw new NotImplementedException("PreparedStatement", "addBatch");
441+
checkOpen();
442+
batchValues.add(toValues(parameters));
443+
isBatch = true;
444+
}
445+
446+
@Override
447+
public void clearBatch()
448+
throws SQLException
449+
{
450+
checkOpen();
451+
batchValues.clear();
452+
isBatch = false;
453+
}
454+
455+
@Override
456+
public int[] executeBatch()
457+
throws SQLException
458+
{
459+
try {
460+
int[] batchUpdateCounts = new int[batchValues.size()];
461+
for (int i = 0; i < batchValues.size(); i++) {
462+
try {
463+
super.execute(getExecuteSql(statementName, batchValues.get(i)));
464+
batchUpdateCounts[i] = getUpdateCount();
465+
}
466+
catch (SQLException e) {
467+
long[] updateCounts = Arrays.stream(batchUpdateCounts).mapToLong(j -> j).toArray();
468+
throw new BatchUpdateException(e.getMessage(), e.getSQLState(), e.getErrorCode(), updateCounts, e.getCause());
469+
}
470+
}
471+
return batchUpdateCounts;
472+
}
473+
finally {
474+
clearBatch();
475+
}
434476
}
435477

436478
@Override
@@ -759,27 +801,34 @@ private void setParameter(int parameterIndex, String value)
759801
parameters.put(parameterIndex - 1, value);
760802
}
761803

762-
private void formatParametersTo(StringBuilder builder)
804+
private static List<String> toValues(Map<Integer, String> parameters)
763805
throws SQLException
764806
{
765-
List<String> values = new ArrayList<>();
807+
ImmutableList.Builder<String> values = ImmutableList.builder();
766808
for (int index = 0; index < parameters.size(); index++) {
767809
if (!parameters.containsKey(index)) {
768810
throw new SQLException("No value specified for parameter " + (index + 1));
769811
}
770812
values.add(parameters.get(index));
771813
}
772-
Joiner.on(", ").appendTo(builder, values);
814+
return values.build();
773815
}
774816

775-
private String getExecuteSql()
817+
private void requireNonBatchStatement()
776818
throws SQLException
819+
{
820+
if (isBatch) {
821+
throw new SQLException("Batch prepared statement must be executed using executeBatch method");
822+
}
823+
}
824+
825+
private static String getExecuteSql(String statementName, List<String> values)
777826
{
778827
StringBuilder sql = new StringBuilder();
779828
sql.append("EXECUTE ").append(statementName);
780-
if (!parameters.isEmpty()) {
829+
if (!values.isEmpty()) {
781830
sql.append(" USING ");
782-
formatParametersTo(sql);
831+
Joiner.on(", ").appendTo(sql, values);
783832
}
784833
return sql.toString();
785834
}

presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestJdbcPreparedStatement.java

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import com.facebook.airlift.log.Logging;
1717
import com.facebook.presto.plugin.blackhole.BlackHolePlugin;
18+
import com.facebook.presto.plugin.memory.MemoryPlugin;
1819
import com.facebook.presto.server.testing.TestingPrestoServer;
1920
import org.testng.annotations.AfterClass;
2021
import org.testng.annotations.BeforeClass;
@@ -40,10 +41,13 @@
4041

4142
import static com.facebook.presto.jdbc.TestPrestoDriver.closeQuietly;
4243
import static com.facebook.presto.jdbc.TestPrestoDriver.waitForNodeRefresh;
44+
import static com.facebook.presto.jdbc.TestingJdbcUtils.list;
45+
import static com.facebook.presto.jdbc.TestingJdbcUtils.readRows;
4346
import static com.google.common.base.Strings.repeat;
4447
import static com.google.common.primitives.Ints.asList;
4548
import static java.lang.String.format;
4649
import static java.nio.charset.StandardCharsets.UTF_8;
50+
import static org.assertj.core.api.Assertions.assertThat;
4751
import static org.assertj.core.api.Assertions.assertThatThrownBy;
4852
import static org.testng.Assert.assertEquals;
4953
import static org.testng.Assert.assertFalse;
@@ -61,7 +65,9 @@ public void setup()
6165
Logging.initialize();
6266
server = new TestingPrestoServer();
6367
server.installPlugin(new BlackHolePlugin());
68+
server.installPlugin(new MemoryPlugin());
6469
server.createCatalog("blackhole", "blackhole");
70+
server.createCatalog("memory", "memory");
6571
waitForNodeRefresh(server);
6672

6773
try (Connection connection = createConnection();
@@ -636,6 +642,88 @@ public void testInvalidConversions()
636642
assertInvalidConversion((ps, i) -> ps.setObject(i, "abc", Types.SMALLINT), "Cannot convert instance of java.lang.String to SQL type " + Types.SMALLINT);
637643
}
638644

645+
@Test
646+
public void testExecuteBatch()
647+
throws Exception
648+
{
649+
try (Connection connection = createConnection("memory", "default")) {
650+
try (Statement statement = connection.createStatement()) {
651+
statement.execute("CREATE TABLE test_execute_batch(c_int integer)");
652+
}
653+
654+
try (PreparedStatement preparedStatement = connection.prepareStatement(
655+
"INSERT INTO test_execute_batch VALUES (?)")) {
656+
// Run executeBatch before addBatch
657+
assertEquals(preparedStatement.executeBatch(), new int[] {});
658+
659+
for (int i = 0; i < 3; i++) {
660+
preparedStatement.setInt(1, i);
661+
preparedStatement.addBatch();
662+
}
663+
assertEquals(preparedStatement.executeBatch(), new int[] {1, 1, 1});
664+
665+
try (Statement statement = connection.createStatement()) {
666+
ResultSet resultSet = statement.executeQuery("SELECT c_int FROM test_execute_batch");
667+
assertThat(readRows(resultSet))
668+
.containsExactlyInAnyOrder(
669+
list(0),
670+
list(1),
671+
list(2));
672+
}
673+
674+
// Make sure the above executeBatch cleared existing batch
675+
assertEquals(preparedStatement.executeBatch(), new int[] {});
676+
677+
// clearBatch removes added batch and cancel batch mode
678+
preparedStatement.setBoolean(1, true);
679+
preparedStatement.clearBatch();
680+
assertEquals(preparedStatement.executeBatch(), new int[] {});
681+
682+
preparedStatement.setInt(1, 1);
683+
assertEquals(preparedStatement.executeUpdate(), 1);
684+
}
685+
686+
try (Statement statement = connection.createStatement()) {
687+
statement.execute("DROP TABLE test_execute_batch");
688+
}
689+
}
690+
}
691+
692+
@Test
693+
public void testInvalidExecuteBatch()
694+
throws Exception
695+
{
696+
try (Connection connection = createConnection("blackhole", "blackhole")) {
697+
try (Statement statement = connection.createStatement()) {
698+
statement.execute("CREATE TABLE test_invalid_execute_batch(c_int integer)");
699+
}
700+
701+
try (PreparedStatement statement = connection.prepareStatement(
702+
"INSERT INTO test_invalid_execute_batch VALUES (?)")) {
703+
statement.setInt(1, 1);
704+
statement.addBatch();
705+
706+
String message = "Batch prepared statement must be executed using executeBatch method";
707+
assertThatThrownBy(statement::executeQuery)
708+
.isInstanceOf(SQLException.class)
709+
.hasMessage(message);
710+
assertThatThrownBy(statement::executeUpdate)
711+
.isInstanceOf(SQLException.class)
712+
.hasMessage(message);
713+
assertThatThrownBy(statement::executeLargeUpdate)
714+
.isInstanceOf(SQLException.class)
715+
.hasMessage(message);
716+
assertThatThrownBy(statement::execute)
717+
.isInstanceOf(SQLException.class)
718+
.hasMessage(message);
719+
}
720+
721+
try (Statement statement = connection.createStatement()) {
722+
statement.execute("DROP TABLE test_invalid_execute_batch");
723+
}
724+
}
725+
}
726+
639727
private void assertInvalidConversion(Binder binder, String message)
640728
{
641729
assertThatThrownBy(() -> assertParameter(null, Types.NULL, binder))
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.presto.jdbc;
15+
16+
import com.google.common.collect.ImmutableList;
17+
18+
import java.sql.ResultSet;
19+
import java.sql.SQLException;
20+
import java.util.ArrayList;
21+
import java.util.List;
22+
23+
import static java.util.Arrays.asList;
24+
25+
public class TestingJdbcUtils
26+
{
27+
private TestingJdbcUtils() {}
28+
29+
public static List<List<Object>> readRows(ResultSet rs)
30+
throws SQLException
31+
{
32+
ImmutableList.Builder<List<Object>> rows = ImmutableList.builder();
33+
int columnCount = rs.getMetaData().getColumnCount();
34+
while (rs.next()) {
35+
List<Object> row = new ArrayList<>();
36+
for (int i = 1; i <= columnCount; i++) {
37+
row.add(rs.getObject(i));
38+
}
39+
rows.add(row);
40+
}
41+
return rows.build();
42+
}
43+
44+
@SafeVarargs
45+
public static <T> List<T> list(T... elements)
46+
{
47+
return asList(elements);
48+
}
49+
}

0 commit comments

Comments
 (0)