Skip to content

Commit 90e5019

Browse files
authored
Add interpolator flag for parameters (#293)
* Add sql interpolator behind a flag * fmt * fmt * Address comments * Remove debug statement
1 parent 978878b commit 90e5019

File tree

11 files changed

+315
-48
lines changed

11 files changed

+315
-48
lines changed

src/main/java/com/databricks/jdbc/client/impl/helper/MetadataResultConstants.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ public class MetadataResultConstants {
251251
IS_AUTO_INCREMENT_COLUMN,
252252
USER_DATA_TYPE_COLUMN,
253253
IS_GENERATED_COLUMN);
254-
public static String NULL_STRING = "null";
254+
public static String NULL_STRING = "NULL";
255255

256256
public static List<ResultColumn> TYPE_INFO_COLUMNS =
257257
List.of(
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package com.databricks.jdbc.commons.util;
2+
3+
import static com.databricks.jdbc.client.impl.helper.MetadataResultConstants.NULL_STRING;
4+
5+
import com.databricks.jdbc.core.DatabricksValidationException;
6+
import com.databricks.jdbc.core.ImmutableSqlParameter;
7+
import java.util.Map;
8+
9+
public class SQLInterpolator {
10+
private static String escapeApostrophes(String input) {
11+
if (input == null) return null;
12+
return input.replace("'", "''");
13+
}
14+
15+
private static String formatObject(ImmutableSqlParameter object) {
16+
if (object == null || object.value() == null) {
17+
return NULL_STRING;
18+
} else if (object.value() instanceof String) {
19+
return "'" + escapeApostrophes((String) object.value()) + "'";
20+
} else {
21+
return object.value().toString();
22+
}
23+
}
24+
25+
private static int countPlaceholders(String sql) {
26+
int count = 0;
27+
for (char c : sql.toCharArray()) {
28+
if (c == '?') {
29+
count++;
30+
}
31+
}
32+
return count;
33+
}
34+
35+
/**
36+
* Interpolates the given SQL string by replacing placeholders with the provided parameters.
37+
*
38+
* <p>This method splits the SQL string by placeholders (question marks) and replaces each
39+
* placeholder with the corresponding parameter from the provided map. The map keys are 1-based
40+
* indexes, aligning with the SQL parameter positions.
41+
*
42+
* @param sql the SQL string containing placeholders ('?') to be replaced.
43+
* @param params a map of parameters where the key is the 1-based index of the placeholder in the
44+
* SQL string, and the value is the corresponding {@link ImmutableSqlParameter}.
45+
* @return the interpolated SQL string with placeholders replaced by the corresponding parameters.
46+
* @throws DatabricksValidationException if the number of placeholders in the SQL string does not
47+
* match the number of parameters provided in the map.
48+
*/
49+
public static String interpolateSQL(String sql, Map<Integer, ImmutableSqlParameter> params)
50+
throws DatabricksValidationException {
51+
String[] parts = sql.split("\\?");
52+
if (countPlaceholders(sql) != params.size()) {
53+
throw new DatabricksValidationException(
54+
"Parameter count does not match. Provide equal number of parameters as placeholders. SQL "
55+
+ sql);
56+
}
57+
StringBuilder sb = new StringBuilder();
58+
for (int i = 0; i < parts.length; i++) {
59+
sb.append(parts[i]);
60+
if (i < params.size()) {
61+
sb.append(formatObject(params.get(i + 1))); // because we have 1 based index in params
62+
}
63+
}
64+
return sb.toString();
65+
}
66+
}

src/main/java/com/databricks/jdbc/core/DatabricksPreparedStatement.java

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.databricks.jdbc.core;
22

3+
import static com.databricks.jdbc.commons.util.SQLInterpolator.interpolateSQL;
34
import static com.databricks.jdbc.core.DatabricksTypeUtil.*;
45
import static com.databricks.jdbc.driver.DatabricksJdbcConstants.*;
56

@@ -21,12 +22,15 @@
2122
public class DatabricksPreparedStatement extends DatabricksStatement implements PreparedStatement {
2223
private final String sql;
2324
private final DatabricksParameterMetaData databricksParameterMetaData;
25+
private final boolean supportManyParameters;
2426

2527
private final int CHUNK_SIZE = 8192;
2628

2729
public DatabricksPreparedStatement(DatabricksConnection connection, String sql) {
2830
super(connection);
2931
this.sql = sql;
32+
this.supportManyParameters =
33+
connection.getSession().getConnectionContext().supportManyParameters();
3034
this.databricksParameterMetaData = new DatabricksParameterMetaData();
3135
}
3236

@@ -72,17 +76,14 @@ private byte[] readByteStream(InputStream x, int length) throws SQLException {
7276

7377
@Override
7478
public ResultSet executeQuery() throws SQLException {
75-
7679
LoggingUtil.log(LogLevel.DEBUG, "public ResultSet executeQuery()");
77-
return executeInternal(
78-
sql, this.databricksParameterMetaData.getParameterBindings(), StatementType.QUERY);
80+
return interpolateIfRequiredAndExecute(StatementType.QUERY);
7981
}
8082

8183
@Override
8284
public int executeUpdate() throws SQLException {
8385
LoggingUtil.log(LogLevel.DEBUG, "public int executeUpdate()");
84-
executeInternal(
85-
sql, this.databricksParameterMetaData.getParameterBindings(), StatementType.UPDATE);
86+
interpolateIfRequiredAndExecute(StatementType.UPDATE);
8687
return (int) resultSet.getUpdateCount();
8788
}
8889

@@ -263,7 +264,7 @@ private void setObject(int parameterIndex, Object x, String databricksType) {
263264
public boolean execute() throws SQLException {
264265
LoggingUtil.log(LogLevel.DEBUG, "public boolean execute()");
265266
checkIfClosed();
266-
executeInternal(sql, databricksParameterMetaData.getParameterBindings(), StatementType.SQL);
267+
interpolateIfRequiredAndExecute(StatementType.SQL);
267268
return shouldReturnResultSet(sql);
268269
}
269270

@@ -646,4 +647,17 @@ public boolean execute(String sql, int[] columnIndexes) throws SQLException {
646647
public boolean execute(String sql, String[] columnNames) throws SQLException {
647648
throw new DatabricksSQLException("Method not supported in PreparedStatement");
648649
}
650+
651+
private DatabricksResultSet interpolateIfRequiredAndExecute(StatementType statementType)
652+
throws SQLException {
653+
String interpolatedSql =
654+
this.supportManyParameters
655+
? interpolateSQL(sql, this.databricksParameterMetaData.getParameterBindings())
656+
: sql;
657+
Map<Integer, ImmutableSqlParameter> paramMap =
658+
this.supportManyParameters
659+
? new HashMap<>()
660+
: this.databricksParameterMetaData.getParameterBindings();
661+
return executeInternal(interpolatedSql, paramMap, statementType);
662+
}
649663
}

src/main/java/com/databricks/jdbc/core/DatabricksResultSetMetaData.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.databricks.jdbc.core;
22

3+
import static com.databricks.jdbc.client.impl.helper.MetadataResultConstants.NULL_STRING;
34
import static com.databricks.jdbc.client.impl.thrift.commons.DatabricksThriftHelper.getTypeFromTypeDesc;
45
import static com.databricks.jdbc.driver.DatabricksJdbcConstants.VOLUME_OPERATION_STATUS_COLUMN_NAME;
56

@@ -29,7 +30,6 @@ public class DatabricksResultSetMetaData implements ResultSetMetaData {
2930
private final long totalRows;
3031
private Long chunkCount;
3132
private static final String DEFAULT_CATALOGUE_NAME = "Spark";
32-
private static final String NULL_STRING = "null";
3333

3434
// TODO: Add handling for Arrow stream results
3535

src/main/java/com/databricks/jdbc/core/DatabricksStatement.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ DatabricksResultSet executeInternal(
431431
String stackTraceMessage =
432432
format(
433433
"DatabricksResultSet executeInternal(String sql = %s,Map<Integer, ImmutableSqlParameter> params = {%s}, StatementType statementType = {%s})",
434-
sql, params.toString(), statementType.toString());
434+
sql, params, statementType);
435435
LoggingUtil.log(LogLevel.DEBUG, stackTraceMessage);
436436
CompletableFuture<DatabricksResultSet> futureResultSet =
437437
getFutureResult(sql, params, statementType);

src/main/java/com/databricks/jdbc/driver/DatabricksConnectionContext.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,4 +531,9 @@ public int getIdleHttpConnectionExpiry() {
531531
return Integer.parseInt(
532532
getParameter(IDLE_HTTP_CONNECTION_EXPIRY, DEFAULT_IDLE_HTTP_CONNECTION_EXPIRY));
533533
}
534+
535+
@Override
536+
public boolean supportManyParameters() {
537+
return getParameter(SUPPORT_MANY_PARAMETERS, "0").equals("1");
538+
}
534539
}

src/main/java/com/databricks/jdbc/driver/DatabricksJdbcConstants.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ public enum FakeServiceType {
216216
static final String RATE_LIMIT_RETRY_TIMEOUT = "RateLimitRetryTimeout";
217217
public static final String DEFAULT_RATE_LIMIT_RETRY_TIMEOUT = "120";
218218
static final String IDLE_HTTP_CONNECTION_EXPIRY = "IdleHttpConnectionExpiry";
219+
static final String SUPPORT_MANY_PARAMETERS = "supportManyParameters";
219220
public static final String DEFAULT_IDLE_HTTP_CONNECTION_EXPIRY = "60";
220221
public static final String CLOUD_FETCH_THREAD_POOL_SIZE = "cloudFetchThreadPoolSize";
221222
public static final int CLOUD_FETCH_THREAD_POOL_SIZE_DEFAULT = 16;

src/main/java/com/databricks/jdbc/driver/IDatabricksConnectionContext.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,4 +144,6 @@ public static AuthMech parseAuthMech(String authMech) {
144144
int getRateLimitRetryTimeout();
145145

146146
int getIdleHttpConnectionExpiry();
147+
148+
boolean supportManyParameters();
147149
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package com.databricks.jdbc.commons;
2+
3+
import static com.databricks.jdbc.TestConstants.TEST_STRING;
4+
import static com.databricks.jdbc.core.DatabricksPreparedStatementTest.getSqlParam;
5+
import static org.junit.jupiter.api.Assertions.assertEquals;
6+
import static org.junit.jupiter.api.Assertions.assertThrows;
7+
8+
import com.databricks.jdbc.commons.util.SQLInterpolator;
9+
import com.databricks.jdbc.core.DatabricksTypeUtil;
10+
import com.databricks.jdbc.core.DatabricksValidationException;
11+
import com.databricks.jdbc.core.ImmutableSqlParameter;
12+
import java.util.HashMap;
13+
import java.util.Map;
14+
import org.junit.jupiter.api.Test;
15+
16+
public class SQLInterpolatorTest {
17+
18+
@Test
19+
public void testInterpolateSQLWithStrings() throws DatabricksValidationException {
20+
String sql = "SELECT * FROM users WHERE name = ? AND city = ?";
21+
Map<Integer, ImmutableSqlParameter> params = new HashMap<>();
22+
params.put(1, getSqlParam(1, "Alice", DatabricksTypeUtil.STRING));
23+
params.put(2, getSqlParam(2, "Wonderland", DatabricksTypeUtil.STRING));
24+
String expected = "SELECT * FROM users WHERE name = 'Alice' AND city = 'Wonderland'";
25+
assertEquals(expected, SQLInterpolator.interpolateSQL(sql, params));
26+
}
27+
28+
@Test
29+
public void testInterpolateSQLWithMixedTypes() throws DatabricksValidationException {
30+
String sql = "INSERT INTO sales (id, amount, active) VALUES (?, ?, ?)";
31+
Map<Integer, ImmutableSqlParameter> params = new HashMap<>();
32+
params.put(1, getSqlParam(1, 101, DatabricksTypeUtil.INT));
33+
params.put(2, getSqlParam(2, 19.95, DatabricksTypeUtil.FLOAT));
34+
params.put(3, getSqlParam(3, true, DatabricksTypeUtil.BOOLEAN));
35+
String expected = "INSERT INTO sales (id, amount, active) VALUES (101, 19.95, true)";
36+
assertEquals(expected, SQLInterpolator.interpolateSQL(sql, params));
37+
}
38+
39+
@Test
40+
public void testInterpolateSQLWithNullValues() throws DatabricksValidationException {
41+
String sql = "UPDATE products SET price = ? WHERE id = ?";
42+
Map<Integer, ImmutableSqlParameter> params = new HashMap<>();
43+
params.put(1, getSqlParam(1, null, DatabricksTypeUtil.NULL));
44+
params.put(2, getSqlParam(2, 200, DatabricksTypeUtil.INT));
45+
String expected = "UPDATE products SET price = NULL WHERE id = 200";
46+
assertEquals(expected, SQLInterpolator.interpolateSQL(sql, params));
47+
}
48+
49+
@Test
50+
public void testParameterMismatch() {
51+
String sql = "DELETE FROM log WHERE date = ?";
52+
Map<Integer, ImmutableSqlParameter> params = new HashMap<>(); // no parameters added
53+
assertThrows(
54+
DatabricksValidationException.class,
55+
() -> {
56+
SQLInterpolator.interpolateSQL(sql, params);
57+
});
58+
}
59+
60+
@Test
61+
public void testExtraParameters() {
62+
String sql = "SELECT * FROM clients WHERE client_id = ?";
63+
Map<Integer, ImmutableSqlParameter> params = new HashMap<>();
64+
params.put(1, getSqlParam(1, 300, DatabricksTypeUtil.INT));
65+
params.put(2, getSqlParam(2, TEST_STRING, DatabricksTypeUtil.STRING)); // extra parameter
66+
assertThrows(
67+
DatabricksValidationException.class,
68+
() -> {
69+
SQLInterpolator.interpolateSQL(sql, params);
70+
});
71+
}
72+
73+
@Test
74+
public void testEscapedValues() throws DatabricksValidationException {
75+
String sql = "UPDATE products SET price = ? WHERE id = ?";
76+
Map<Integer, ImmutableSqlParameter> params = new HashMap<>();
77+
params.put(1, getSqlParam(1, "O'Reilly", DatabricksTypeUtil.STRING));
78+
params.put(2, getSqlParam(2, 200, DatabricksTypeUtil.INT));
79+
String expected = "UPDATE products SET price = 'O''Reilly' WHERE id = 200";
80+
assertEquals(expected, SQLInterpolator.interpolateSQL(sql, params));
81+
}
82+
}

0 commit comments

Comments
 (0)