diff --git a/README.md b/README.md index 57401a6a..67ff59a0 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ properties.put("clientSecret", "${clientSecret}"); The documentation for jwt authentication can be found [here][jwt flow]. -Instuctions to generate a private key can be found [here](#generating-a-private-key-for-jwt-authentication) +Instructions to generate a private key can be found [here](#generating-a-private-key-for-jwt-authentication) ```java Properties properties = new Properties(); diff --git a/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudConnection.java b/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudConnection.java index 4d2406c3..c020fdcd 100644 --- a/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudConnection.java +++ b/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudConnection.java @@ -22,6 +22,7 @@ import com.salesforce.datacloud.jdbc.auth.AuthenticationSettings; import com.salesforce.datacloud.jdbc.auth.DataCloudTokenProcessor; import com.salesforce.datacloud.jdbc.auth.TokenProcessor; +import com.salesforce.datacloud.jdbc.core.partial.RowBased; import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; import com.salesforce.datacloud.jdbc.http.ClientBuilder; import com.salesforce.datacloud.jdbc.interceptor.AuthorizationHeaderInterceptor; @@ -29,6 +30,7 @@ import com.salesforce.datacloud.jdbc.interceptor.HyperExternalClientContextHeaderInterceptor; import com.salesforce.datacloud.jdbc.interceptor.HyperWorkloadHeaderInterceptor; import com.salesforce.datacloud.jdbc.interceptor.TracingHeadersInterceptor; +import com.salesforce.datacloud.jdbc.util.Unstable; import io.grpc.ClientInterceptor; import io.grpc.ManagedChannelBuilder; import java.sql.Array; @@ -173,6 +175,33 @@ private DataCloudPreparedStatement getQueryPreparedStatement(String sql) { return new DataCloudPreparedStatement(this, sql, new DefaultParameterManager()); } + /** + * Retrieves a collection of rows for the specified query once it is ready. + * Use {@link #getQueryStatus(String)} to check if the query has produced results or finished execution before calling this method. + *

+ * When using {@link RowBased.Mode#FULL_RANGE}, this method does not handle pagination near the end of available rows. + * The caller is responsible for calculating the correct offset and limit to avoid out-of-range errors. + * + * @param queryId The identifier of the query to fetch results for. + * @param offset The starting row offset. + * @param limit The maximum number of rows to retrieve. + * @param mode The fetching mode—either {@link RowBased.Mode#SINGLE_RPC} for a single request or + * {@link RowBased.Mode#FULL_RANGE} to iterate through all available rows. + * @return A {@link DataCloudResultSet} containing the query results. + */ + public DataCloudResultSet getRowBasedResultSet(String queryId, long offset, long limit, RowBased.Mode mode) { + val iterator = RowBased.of(executor, queryId, offset, limit, mode); + return StreamingResultSet.of(queryId, executor, iterator); + } + + /** + * Use this to determine when a given query is complete by filtering the responses and a subsequent findFirst() + */ + @Unstable + public Stream getQueryStatus(String queryId) { + return executor.getQueryStatus(queryId); + } + @Override public CallableStatement prepareCall(String sql) { return null; diff --git a/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudQueryStatus.java b/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudQueryStatus.java new file mode 100644 index 00000000..08064533 --- /dev/null +++ b/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudQueryStatus.java @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import java.util.Optional; +import lombok.Value; +import lombok.val; +import salesforce.cdp.hyperdb.v1.QueryInfo; +import salesforce.cdp.hyperdb.v1.QueryStatus; + +/** + * Represents the status of a query. + * The {@link CompletionStatus} enum defines the possible states of the query, which are: + *

+ */ +@Value +public class DataCloudQueryStatus { + public enum CompletionStatus { + RUNNING, + RESULTS_PRODUCED, + FINISHED + } + + String queryId; + + long chunkCount; + + long rowCount; + + double progress; + + CompletionStatus completionStatus; + + /** + * Checks if the query's results have been produced. + * + * @return {@code true} if the query's results are available for retrieval, otherwise {@code false}. + */ + public boolean isResultProduced() { + return completionStatus == CompletionStatus.RESULTS_PRODUCED; + } + + /** + * Checks if the query execution is finished. + * + * @return {@code true} if the query has completed execution and results have been persisted, otherwise {@code false}. + */ + public boolean isExecutionFinished() { + return completionStatus == CompletionStatus.FINISHED; + } + + static Optional of(QueryInfo queryInfo) { + return Optional.ofNullable(queryInfo).map(QueryInfo::getQueryStatus).map(DataCloudQueryStatus::of); + } + + private static DataCloudQueryStatus of(QueryStatus s) { + val completionStatus = of(s.getCompletionStatus()); + return new DataCloudQueryStatus( + s.getQueryId(), s.getChunkCount(), s.getRowCount(), s.getProgress(), completionStatus); + } + + private static CompletionStatus of(QueryStatus.CompletionStatus completionStatus) { + switch (completionStatus) { + case RUNNING_OR_UNSPECIFIED: + return CompletionStatus.RUNNING; + case RESULTS_PRODUCED: + return CompletionStatus.RESULTS_PRODUCED; + case FINISHED: + return CompletionStatus.FINISHED; + default: + throw new IllegalArgumentException("Unknown completion status. status=" + completionStatus); + } + } +} diff --git a/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudStatement.java b/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudStatement.java index 63c1865d..1140371b 100644 --- a/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudStatement.java +++ b/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudStatement.java @@ -25,6 +25,7 @@ import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; import com.salesforce.datacloud.jdbc.util.Constants; import com.salesforce.datacloud.jdbc.util.SqlErrorCodes; +import com.salesforce.datacloud.jdbc.util.Unstable; import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; @@ -73,16 +74,27 @@ protected HyperGrpcClientExecutor getQueryExecutor(QueryParam additionalQueryPar return clientBuilder.queryTimeout(getQueryTimeout()).build(); } - private void assertQueryReady() throws SQLException { + private void assertQueryExecuted() throws SQLException { if (listener == null) { throw new DataCloudJDBCException("a query was not executed before attempting to access results"); } + } + + private void assertQueryReady() throws SQLException { + assertQueryExecuted(); if (!listener.isReady()) { throw new DataCloudJDBCException("query results were not ready"); } } + @Unstable + public String getQueryId() throws SQLException { + assertQueryExecuted(); + + return listener.getQueryId(); + } + public boolean isReady() { return listener.isReady(); } diff --git a/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/HyperGrpcClientExecutor.java b/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/HyperGrpcClientExecutor.java index 778a968c..13985d2e 100644 --- a/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/HyperGrpcClientExecutor.java +++ b/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/HyperGrpcClientExecutor.java @@ -20,6 +20,8 @@ import com.salesforce.datacloud.jdbc.config.DriverVersion; import com.salesforce.datacloud.jdbc.interceptor.QueryIdHeaderInterceptor; import com.salesforce.datacloud.jdbc.util.PropertiesExtensions; +import com.salesforce.datacloud.jdbc.util.StreamUtilities; +import com.salesforce.datacloud.jdbc.util.Unstable; import io.grpc.ClientInterceptor; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; @@ -29,8 +31,10 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Properties; import java.util.concurrent.TimeUnit; +import java.util.stream.Stream; import lombok.AccessLevel; import lombok.Builder; import lombok.Getter; @@ -45,6 +49,7 @@ import salesforce.cdp.hyperdb.v1.QueryParam; import salesforce.cdp.hyperdb.v1.QueryResult; import salesforce.cdp.hyperdb.v1.QueryResultParam; +import salesforce.cdp.hyperdb.v1.ResultRange; @Slf4j @Builder(toBuilder = true) @@ -135,6 +140,29 @@ public Iterator getQueryInfoStreaming(String queryId) { return getStub(queryId).getQueryInfo(param); } + @Unstable + public Stream getQueryStatus(String queryId) { + val iterator = getQueryInfo(queryId); + return StreamUtilities.toStream(iterator) + .map(DataCloudQueryStatus::of) + .filter(Optional::isPresent) + .map(Optional::get); + } + + public Iterator getQueryResult(String queryId, long offset, long limit, boolean omitSchema) { + val rowRange = + ResultRange.newBuilder().setRowOffset(offset).setRowLimit(limit).setByteLimit(1024); + + final QueryResultParam param = QueryResultParam.newBuilder() + .setQueryId(queryId) + .setResultRange(rowRange) + .setOmitSchema(omitSchema) + .setOutputFormat(OutputFormat.ARROW_IPC) + .build(); + + return getStub(queryId).getQueryResult(param); + } + public Iterator getQueryResult(String queryId, long chunkId, boolean omitSchema) { val param = getQueryResultParam(queryId, chunkId, omitSchema); return getStub(queryId).getQueryResult(param); @@ -161,12 +189,9 @@ private QueryResultParam getQueryResultParam(String queryId, long chunkId, boole val builder = QueryResultParam.newBuilder() .setQueryId(queryId) .setChunkId(chunkId) + .setOmitSchema(omitSchema) .setOutputFormat(OutputFormat.ARROW_IPC); - if (omitSchema) { - builder.setOmitSchema(true); - } - return builder.build(); } diff --git a/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/StreamingResultSet.java b/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/StreamingResultSet.java index 76d19cec..be1c40f2 100644 --- a/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/StreamingResultSet.java +++ b/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/StreamingResultSet.java @@ -18,11 +18,15 @@ import com.salesforce.datacloud.jdbc.core.listener.QueryStatusListener; import com.salesforce.datacloud.jdbc.exception.QueryExceptionHandler; import com.salesforce.datacloud.jdbc.util.ArrowUtils; +import com.salesforce.datacloud.jdbc.util.StreamUtilities; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.util.Collections; +import java.util.Iterator; import java.util.TimeZone; +import java.util.stream.Stream; import lombok.SneakyThrows; +import lombok.Value; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.arrow.memory.RootAllocator; @@ -32,6 +36,7 @@ import org.apache.calcite.avatica.AvaticaStatement; import org.apache.calcite.avatica.Meta; import org.apache.calcite.avatica.QueryState; +import salesforce.cdp.hyperdb.v1.QueryResult; @Slf4j public class StreamingResultSet extends AvaticaResultSet implements DataCloudResultSet { @@ -51,6 +56,7 @@ private StreamingResultSet( this.listener = listener; } + @Deprecated @SneakyThrows public static StreamingResultSet of(String sql, QueryStatusListener listener) { try { @@ -73,6 +79,30 @@ public static StreamingResultSet of(String sql, QueryStatusListener listener) { } } + @SneakyThrows + public static StreamingResultSet of( + String queryId, HyperGrpcClientExecutor client, Iterator iterator) { + try { + val channel = ExecuteQueryResponseChannel.of(StreamUtilities.toStream(iterator)); + val reader = new ArrowStreamReader(channel, new RootAllocator(ROOT_ALLOCATOR_MB_FROM_V2)); + val schemaRoot = reader.getVectorSchemaRoot(); + val columns = ArrowUtils.toColumnMetaData(schemaRoot.getSchema().getFields()); + val timezone = TimeZone.getDefault(); + val state = new QueryState(); + val signature = new Meta.Signature( + columns, null, Collections.emptyList(), Collections.emptyMap(), null, Meta.StatementType.SELECT); + val metadata = new AvaticaResultSetMetaData(null, null, signature); + val listener = new AlreadyReadyNoopListener(queryId); + val result = new StreamingResultSet(listener, null, state, signature, metadata, timezone, null); + val cursor = new ArrowStreamReaderCursor(reader); + result.execute2(cursor, columns); + + return result; + } catch (Exception ex) { + throw QueryExceptionHandler.createException(QUERY_FAILURE + queryId, ex); + } + } + @Override public String getQueryId() { return listener.getQueryId(); @@ -87,4 +117,24 @@ public String getStatus() { public boolean isReady() { return listener.isReady(); } + + private static final String QUERY_FAILURE = "Failed to execute query: "; + + @Value + private static class AlreadyReadyNoopListener implements QueryStatusListener { + String queryId; + String status = "Status should be determined via DataCloudConnection::getStatus"; + String query = null; + boolean ready = true; + + @Override + public DataCloudResultSet generateResultSet() { + return null; + } + + @Override + public Stream stream() { + return Stream.empty(); + } + } } diff --git a/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListener.java b/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListener.java index b2656f16..ac6773eb 100644 --- a/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListener.java +++ b/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListener.java @@ -20,6 +20,7 @@ import java.util.stream.Stream; import salesforce.cdp.hyperdb.v1.QueryResult; +@Deprecated public interface QueryStatusListener { String BEFORE_READY = "Results were requested before ready"; diff --git a/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/partial/RowBased.java b/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/partial/RowBased.java new file mode 100644 index 00000000..27dcaa78 --- /dev/null +++ b/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/core/partial/RowBased.java @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.partial; + +import com.salesforce.datacloud.jdbc.core.HyperGrpcClientExecutor; +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.concurrent.atomic.AtomicLong; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import lombok.val; +import salesforce.cdp.hyperdb.v1.QueryResult; + +@Builder +class RowBasedContext { + @NonNull private final HyperGrpcClientExecutor client; + + @NonNull private final String queryId; + + private final long offset; + + @Getter + private final long limit; + + @Getter + private final AtomicLong seen = new AtomicLong(0); + + public Iterator getQueryResult(boolean omitSchema) { + val currentOffset = offset + seen.get(); + val currentLimit = limit - seen.get(); + return client.getQueryResult(queryId, currentOffset, currentLimit, omitSchema); + } +} + +/** + * Row based results can be acquired with a QueryId and a row range, the behavior of getting more rows is determined by the {@link RowBased.Mode}: + * {@link RowBased.Mode#SINGLE_RPC} execute a single GetQueryResult calls, we will return whatever data is on this response if any is available + * {@link RowBased.Mode#FULL_RANGE} execute as many GetQueryResult calls until all available rows are exhausted + */ +public interface RowBased extends Iterator { + enum Mode { + SINGLE_RPC, + FULL_RANGE + } + + static RowBased of( + @NonNull HyperGrpcClientExecutor client, + @NonNull String queryId, + long offset, + long limit, + @NonNull Mode mode) { + val context = RowBasedContext.builder() + .client(client) + .queryId(queryId) + .offset(offset) + .limit(limit) + .build(); + switch (mode) { + case SINGLE_RPC: + return RowBasedSingleRpc.builder() + .iterator(context.getQueryResult(false)) + .build(); + case FULL_RANGE: + return RowBasedFullRange.builder().context(context).build(); + } + throw new IllegalArgumentException("Unknown mode not supported. mode=" + mode); + } +} + +@Builder +class RowBasedSingleRpc implements RowBased { + private final Iterator iterator; + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public QueryResult next() { + return iterator.next(); + } +} + +@Builder +class RowBasedFullRange implements RowBased { + private final RowBasedContext context; + + private Iterator iterator; + + @Override + public boolean hasNext() { + if (iterator == null) { + iterator = context.getQueryResult(false); + return iterator.hasNext(); + } + + if (iterator.hasNext()) { + return true; + } + + if (context.getSeen().get() < context.getLimit()) { + iterator = context.getQueryResult(true); + } + + return iterator.hasNext(); + } + + @Override + public QueryResult next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + val next = iterator.next(); + context.getSeen().addAndGet(next.getResultPartRowCount()); + return next; + } +} diff --git a/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/util/StreamUtilities.java b/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/util/StreamUtilities.java index 1ca8df2a..d9c16242 100644 --- a/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/util/StreamUtilities.java +++ b/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/util/StreamUtilities.java @@ -21,6 +21,7 @@ import java.util.Spliterators; import java.util.function.Consumer; import java.util.function.LongSupplier; +import java.util.function.Predicate; import java.util.function.Supplier; import java.util.function.UnaryOperator; import java.util.stream.Stream; @@ -54,4 +55,26 @@ public Optional tryTimes( .findFirst() .flatMap(Result::get); } + + public Stream takeWhile(Stream stream, Predicate predicate) { + val split = stream.spliterator(); + + return StreamSupport.stream( + new Spliterators.AbstractSpliterator(split.estimateSize(), split.characteristics()) { + boolean shouldContinue = true; + + @Override + public boolean tryAdvance(Consumer action) { + return shouldContinue + && split.tryAdvance(elem -> { + if (predicate.test(elem)) { + action.accept(elem); + } else { + shouldContinue = false; + } + }); + } + }, + false); + } } diff --git a/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/util/Unstable.java b/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/util/Unstable.java new file mode 100644 index 00000000..ac036904 --- /dev/null +++ b/jdbc-slim/src/main/java/com/salesforce/datacloud/jdbc/util/Unstable.java @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import java.lang.annotation.Documented; + +@Documented +public @interface Unstable {} diff --git a/jdbc-slim/src/test/java/com/salesforce/datacloud/jdbc/util/StreamUtilitiesTest.java b/jdbc-slim/src/test/java/com/salesforce/datacloud/jdbc/util/StreamUtilitiesTest.java new file mode 100644 index 00000000..4998722b --- /dev/null +++ b/jdbc-slim/src/test/java/com/salesforce/datacloud/jdbc/util/StreamUtilitiesTest.java @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.junit.jupiter.api.Test; + +@Slf4j +class StreamUtilitiesTest { + @Test + void testTakeWhileSomeMatch() { + val stream = Stream.of(1, 2, 3, 4, 5, 6, 7); + + val result = StreamUtilities.takeWhile(stream, x -> x < 5).collect(Collectors.toList()); + + assertThat(result).containsExactly(1, 2, 3, 4); + } + + @Test + void testTakeWhileAllMatch() { + val stream = Stream.of(1, 2, 3, 4); + + val result = StreamUtilities.takeWhile(stream, x -> x < 10).collect(Collectors.toList()); + + assertThat(result).containsExactly(1, 2, 3, 4); + } + + @Test + void testTakeWhileNoMatch() { + val stream = Stream.of(1, 2, 3); + + val result = StreamUtilities.takeWhile(stream, x -> x < 0).collect(Collectors.toList()); + + assertThat(result).isEmpty(); + } + + @Test + void testTakeWhileEmptyStream() { + Stream stream = Stream.empty(); + Predicate predicate = x -> x < 5; + + val result = StreamUtilities.takeWhile(stream, predicate).collect(Collectors.toList()); + + assertThat(result).isEmpty(); + } + + @Test + void testTryTimesSuccessFirstTry() { + val result = StreamUtilities.tryTimes(3, () -> Stream.of("Success"), this::consumer); + + assertThat(result).isPresent(); + assertThat(result.get().collect(Collectors.toList())).containsExactly("Success"); + } + + @Test + void testTryTimesSomeFailures() { + Consumer mockConsumer = mock(Consumer.class); + + val counter = new AtomicInteger(0); + + val result = StreamUtilities.tryTimes( + 3, + () -> { + if (counter.incrementAndGet() < 3) { + throw new RuntimeException("Failure " + counter.get()); + } + return Stream.of("Success"); + }, + mockConsumer); + + verify(mockConsumer, times(2)).accept(any(Throwable.class)); + + assertThat(result).isPresent(); + assertThat(result.get().collect(Collectors.toList())).containsExactly("Success"); + } + + @Test + void testTryTimesNoAttemptsAllowed() { + val result = StreamUtilities.tryTimes(0, () -> Stream.of("Never runs"), this::consumer); + + assertThat(result).isEmpty(); + } + + @Test + void testTryTimesAlwaysFails() { + val result = StreamUtilities.tryTimes( + 3, + () -> { + throw new RuntimeException("Always fails"); + }, + this::consumer); + + assertThat(result).isNotPresent(); + } + + private void consumer(Throwable err) { + log.error("consumed throwable", err); + } +} diff --git a/pom.xml b/pom.xml index e7a2690b..d5a75607 100644 --- a/pom.xml +++ b/pom.xml @@ -451,8 +451,9 @@ ${project.build.directory}/delombok;${project.build.directory}/generated-sources/protobuf ${java.version} true - com.salesforce.hyperdb.grpc + salesforce.cdp.hyperdb.v1.* ${project.build.directory}/apidocs + false diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/partial/RowBasedTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/partial/RowBasedTest.java new file mode 100644 index 00000000..e616e609 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/partial/RowBasedTest.java @@ -0,0 +1,212 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.partial; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; +import static org.mockito.Mockito.mock; + +import com.salesforce.datacloud.jdbc.core.DataCloudConnection; +import com.salesforce.datacloud.jdbc.core.DataCloudQueryStatus; +import com.salesforce.datacloud.jdbc.core.DataCloudResultSet; +import com.salesforce.datacloud.jdbc.core.DataCloudStatement; +import com.salesforce.datacloud.jdbc.core.HyperGrpcClientExecutor; +import com.salesforce.datacloud.jdbc.hyper.HyperTestBase; +import com.salesforce.datacloud.jdbc.util.StreamUtilities; +import io.grpc.StatusRuntimeException; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.LongStream; +import java.util.stream.Stream; +import lombok.SneakyThrows; +import lombok.Value; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; + +@Slf4j +class RowBasedTest extends HyperTestBase { + private List sut(String queryId, long offset, long limit, RowBased.Mode mode) { + val connection = getHyperQueryConnection(); + val resultSet = connection.getRowBasedResultSet(queryId, offset, limit, mode); + return toList(resultSet); + } + + private static final int tinySize = 8; + private static final int smallSize = 32; + private static final int largeSize = 1024 * 1024 * 10; + + private String tiny; + private String small; + private String large; + + @BeforeAll + void setupQueries() { + large = getQueryId(largeSize); + small = getQueryId(smallSize); + tiny = getQueryId(tinySize); + waitForQuery(large); + waitForQuery(small); + waitForQuery(tiny); + } + + @Test + void singleRpcReturnsIteratorButNotRowBasedFullRange() { + val client = mock(HyperGrpcClientExecutor.class); + val single = RowBased.of(client, "select 1", 0, 1, RowBased.Mode.SINGLE_RPC); + + assertThat(single).isInstanceOf(RowBasedSingleRpc.class).isNotInstanceOf(RowBasedFullRange.class); + } + + @Test + void fullRangeReturnsRowBasedFullRange() { + val client = mock(HyperGrpcClientExecutor.class); + val single = RowBased.of(client, "select 1", 0, 1, RowBased.Mode.FULL_RANGE); + + assertThat(single).isInstanceOf(RowBasedFullRange.class).isNotInstanceOf(RowBasedSingleRpc.class); + } + + @SneakyThrows + @ParameterizedTest + @EnumSource(RowBased.Mode.class) + void fetchWhereActualLessThanPageSize(RowBased.Mode mode) { + val limit = 10; + + assertThat(sut(small, 0, limit, mode)).containsExactlyElementsOf(rangeClosed(1, 10)); + assertThat(sut(small, 10, limit, mode)).containsExactlyElementsOf(rangeClosed(11, 20)); + assertThat(sut(small, 20, limit, mode)).containsExactlyElementsOf(rangeClosed(21, 30)); + assertThat(sut(small, 30, 2, mode)).containsExactlyElementsOf(rangeClosed(31, 32)); + } + + @Test + void fetchWhereActualMoreThanPageSize_SINGLE_RPC() { + val actual = sut(small, 0, smallSize * 2, RowBased.Mode.SINGLE_RPC); + assertThat(actual).isNotEmpty().isSubsetOf(rangeClosed(1, smallSize)); + } + + /** + * DataCloudConnection::getRowBasedResultSet is not responsible for calculating the offset near the end of available rows + */ + @SneakyThrows + @Test + void throwsWhenFullRangeOverrunsAvailableRows() { + assertThatThrownBy(() -> sut(tiny, 0, tinySize * 3, RowBased.Mode.FULL_RANGE)) + .hasRootCauseInstanceOf(StatusRuntimeException.class) + .hasRootCauseMessage(String.format( + "OUT_OF_RANGE: Request out of range: The specified offset is %d, but only %d tuples are available", + tinySize, tinySize)); + } + + @SneakyThrows + @Test + void fetchWithRowsNearEndRange_FULL_RANGE() { + val threads = 3; + final long rows; + try (val conn = getHyperQueryConnection()) { + rows = conn.getQueryStatus(small) + .filter(t -> t.isResultProduced() || t.isExecutionFinished()) + .map(DataCloudQueryStatus::getRowCount) + .findFirst() + .orElseThrow(() -> new RuntimeException("boom")); + } + + val baseSize = rows / threads; + val remainder = rows % threads; + + val pages = StreamUtilities.takeWhile( + LongStream.range(0, threads).mapToObj(i -> { + val limit = baseSize + (i < remainder ? 1 : 0); + val offset = baseSize * i + Math.min(i, remainder); + return new Page(offset, limit); + }), + p -> p.limit > 0); + + try (val conn = getHyperQueryConnection()) { + val actual = pages.parallel() + .map(p -> { + log.info("Executing FULL_RANGE request for page {}", p); + return conn.getRowBasedResultSet(small, p.offset, p.limit, RowBased.Mode.FULL_RANGE); + }) + .flatMap(RowBasedTest::toStream) + .collect(Collectors.toList()); + assertThat(actual).containsExactlyElementsOf(rangeClosed(1, smallSize)); + } + } + + @Value + private static class Page { + long offset; + long limit; + } + + @SneakyThrows + private String getQueryId(int max) { + val query = String.format( + "select a, cast(a as numeric(38,18)) b, cast(a as numeric(38,18)) c, cast(a as numeric(38,18)) d from generate_series(1, %d) as s(a) order by a asc", + max); + + try (val client = getHyperQueryConnection(); + val statement = client.createStatement().unwrap(DataCloudStatement.class)) { + statement.executeAsyncQuery(query); + return statement.getQueryId(); + } + } + + @SneakyThrows + private void waitForQuery(String queryId) { + try (val conn = getHyperQueryConnection()) { + while (!isReady(conn, queryId)) { + Thread.sleep(250); + } + } + } + + private boolean isReady(DataCloudConnection connection, String queryId) { + return connection.getQueryStatus(queryId).anyMatch(t -> t.isExecutionFinished() || t.isResultProduced()); + } + + private static List rangeClosed(int start, int end) { + return IntStream.rangeClosed(start, end).boxed().collect(Collectors.toList()); + } + + private static Stream toStream(DataCloudResultSet resultSet) { + val iterator = new Iterator() { + @SneakyThrows + @Override + public boolean hasNext() { + return resultSet.next(); + } + + @SneakyThrows + @Override + public Integer next() { + return resultSet.getInt(1); + } + }; + + return StreamUtilities.toStream(iterator); + } + + private static List toList(DataCloudResultSet resultSet) { + + return toStream(resultSet).collect(Collectors.toList()); + } +}