diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/MergedResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/MergedResultSet.java index fcbc49f346d..1cbbf0818c5 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/MergedResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/MergedResultSet.java @@ -25,6 +25,7 @@ import com.google.cloud.spanner.SpannerExceptionFactory; import com.google.cloud.spanner.Struct; import com.google.cloud.spanner.Type; +import com.google.cloud.spanner.Type.Code; import com.google.common.base.Preconditions; import com.google.common.base.Supplier; import com.google.spanner.v1.ResultSetMetadata; @@ -82,9 +83,11 @@ public void run() { break; } } - if (first) { - // Special case: The result set did not return any rows. Push the metadata to the merged - // result set. + if (first + && resultSet.getType().getCode() == Code.STRUCT + && !resultSet.getType().getStructFields().isEmpty()) { + // Special case: The result set did not return any rows, but did return metadata. + // Push the metadata to the merged result set. queue.put( PartitionExecutorResult.typeAndMetadata( resultSet.getType(), resultSet.getMetadata())); @@ -319,13 +322,17 @@ public Struct get() { return currentRow; } - private PartitionExecutorResult getFirstResult() { + private PartitionExecutorResult getFirstResultWithMetadata() { try { metadataAvailableLatch.await(); } catch (InterruptedException interruptedException) { throw SpannerExceptionFactory.propagateInterrupt(interruptedException); } - PartitionExecutorResult result = queue.peek(); + PartitionExecutorResult result = + queue.stream() + .filter(rs -> rs.metadata != null || rs.exception != null) + .findFirst() + .orElse(null); if (result == null) { throw SpannerExceptionFactory.newSpannerException( ErrorCode.FAILED_PRECONDITION, "Thread-unsafe access to ResultSet"); @@ -338,7 +345,7 @@ private PartitionExecutorResult getFirstResult() { public ResultSetMetadata getMetadata() { if (metadata == null) { - return getFirstResult().metadata; + return getFirstResultWithMetadata().metadata; } return metadata; } @@ -355,7 +362,7 @@ public int getParallelism() { public Type getType() { if (type == null) { - return getFirstResult().type; + return getFirstResultWithMetadata().type; } return type; } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/MergedResultSetTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/MergedResultSetTest.java index b0465be6106..6d3950efbc3 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/MergedResultSetTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/MergedResultSetTest.java @@ -32,6 +32,8 @@ import com.google.cloud.spanner.SpannerExceptionFactory; import com.google.cloud.spanner.Struct; import com.google.cloud.spanner.Type; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.StructType; import java.util.ArrayList; import java.util.BitSet; import java.util.Collection; @@ -103,7 +105,7 @@ public static Collection parameters() { return params; } - private MockedResults setupResults(boolean withErrors) { + private MockedResults setupResults(boolean withErrors, boolean withEmptyResults) { Random random = new Random(); Connection connection = mock(Connection.class); List partitions = new ArrayList<>(); @@ -122,10 +124,22 @@ private MockedResults setupResults(boolean withErrors) { when(connection.runPartition(partition)) .thenReturn(new ResultSetWithError(ResultSetsHelper.fromProto(proto), errorIndex)); } else { - when(connection.runPartition(partition)).thenReturn(ResultSetsHelper.fromProto(proto)); - try (ResultSet resultSet = ResultSetsHelper.fromProto(proto)) { - while (resultSet.next()) { - allRows.add(resultSet.getCurrentRowAsStruct()); + if (withEmptyResults && numPartitions > 1 && index == 0) { + when(connection.runPartition(partition)) + .thenReturn( + ResultSetsHelper.fromProto( + com.google.spanner.v1.ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setRowType(StructType.newBuilder().build()) + .build()) + .build())); + } else { + when(connection.runPartition(partition)).thenReturn(ResultSetsHelper.fromProto(proto)); + try (ResultSet resultSet = ResultSetsHelper.fromProto(proto)) { + while (resultSet.next()) { + allRows.add(resultSet.getCurrentRowAsStruct()); + } } } } @@ -135,7 +149,7 @@ private MockedResults setupResults(boolean withErrors) { @Test public void testAllResultsAreReturned() { - MockedResults results = setupResults(false); + MockedResults results = setupResults(/* withErrors= */ false, /* withEmptyResults= */ false); BitSet rowsFound = new BitSet(results.allRows.size()); try (MergedResultSet resultSet = new MergedResultSet(results.connection, results.partitions, maxParallelism)) { @@ -170,7 +184,7 @@ public void testAllResultsAreReturned() { @Test public void testResultSetStopsAfterFirstError() { - MockedResults results = setupResults(true); + MockedResults results = setupResults(/* withErrors= */ true, /* withEmptyResults= */ false); try (MergedResultSet resultSet = new MergedResultSet(results.connection, results.partitions, maxParallelism)) { if (numPartitions > 0) { @@ -194,6 +208,40 @@ public void testResultSetStopsAfterFirstError() { } } + @Test + public void testResultSetReturnsNonEmptyMetadata() { + MockedResults results = setupResults(/* withErrors= */ false, /* withEmptyResults= */ true); + BitSet rowsFound = new BitSet(results.allRows.size()); + try (MergedResultSet resultSet = + new MergedResultSet(results.connection, results.partitions, maxParallelism)) { + if (numPartitions > 0) { + assertNotNull(resultSet.getMetadata()); + assertEquals(26, resultSet.getMetadata().getRowType().getFieldsCount()); + } + while (resultSet.next()) { + assertRowExists(results.allRows, resultSet.getCurrentRowAsStruct(), rowsFound); + } + if (numPartitions == 0) { + assertEquals(0, resultSet.getColumnCount()); + } else { + assertEquals(26, resultSet.getColumnCount()); + assertEquals(Type.bool(), resultSet.getColumnType(0)); + assertEquals(Type.bool(), resultSet.getColumnType("COL0")); + assertEquals(10, resultSet.getColumnIndex("COL10")); + } + // Check that all rows were found. + assertEquals(results.allRows.size(), rowsFound.nextClearBit(0)); + // Check extended metadata. + assertEquals(numPartitions, resultSet.getNumPartitions()); + if (maxParallelism > 0) { + assertEquals(Math.min(numPartitions, maxParallelism), resultSet.getParallelism()); + } else { + int processors = Runtime.getRuntime().availableProcessors(); + assertEquals(Math.min(numPartitions, processors), resultSet.getParallelism()); + } + } + } + private void assertRowExists(List expectedRows, Struct row, BitSet rowsFound) { for (int i = 0; i < expectedRows.size(); i++) { if (row.equals(expectedRows.get(i))) {