diff --git a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudConnection.java b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudConnection.java index 4e3a20d1..0e5045e9 100644 --- a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudConnection.java +++ b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudConnection.java @@ -238,6 +238,11 @@ public void close() { } } + @Unstable + public void cancel(String queryId) { + getExecutor().cancel(queryId); + } + @Override public boolean isClosed() { return closed.get(); diff --git a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatement.java b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatement.java index 90140f71..c07df2b7 100644 --- a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatement.java +++ b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatement.java @@ -23,6 +23,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import com.google.protobuf.ByteString; +import com.salesforce.datacloud.jdbc.core.listener.AsyncQueryStatusListener; import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; import com.salesforce.datacloud.jdbc.util.ArrowUtils; import com.salesforce.datacloud.jdbc.util.Constants; @@ -53,6 +54,7 @@ import java.util.Calendar; import java.util.Map; import java.util.TimeZone; +import lombok.SneakyThrows; import lombok.experimental.UtilityClass; import lombok.extern.slf4j.Slf4j; import lombok.val; @@ -86,18 +88,19 @@ private void setParameter(int parameterIndex, int sqlType, T value) throws S @Override public ResultSet executeQuery(String sql) throws SQLException { - this.sql = sql; - return executeQuery(); + throw new DataCloudJDBCException( + "Per the JDBC specification this method cannot be called on a PreparedStatement, use DataCloudPreparedStatement::executeQuery() instead."); } @Override public boolean execute(String sql) throws SQLException { - resultSet = executeQuery(sql); - return true; + throw new DataCloudJDBCException( + "Per the JDBC specification this method cannot be called on a PreparedStatement, use DataCloudPreparedStatement::execute() instead."); } @Override - public ResultSet executeQuery() throws SQLException { + @SneakyThrows + protected HyperGrpcClientExecutor getQueryExecutor() { final byte[] encodedRow; try { encodedRow = ArrowUtils.toArrowByteArray(parameterManager.getParameters(), calendar); @@ -105,14 +108,26 @@ public ResultSet executeQuery() throws SQLException { throw new DataCloudJDBCException("Failed to encode parameters on prepared statement", e); } - val queryParamBuilder = QueryParam.newBuilder() + val preparedQueryParams = QueryParam.newBuilder() .setParamStyle(QueryParam.ParameterStyle.QUESTION_MARK) .setArrowParameters(QueryParameterArrow.newBuilder() .setData(ByteString.copyFrom(encodedRow)) .build()) .build(); - val client = getQueryExecutor(queryParamBuilder); + return getQueryExecutor(preparedQueryParams); + } + + @Override + public boolean execute() throws SQLException { + val client = getQueryExecutor(); + listener = AsyncQueryStatusListener.of(sql, client); + return true; + } + + @Override + public ResultSet executeQuery() throws SQLException { + val client = getQueryExecutor(); val timeout = Duration.ofSeconds(getQueryTimeout()); val useSync = optional(this.dataCloudConnection.getProperties(), Constants.FORCE_SYNC) @@ -246,12 +261,6 @@ public void setObject(int parameterIndex, Object x) throws SQLException { } } - @Override - public boolean execute() throws SQLException { - resultSet = executeQuery(); - return true; - } - @Override public void addBatch() throws SQLException { throw new DataCloudJDBCException(BATCH_EXECUTION_IS_NOT_SUPPORTED, SqlErrorCodes.FEATURE_NOT_SUPPORTED); diff --git a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudResultSet.java b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudResultSet.java index 98e62b23..5c6d7d44 100644 --- a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudResultSet.java +++ b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudResultSet.java @@ -15,6 +15,7 @@ */ package com.salesforce.datacloud.jdbc.core; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; import java.sql.ResultSet; public interface DataCloudResultSet extends ResultSet { @@ -22,5 +23,5 @@ public interface DataCloudResultSet extends ResultSet { String getStatus(); - boolean isReady(); + boolean isReady() throws DataCloudJDBCException; } diff --git a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudStatement.java b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudStatement.java index 1b734926..806f16cd 100644 --- a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudStatement.java +++ b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudStatement.java @@ -38,7 +38,7 @@ import salesforce.cdp.hyperdb.v1.QueryParam; @Slf4j -public class DataCloudStatement implements Statement { +public class DataCloudStatement implements Statement, AutoCloseable { protected ResultSet resultSet; protected static final String NOT_SUPPORTED_IN_DATACLOUD_QUERY = "Write is not supported in Data Cloud query"; @@ -96,14 +96,15 @@ public String getQueryId() throws SQLException { return listener.getQueryId(); } - public boolean isReady() { + public boolean isReady() throws DataCloudJDBCException { return listener.isReady(); } @Override public boolean execute(String sql) throws SQLException { log.debug("Entering execute"); - this.resultSet = executeQuery(sql); + val client = getQueryExecutor(); + listener = AsyncQueryStatusListener.of(sql, client); return true; } @@ -211,7 +212,16 @@ public void setQueryTimeout(int seconds) { } @Override - public void cancel() {} + public void cancel() throws SQLException { + if (listener == null) { + log.warn("There was no in-progress query registered with this statement to cancel"); + return; + } + + val queryId = getQueryId(); + val executor = dataCloudConnection.getExecutor(); + executor.cancel(queryId); + } @Override public SQLWarning getWarnings() { diff --git a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/HyperGrpcClientExecutor.java b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/HyperGrpcClientExecutor.java index 13985d2e..3d8fea1f 100644 --- a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/HyperGrpcClientExecutor.java +++ b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/HyperGrpcClientExecutor.java @@ -41,6 +41,7 @@ import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; +import salesforce.cdp.hyperdb.v1.CancelQueryParam; import salesforce.cdp.hyperdb.v1.ExecuteQueryResponse; import salesforce.cdp.hyperdb.v1.HyperServiceGrpc; import salesforce.cdp.hyperdb.v1.OutputFormat; @@ -135,11 +136,6 @@ public Iterator getQueryInfo(String queryId) { return getStub(queryId).getQueryInfo(param); } - public Iterator getQueryInfoStreaming(String queryId) { - val param = getQueryInfoParamStreaming(queryId); - return getStub(queryId).getQueryInfo(param); - } - @Unstable public Stream getQueryStatus(String queryId) { val iterator = getQueryInfo(queryId); @@ -149,6 +145,12 @@ public Stream getQueryStatus(String queryId) { .map(Optional::get); } + public void cancel(String queryId) { + val request = CancelQueryParam.newBuilder().setQueryId(queryId).build(); + val stub = getStub(queryId); + stub.cancelQuery(request); + } + public Iterator getQueryResult(String queryId, long offset, long limit, boolean omitSchema) { val rowRange = ResultRange.newBuilder().setRowOffset(offset).setRowLimit(limit).setByteLimit(1024); @@ -196,10 +198,6 @@ private QueryResultParam getQueryResultParam(String queryId, long chunkId, boole } private QueryInfoParam getQueryInfoParam(String queryId) { - return QueryInfoParam.newBuilder().setQueryId(queryId).build(); - } - - private QueryInfoParam getQueryInfoParamStreaming(String queryId) { return QueryInfoParam.newBuilder() .setQueryId(queryId) .setStreaming(true) diff --git a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/StreamingResultSet.java b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/StreamingResultSet.java index 337b0fe9..408b4a13 100644 --- a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/StreamingResultSet.java +++ b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/StreamingResultSet.java @@ -16,6 +16,7 @@ package com.salesforce.datacloud.jdbc.core; import com.salesforce.datacloud.jdbc.core.listener.QueryStatusListener; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; import com.salesforce.datacloud.jdbc.exception.QueryExceptionHandler; import com.salesforce.datacloud.jdbc.util.ArrowUtils; import com.salesforce.datacloud.jdbc.util.StreamUtilities; @@ -115,7 +116,7 @@ public String getStatus() { } @Override - public boolean isReady() { + public boolean isReady() throws DataCloudJDBCException { return listener.isReady(); } diff --git a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusListener.java b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusListener.java index 9d4bf8f1..796151a4 100644 --- a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusListener.java +++ b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusListener.java @@ -18,6 +18,7 @@ import static com.salesforce.datacloud.jdbc.util.ThrowingSupplier.rethrowLongSupplier; import static com.salesforce.datacloud.jdbc.util.ThrowingSupplier.rethrowSupplier; +import com.salesforce.datacloud.jdbc.core.DataCloudQueryStatus; import com.salesforce.datacloud.jdbc.core.DataCloudResultSet; import com.salesforce.datacloud.jdbc.core.HyperGrpcClientExecutor; import com.salesforce.datacloud.jdbc.core.StreamingResultSet; @@ -30,6 +31,7 @@ import java.time.Instant; import java.util.Iterator; import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import java.util.function.UnaryOperator; import java.util.stream.LongStream; @@ -37,12 +39,10 @@ import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.Getter; -import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import lombok.val; import salesforce.cdp.hyperdb.v1.ExecuteQueryResponse; import salesforce.cdp.hyperdb.v1.QueryResult; -import salesforce.cdp.hyperdb.v1.QueryStatus; @Slf4j @AllArgsConstructor(access = AccessLevel.PRIVATE) @@ -61,8 +61,6 @@ public class AdaptiveQueryStatusListener implements QueryStatusListener { private final AdaptiveQueryStatusPoller headPoller; - private final AsyncQueryStatusPoller tailPoller; - public static AdaptiveQueryStatusListener of(String query, HyperGrpcClientExecutor client, Duration timeout) throws SQLException { try { @@ -70,13 +68,7 @@ public static AdaptiveQueryStatusListener of(String query, HyperGrpcClientExecut val queryId = response.next().getQueryInfo().getQueryStatus().getQueryId(); return new AdaptiveQueryStatusListener( - queryId, - query, - client, - timeout, - response, - new AdaptiveQueryStatusPoller(queryId, client), - new AsyncQueryStatusPoller(queryId, client)); + queryId, query, client, timeout, response, new AdaptiveQueryStatusPoller(queryId, client)); } catch (StatusRuntimeException ex) { throw QueryExceptionHandler.createQueryException(query, ex); } @@ -89,12 +81,11 @@ public boolean isReady() { @Override public String getStatus() { - val poller = headPoller.pollChunkCount() > 1 ? tailPoller : headPoller; - return Optional.of(poller) - .map(QueryStatusPoller::pollQueryStatus) - .map(QueryStatus::getCompletionStatus) + return client.getQueryStatus(queryId) + .map(DataCloudQueryStatus::getCompletionStatus) .map(Enum::name) - .orElse(QueryStatus.CompletionStatus.RUNNING_OR_UNSPECIFIED.name()); + .findFirst() + .orElse("UNKNOWN"); } @Override @@ -126,8 +117,8 @@ private Stream> infiniteChunks() { private long getChunkLimit() throws SQLException { if (headPoller.pollChunkCount() > 1) { - blockUntilReady(tailPoller, timeout); - return tailPoller.pollChunkCount() - 1; + val status = blockUntilReady(timeout); + return status.getChunkCount() - 1; } return 0; @@ -146,23 +137,19 @@ private Stream tryGetQueryResult(long chunkId) { .orElse(Stream.empty()); } - @SneakyThrows - private void blockUntilReady(QueryStatusPoller poller, Duration timeout) { - val end = Instant.now().plus(timeout); - int millis = 1000; - while (!poller.pollIsReady() && Instant.now().isBefore(end)) { - log.info( - "Waiting for additional query results. queryId={}, timeout={}, sleep={}", - queryId, - timeout, - Duration.ofSeconds(millis)); - - Thread.sleep(millis); - millis *= 2; + private DataCloudQueryStatus blockUntilReady(Duration timeout) throws DataCloudJDBCException { + val deadline = Instant.now().plus(timeout); + val last = new AtomicReference(); + + while (Instant.now().isBefore(deadline)) { + val isReady = client.getQueryStatus(queryId) + .peek(last::set) + .anyMatch(t -> t.isResultProduced() || t.isExecutionFinished()); + if (isReady) { + return last.get(); + } } - if (!tailPoller.pollIsReady()) { - throw new DataCloudJDBCException(BEFORE_READY + ". queryId=" + queryId + ", timeout=" + timeout); - } + throw new DataCloudJDBCException(BEFORE_READY + ". queryId=" + queryId + ", timeout=" + timeout); } } diff --git a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusPoller.java b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusPoller.java index 7c0dfb02..65adc888 100644 --- a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusPoller.java +++ b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusPoller.java @@ -44,7 +44,7 @@ public class AdaptiveQueryStatusPoller implements QueryStatusPoller { @SneakyThrows private Iterator getQueryInfoStreaming() { try { - return client.getQueryInfoStreaming(queryId); + return client.getQueryInfo(queryId); } catch (StatusRuntimeException ex) { throw QueryExceptionHandler.createException("Failed when getting query status", ex); } diff --git a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusListener.java b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusListener.java index 3448fd98..9a36e957 100644 --- a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusListener.java +++ b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusListener.java @@ -15,6 +15,7 @@ */ package com.salesforce.datacloud.jdbc.core.listener; +import com.salesforce.datacloud.jdbc.core.DataCloudQueryStatus; import com.salesforce.datacloud.jdbc.core.DataCloudResultSet; import com.salesforce.datacloud.jdbc.core.HyperGrpcClientExecutor; import com.salesforce.datacloud.jdbc.core.StreamingResultSet; @@ -23,7 +24,6 @@ import com.salesforce.datacloud.jdbc.util.StreamUtilities; import io.grpc.StatusRuntimeException; import java.sql.SQLException; -import java.util.Optional; import java.util.function.UnaryOperator; import java.util.stream.LongStream; import java.util.stream.Stream; @@ -34,7 +34,6 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import salesforce.cdp.hyperdb.v1.QueryResult; -import salesforce.cdp.hyperdb.v1.QueryStatus; @Slf4j @Builder(access = AccessLevel.PRIVATE) @@ -66,17 +65,21 @@ public static AsyncQueryStatusListener of(String query, HyperGrpcClientExecutor } @Override - public boolean isReady() { - return getPoller().pollIsReady(); + public boolean isReady() throws DataCloudJDBCException { + try { + return client.getQueryStatus(queryId).anyMatch(t -> t.isResultProduced() || t.isExecutionFinished()); + } catch (StatusRuntimeException ex) { + throw QueryExceptionHandler.createQueryException(query, ex); + } } @Override public String getStatus() { - return Optional.of(getPoller()) - .map(AsyncQueryStatusPoller::pollQueryStatus) - .map(QueryStatus::getCompletionStatus) + return client.getQueryStatus(queryId) + .map(DataCloudQueryStatus::getCompletionStatus) .map(Enum::name) - .orElse(null); + .findFirst() + .orElse("UNKNOWN"); } @Override diff --git a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusPoller.java b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusPoller.java index 378b7a29..aa4dfb79 100644 --- a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusPoller.java +++ b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusPoller.java @@ -49,6 +49,7 @@ private Optional getQueryInfo() { } private Optional fetchQueryStatus() { + val status = getQueryInfo().map(QueryInfo::getQueryStatus); if (status.isPresent()) { this.lastStatus.set(status.get()); diff --git a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListener.java b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListener.java index ac6773eb..3a8e82c8 100644 --- a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListener.java +++ b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListener.java @@ -16,6 +16,7 @@ package com.salesforce.datacloud.jdbc.core.listener; import com.salesforce.datacloud.jdbc.core.DataCloudResultSet; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; import java.sql.SQLException; import java.util.stream.Stream; import salesforce.cdp.hyperdb.v1.QueryResult; @@ -26,7 +27,7 @@ public interface QueryStatusListener { String getQuery(); - boolean isReady(); + boolean isReady() throws DataCloudJDBCException; String getStatus(); diff --git a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusPoller.java b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusPoller.java index 53ed24b7..cb6efc91 100644 --- a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusPoller.java +++ b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusPoller.java @@ -19,6 +19,10 @@ import java.util.Optional; import salesforce.cdp.hyperdb.v1.QueryStatus; +/** + * Marked as deprecated since the streaming form of GetQueryInfo makes this construct mostly unnecessary. + */ +@Deprecated public interface QueryStatusPoller { QueryStatus pollQueryStatus(); diff --git a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatementTest.java b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatementTest.java index 79ef373b..f7d97777 100644 --- a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatementTest.java +++ b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatementTest.java @@ -91,32 +91,25 @@ public void beforeEach() { mockParameterManager = mock(ParameterManager.class); - preparedStatement = new DataCloudPreparedStatement(mockConnection, mockParameterManager); + preparedStatement = new DataCloudPreparedStatement(mockConnection, "SELECT * FROM table", mockParameterManager); } @Test @SneakyThrows public void testExecuteQuery() { - setupHyperGrpcClientWithMockedResultSet("query id", ImmutableList.of()); - ResultSet resultSet = preparedStatement.executeQuery("SELECT * FROM table"); - assertNotNull(resultSet); - assertThat(resultSet.getMetaData().getColumnCount()).isEqualTo(3); - assertThat(resultSet.getMetaData().getColumnName(1)).isEqualTo("id"); - assertThat(resultSet.getMetaData().getColumnName(2)).isEqualTo("name"); - assertThat(resultSet.getMetaData().getColumnName(3)).isEqualTo("grade"); + assertThatThrownBy(() -> preparedStatement.executeQuery("SELECT * FROM table")) + .isInstanceOf(DataCloudJDBCException.class) + .hasMessage( + "Per the JDBC specification this method cannot be called on a PreparedStatement, use DataCloudPreparedStatement::executeQuery() instead."); } @Test @SneakyThrows public void testExecute() { - setupHyperGrpcClientWithMockedResultSet("query id", ImmutableList.of()); - preparedStatement.execute("SELECT * FROM table"); - ResultSet resultSet = preparedStatement.getResultSet(); - assertNotNull(resultSet); - assertThat(resultSet.getMetaData().getColumnCount()).isEqualTo(3); - assertThat(resultSet.getMetaData().getColumnName(1)).isEqualTo("id"); - assertThat(resultSet.getMetaData().getColumnName(2)).isEqualTo("name"); - assertThat(resultSet.getMetaData().getColumnName(3)).isEqualTo("grade"); + assertThatThrownBy(() -> preparedStatement.execute("SELECT * FROM table")) + .isInstanceOf(DataCloudJDBCException.class) + .hasMessage( + "Per the JDBC specification this method cannot be called on a PreparedStatement, use DataCloudPreparedStatement::execute() instead."); } @SneakyThrows @@ -127,13 +120,13 @@ public void testForceSyncOverride(boolean forceSync) { p.setProperty(Constants.FORCE_SYNC, Boolean.toString(forceSync)); when(mockConnection.getProperties()).thenReturn(p); - val statement = new DataCloudPreparedStatement(mockConnection, mockParameterManager); + val statement = new DataCloudPreparedStatement(mockConnection, "SELECT * FROM table", mockParameterManager); setupHyperGrpcClientWithMockedResultSet( "query id", ImmutableList.of(), forceSync ? QueryParam.TransferMode.SYNC : QueryParam.TransferMode.ADAPTIVE); - ResultSet response = statement.executeQuery("SELECT * FROM table"); + ResultSet response = statement.executeQuery(); AssertionsForClassTypes.assertThat(statement.isReady()).isTrue(); assertNotNull(response); AssertionsForClassTypes.assertThat(response.getMetaData().getColumnCount()) @@ -147,7 +140,7 @@ public void testExecuteQueryWithSqlException() { GrpcMock.stubFor(GrpcMock.unaryMethod(HyperServiceGrpc.getExecuteQueryMethod()) .willReturn(GrpcMock.exception(fakeException))); - assertThrows(DataCloudJDBCException.class, () -> preparedStatement.executeQuery("SELECT * FROM table")); + assertThrows(DataCloudJDBCException.class, () -> preparedStatement.executeQuery()); } @Test diff --git a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudStatementFunctionalTest.java b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudStatementFunctionalTest.java index 3d43d23e..eb99f1c8 100644 --- a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudStatementFunctionalTest.java +++ b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudStatementFunctionalTest.java @@ -15,17 +15,89 @@ */ package com.salesforce.datacloud.jdbc.core; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.hyper.HyperServerConfig; import com.salesforce.datacloud.jdbc.hyper.HyperTestBase; import java.sql.ResultSet; +import java.util.stream.Collectors; import lombok.SneakyThrows; import lombok.val; import org.junit.jupiter.api.Test; public class DataCloudStatementFunctionalTest extends HyperTestBase { + private static final HyperServerConfig configWithSleep = + HyperServerConfig.builder().experimentalPgSleep(true).build(); + + @Test + @SneakyThrows + public void canCancelStatementQuery() { + try (val server = configWithSleep.start(); + val statement = server.getConnection().createStatement().unwrap(DataCloudStatement.class)) { + statement.execute("select pg_sleep(5000000);"); + val client = server.getRawClient(); + val queryId = statement.getQueryId(); + val a = client.getQueryStatus(queryId).findFirst().get(); + assertThat(a.getCompletionStatus()).isEqualTo(DataCloudQueryStatus.CompletionStatus.RUNNING); + + statement.cancel(); + + assertThatThrownBy(() -> client.getQueryStatus(queryId).collect(Collectors.toList())) + .hasMessage("FAILED_PRECONDITION: canceled"); + } + } + + @Test + @SneakyThrows + public void canCancelPreparedStatementQuery() { + try (val server = configWithSleep.start(); + val statement = server.getConnection() + .prepareStatement("select pg_sleep(?)") + .unwrap(DataCloudPreparedStatement.class)) { + statement.setInt(1, 5000000); + statement.execute(); + val client = server.getRawClient(); + val queryId = statement.getQueryId(); + val a = client.getQueryStatus(queryId).findFirst().get(); + assertThat(a.getCompletionStatus()).isEqualTo(DataCloudQueryStatus.CompletionStatus.RUNNING); + + statement.cancel(); + + assertThatThrownBy(() -> client.getQueryStatus(queryId).collect(Collectors.toList())) + .hasMessage("FAILED_PRECONDITION: canceled"); + } + } + + @Test + @SneakyThrows + public void canCancelAnotherQueryById() { + try (val server = configWithSleep.start(); + val statement = server.getConnection().createStatement().unwrap(DataCloudStatement.class); + val cancel = server.getConnection().unwrap(DataCloudConnection.class)) { + + statement.execute("select pg_sleep(5000000);"); + val queryId = statement.getQueryId(); + + val client = server.getRawClient(); + + val a = client.getQueryStatus(queryId).findFirst().get(); + assertThat(a.getCompletionStatus()).isEqualTo(DataCloudQueryStatus.CompletionStatus.RUNNING); + + cancel.cancel(queryId); + + assertThatThrownBy(() -> client.getQueryStatus(queryId).collect(Collectors.toList())) + .hasMessage("FAILED_PRECONDITION: canceled"); + } + } + + @Test + @SneakyThrows + public void noErrorOnCancelUnknownQuery() { + assertWithConnection(connection -> connection.cancel("nonsense query id")); + } + @Test @SneakyThrows public void forwardAndReadOnly() { diff --git a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudStatementTest.java b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudStatementTest.java index 6a32b25d..14691d9e 100644 --- a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudStatementTest.java +++ b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudStatementTest.java @@ -119,14 +119,17 @@ public void testExecuteQuery() { @Test @SneakyThrows public void testExecute() { - setupHyperGrpcClientWithMockedResultSet("query id", ImmutableList.of()); - statement.execute("SELECT * FROM table"); - ResultSet response = statement.getResultSet(); - assertNotNull(response); - assertThat(response.getMetaData().getColumnCount()).isEqualTo(3); - assertThat(response.getMetaData().getColumnName(1)).isEqualTo("id"); - assertThat(response.getMetaData().getColumnName(2)).isEqualTo("name"); - assertThat(response.getMetaData().getColumnName(3)).isEqualTo("grade"); + try (val connection = getInterceptedClientConnection(); + val statement = connection.createStatement()) { + statement.execute( + "SELECT md5(random()::text) AS id, md5(random()::text) AS name, round((random() * 3 + 1)::numeric, 2) AS grade FROM generate_series(1, 3);"); + val response = statement.getResultSet(); + assertNotNull(response); + assertThat(response.getMetaData().getColumnCount()).isEqualTo(3); + assertThat(response.getMetaData().getColumnName(1)).isEqualTo("id"); + assertThat(response.getMetaData().getColumnName(2)).isEqualTo("name"); + assertThat(response.getMetaData().getColumnName(3)).isEqualTo("grade"); + } } @Test diff --git a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/HyperGrpcTestBase.java b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/HyperGrpcTestBase.java index 16b69e0a..ec766c88 100644 --- a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/HyperGrpcTestBase.java +++ b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/HyperGrpcTestBase.java @@ -21,6 +21,7 @@ import com.salesforce.datacloud.jdbc.auth.AuthenticationSettings; import com.salesforce.datacloud.jdbc.auth.DataCloudToken; import com.salesforce.datacloud.jdbc.auth.TokenProcessor; +import com.salesforce.datacloud.jdbc.hyper.HyperTestBase; import com.salesforce.datacloud.jdbc.util.RealisticArrowGenerator; import io.grpc.inprocess.InProcessChannelBuilder; import java.io.IOException; @@ -46,7 +47,7 @@ import salesforce.cdp.hyperdb.v1.QueryStatus; @ExtendWith(InProcessGrpcMockExtension.class) -public class HyperGrpcTestBase { +public class HyperGrpcTestBase extends HyperTestBase { protected static HyperGrpcClientExecutor hyperGrpcClient; @@ -170,28 +171,6 @@ public void setupGetQueryResult( .willReturn(results)); } - public void setupAdaptiveInitialResults( - String sql, - String queryId, - int parts, - Integer chunks, - QueryStatus.CompletionStatus status, - List students) { - val results = IntStream.range(0, parts) - .mapToObj(i -> RealisticArrowGenerator.getMockedData(students)) - .flatMap(UnaryOperator.identity()) - .map(r -> ExecuteQueryResponse.newBuilder().setQueryResult(r).build()); - - val response = Stream.concat( - Stream.of(executeQueryResponse(queryId, null, null)), - Stream.concat(results, Stream.of(executeQueryResponse(queryId, status, chunks)))) - .collect(Collectors.toList()); - - GrpcMock.stubFor(GrpcMock.serverStreamingMethod(HyperServiceGrpc.getExecuteQueryMethod()) - .withRequest(req -> req.getQuery().equals(sql)) - .willReturn(response)); - } - public static ExecuteQueryResponse executeQueryResponseWithData(List students) { val result = RealisticArrowGenerator.getMockedData(students).findFirst().orElseThrow(RuntimeException::new); return ExecuteQueryResponse.newBuilder().setQueryResult(result).build(); diff --git a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/StreamingResultSetTest.java b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/StreamingResultSetTest.java index 16fdc4b9..69e8bb27 100644 --- a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/StreamingResultSetTest.java +++ b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/StreamingResultSetTest.java @@ -43,7 +43,7 @@ private static Stream queryModes(int size) { inline("executeSyncQuery", DataCloudStatement::executeSyncQuery, size), inline("executeAdaptiveQuery", DataCloudStatement::executeAdaptiveQuery, size), deferred("executeAsyncQuery", DataCloudStatement::executeAsyncQuery, true, size), - deferred("execute", DataCloudStatement::execute, false, size), + deferred("execute", DataCloudStatement::execute, true, size), deferred("executeQuery", DataCloudStatement::executeQuery, false, size)); } diff --git a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusListenerTest.java b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusListenerTest.java index 3706640e..f49b7a99 100644 --- a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusListenerTest.java +++ b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusListenerTest.java @@ -25,7 +25,6 @@ import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; import com.salesforce.datacloud.jdbc.util.RealisticArrowGenerator; import java.util.NoSuchElementException; -import java.util.Objects; import java.util.Properties; import java.util.Random; import java.util.UUID; @@ -38,7 +37,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; +import org.junit.jupiter.params.provider.CsvSource; import salesforce.cdp.hyperdb.v1.QueryParam; import salesforce.cdp.hyperdb.v1.QueryStatus; @@ -48,22 +47,14 @@ class AsyncQueryStatusListenerTest extends HyperGrpcTestBase { private final QueryParam.TransferMode mode = QueryParam.TransferMode.ASYNC; @ParameterizedTest - @ValueSource( - ints = { - QueryStatus.CompletionStatus.RUNNING_OR_UNSPECIFIED_VALUE, - QueryStatus.CompletionStatus.RESULTS_PRODUCED_VALUE, - QueryStatus.CompletionStatus.FINISHED_VALUE - }) - void itCanGetStatus(int value) { - val status = QueryStatus.CompletionStatus.forNumber(value); - + @CsvSource({"0, RUNNING", "1, RESULTS_PRODUCED", "2, FINISHED"}) + void itCanGetStatus(int value, String expected) { val queryId = UUID.randomUUID().toString(); setupExecuteQuery(queryId, query, mode); val listener = sut(query); - setupGetQueryInfo(queryId, status); - assertThat(listener.getStatus()) - .isEqualTo(Objects.requireNonNull(status).name()); + setupGetQueryInfo(queryId, QueryStatus.CompletionStatus.forNumber(value)); + assertThat(listener.getStatus()).isEqualTo(expected); } @Test diff --git a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListenerAssert.java b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListenerAssert.java index c7b17e5a..f45abd23 100644 --- a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListenerAssert.java +++ b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListenerAssert.java @@ -15,6 +15,7 @@ */ package com.salesforce.datacloud.jdbc.core.listener; +import lombok.SneakyThrows; import org.assertj.core.api.AbstractObjectAssert; import org.assertj.core.util.Objects; @@ -98,6 +99,7 @@ public QueryStatusListenerAssert hasQueryId(String queryId) { * @return this assertion object. * @throws AssertionError - if the actual QueryStatusListener is not ready. */ + @SneakyThrows public QueryStatusListenerAssert isReady() { // check that actual QueryStatusListener we want to make assertions on is not null. isNotNull(); @@ -117,6 +119,7 @@ public QueryStatusListenerAssert isReady() { * @return this assertion object. * @throws AssertionError - if the actual QueryStatusListener is ready. */ + @SneakyThrows public QueryStatusListenerAssert isNotReady() { // check that actual QueryStatusListener we want to make assertions on is not null. isNotNull(); diff --git a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/partial/RowBasedTest.java b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/partial/RowBasedTest.java index b01aa41f..ed16b476 100644 --- a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/partial/RowBasedTest.java +++ b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/partial/RowBasedTest.java @@ -44,11 +44,12 @@ import org.junit.jupiter.params.provider.MethodSource; @Slf4j -class RowBasedTest extends HyperTestBase { +public class RowBasedTest extends HyperTestBase { private List sut(String queryId, long offset, long limit, RowBased.Mode mode) { - val connection = getHyperQueryConnection(); - val resultSet = connection.getRowBasedResultSet(queryId, offset, limit, mode); - return toList(resultSet); + try (val connection = getHyperQueryConnection()) { + val resultSet = connection.getRowBasedResultSet(queryId, offset, limit, mode); + return toList(resultSet); + } } private static final int tinySize = 8; @@ -204,7 +205,7 @@ private static List rangeClosed(int start, int end) { return IntStream.rangeClosed(start, end).boxed().collect(Collectors.toList()); } - private static Stream toStream(DataCloudResultSet resultSet) { + public static Stream toStream(DataCloudResultSet resultSet) { val iterator = new Iterator() { @SneakyThrows @Override diff --git a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/hyper/HyperServerConfig.java b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/hyper/HyperServerConfig.java new file mode 100644 index 00000000..3f989be2 --- /dev/null +++ b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/hyper/HyperServerConfig.java @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.hyper; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.Builder; +import lombok.Value; +import lombok.val; + +@Builder(toBuilder = true) +@Value +public class HyperServerConfig { + @Builder.Default + @JsonProperty("grpc-request-timeout") + String grpcRequestTimeoutSeconds = null; + + @Builder.Default + @JsonProperty("experimental_pg_sleep") + boolean experimentalPgSleep = false; + + @Override + public String toString() { + val mapper = new ObjectMapper(); + val map = mapper.convertValue(this, new TypeReference>() {}); + return map.entrySet().stream() + .filter(entry -> entry.getValue() != null) + .map(entry -> String.format("--%s=%s", entry.getKey().replace("_", "-"), entry.getValue())) + .collect(Collectors.joining(" ")); + } + + public HyperServerProcess start() { + return new HyperServerProcess(this.toBuilder()); + } +} diff --git a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/hyper/HyperServerProcess.java b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/hyper/HyperServerProcess.java index 928b913a..490b8ca5 100644 --- a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/hyper/HyperServerProcess.java +++ b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/hyper/HyperServerProcess.java @@ -17,8 +17,19 @@ import static java.util.Objects.requireNonNull; -import java.io.*; +import com.google.common.collect.ImmutableMap; +import com.salesforce.datacloud.jdbc.core.DataCloudConnection; +import com.salesforce.datacloud.jdbc.core.HyperGrpcClientExecutor; +import com.salesforce.datacloud.jdbc.interceptor.AuthorizationHeaderInterceptor; +import io.grpc.ManagedChannelBuilder; +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; import java.nio.file.Paths; +import java.util.Map; +import java.util.Properties; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -31,19 +42,23 @@ import org.junit.jupiter.api.Assertions; @Slf4j -public class HyperServerProcess { +public class HyperServerProcess implements AutoCloseable { private static final Pattern PORT_PATTERN = Pattern.compile(".*gRPC listening on 127.0.0.1:([0-9]+).*"); private final Process hyperProcess; private final ExecutorService hyperMonitors; private Integer port; - @SneakyThrows public HyperServerProcess() { + this(HyperServerConfig.builder()); + } + + @SneakyThrows + public HyperServerProcess(HyperServerConfig.HyperServerConfigBuilder config) { log.info("starting hyperd, this might take a few seconds"); val executable = new File("../target/hyper/hyperd"); - val properties = Paths.get(requireNonNull(HyperTestBase.class.getResource("/hyper.yaml")) + val yaml = Paths.get(requireNonNull(HyperTestBase.class.getResource("/hyper.yaml")) .toURI()) .toFile(); @@ -52,9 +67,17 @@ public HyperServerProcess() { + executable.getAbsolutePath()); } - hyperProcess = new ProcessBuilder() - .command(executable.getAbsolutePath(), "--config", properties.getAbsolutePath(), "--no-password", "run") - .start(); + val builder = new ProcessBuilder() + .command( + executable.getAbsolutePath(), + config.build().toString(), + "--config", + yaml.getAbsolutePath(), + "--no-password", + "run"); + + log.info("hyper command: {}", builder.command()); + hyperProcess = builder.start(); // Wait until process is listening and extract port on which it is listening val latch = new CountDownLatch(1); @@ -74,18 +97,6 @@ public HyperServerProcess() { } } - @SneakyThrows - void shutdown() throws InterruptedException { - if (hyperProcess != null && hyperProcess.isAlive()) { - log.info("destroy hyper process"); - hyperProcess.destroy(); - hyperProcess.waitFor(); - } - - log.info("shutdown hyper monitors"); - hyperMonitors.shutdown(); - } - int getPort() { return port; } @@ -106,4 +117,41 @@ private static void logStream(InputStream inputStream, Consumer consumer log.error("Caught unexpected exception", e); } } + + @Override + public void close() throws Exception { + if (hyperProcess != null && hyperProcess.isAlive()) { + log.info("destroy hyper process"); + hyperProcess.destroy(); + hyperProcess.waitFor(); + } + + log.info("shutdown hyper monitors"); + hyperMonitors.shutdown(); + } + + public DataCloudConnection getConnection() { + return getConnection(ImmutableMap.of()); + } + + @SneakyThrows + public HyperGrpcClientExecutor getRawClient() { + val auth = AuthorizationHeaderInterceptor.of(new HyperTestBase.NoopTokenSupplier()); + ManagedChannelBuilder channel = ManagedChannelBuilder.forAddress("127.0.0.1", getPort()) + .usePlaintext() + .intercept(auth); + return HyperGrpcClientExecutor.of(channel, new Properties()); + } + + @SneakyThrows + public DataCloudConnection getConnection(Map connectionSettings) { + val properties = new Properties(); + properties.putAll(connectionSettings); + val auth = AuthorizationHeaderInterceptor.of(new HyperTestBase.NoopTokenSupplier()); + log.info("Creating connection to port {}", getPort()); + ManagedChannelBuilder channel = + ManagedChannelBuilder.forAddress("127.0.0.1", getPort()).usePlaintext(); + + return DataCloudConnection.fromTokenSupplier(auth, channel, properties); + } } diff --git a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/hyper/HyperTestBase.java b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/hyper/HyperTestBase.java index 2f9d7186..d7c91684 100644 --- a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/hyper/HyperTestBase.java +++ b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/hyper/HyperTestBase.java @@ -21,21 +21,30 @@ import com.salesforce.datacloud.jdbc.core.DataCloudConnection; import com.salesforce.datacloud.jdbc.core.DataCloudStatement; import com.salesforce.datacloud.jdbc.interceptor.AuthorizationHeaderInterceptor; +import com.salesforce.datacloud.jdbc.interceptor.QueryIdHeaderInterceptor; import io.grpc.ManagedChannelBuilder; +import io.grpc.MethodDescriptor; +import io.grpc.inprocess.InProcessChannelBuilder; import java.sql.ResultSet; +import java.util.Iterator; import java.util.Map; import java.util.Properties; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.assertj.core.api.ThrowingConsumer; +import org.grpcmock.GrpcMock; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.Timeout; +import salesforce.cdp.hyperdb.v1.HyperServiceGrpc; +import salesforce.cdp.hyperdb.v1.QueryInfoParam; +import salesforce.cdp.hyperdb.v1.QueryResultParam; @Slf4j @TestInstance(TestInstance.Lifecycle.PER_CLASS) @@ -74,24 +83,15 @@ public static DataCloudConnection getHyperQueryConnection() { return getHyperQueryConnection(ImmutableMap.of()); } - @SneakyThrows public static DataCloudConnection getHyperQueryConnection(Map connectionSettings) { - - val properties = new Properties(); - properties.putAll(connectionSettings); - val auth = AuthorizationHeaderInterceptor.of(new NoopTokenSupplier()); - log.info("Creating connection to port {}", instance.getPort()); - ManagedChannelBuilder channel = ManagedChannelBuilder.forAddress("127.0.0.1", instance.getPort()) - .usePlaintext(); - - return DataCloudConnection.fromTokenSupplier(auth, channel, properties); + return instance.getConnection(connectionSettings); } @SneakyThrows @AfterAll @Timeout(5_000) public void afterAll() { - instance.shutdown(); + instance.close(); } @SneakyThrows @@ -111,4 +111,57 @@ public String getToken() { return ""; } } + + @SneakyThrows + protected DataCloudConnection getInterceptedClientConnection() { + val mocked = InProcessChannelBuilder.forName(GrpcMock.getGlobalInProcessName()) + .usePlaintext(); + + val auth = AuthorizationHeaderInterceptor.of(new HyperTestBase.NoopTokenSupplier()); + val channel = ManagedChannelBuilder.forAddress("127.0.0.1", instance.getPort()) + .usePlaintext() + .intercept(auth) + .maxInboundMessageSize(64 * 1024 * 1024) + .build(); + + val stub = HyperServiceGrpc.newBlockingStub(channel); + + proxyStreamingMethod( + stub, + HyperServiceGrpc.getExecuteQueryMethod(), + HyperServiceGrpc.HyperServiceBlockingStub::executeQuery); + proxyStreamingMethod( + stub, + HyperServiceGrpc.getGetQueryInfoMethod(), + HyperServiceGrpc.HyperServiceBlockingStub::getQueryInfo); + proxyStreamingMethod( + stub, + HyperServiceGrpc.getGetQueryResultMethod(), + HyperServiceGrpc.HyperServiceBlockingStub::getQueryResult); + + return DataCloudConnection.fromTokenSupplier(auth, mocked, new Properties()); + } + + public static void proxyStreamingMethod( + HyperServiceGrpc.HyperServiceBlockingStub stub, + MethodDescriptor mock, + BiFunction> method) { + GrpcMock.stubFor(GrpcMock.serverStreamingMethod(mock).willProxyTo((request, observer) -> { + final String queryId; + if (request instanceof salesforce.cdp.hyperdb.v1.QueryInfoParam) { + queryId = ((QueryInfoParam) request).getQueryId(); + } else if (request instanceof salesforce.cdp.hyperdb.v1.QueryResultParam) { + queryId = ((QueryResultParam) request).getQueryId(); + } else { + queryId = null; + } + + val response = method.apply( + queryId == null ? stub : stub.withInterceptors(new QueryIdHeaderInterceptor(queryId)), request); + while (response.hasNext()) { + observer.onNext(response.next()); + } + observer.onCompleted(); + })); + } } diff --git a/jdbc-core/src/test/resources/hyper.yaml b/jdbc-core/src/test/resources/hyper.yaml index e171860e..4891a266 100644 --- a/jdbc-core/src/test/resources/hyper.yaml +++ b/jdbc-core/src/test/resources/hyper.yaml @@ -5,4 +5,4 @@ language: en_US no-password: true use_v3_new_endpoints: true grpc_persist_results: true -log_pipelines: true \ No newline at end of file +log_pipelines: true diff --git a/pom.xml b/pom.xml index 54664875..4a28279c 100644 --- a/pom.xml +++ b/pom.xml @@ -27,6 +27,7 @@ 3.4.2 3.11.2 3.3.1 + 3.3.1 true 1.7.0 UTF-8 @@ -265,6 +266,7 @@ + src/main/java/**/*.java src/test/java/**/*.java @@ -341,6 +343,7 @@ org.apache.maven.plugins maven-source-plugin + ${maven-source-plugin.version} attach-javadoc