Skip to content

Commit 7c80949

Browse files
fix(plugin-arrow): Handle restricted output columns in Arrow Page Source (#26175)
Restriced columns need to be handled when using table valued functions in query. ## Description <!---Describe your changes in detail--> When using table valued functions (TVF) in queries like given below ```sql SELECT id from TABLE(system.query_function( 'SELECT name, id FROM tpch.member WHERE id = 1', 'name VARCHAR, id INTEGER')) ``` `ArrowPageSource` can fail to map the correct `FieldVector` for the column. This happens because a TVF like above executes the query natively in Flight server and gives a result with two columns but `ArrowPageSource` is expected to only return results for 1 column, ie `id` from above query. The column `name` is restricted from the TVF result ## Motivation and Context <!---Why is this change required? What problem does it solve?--> <!---If it fixes an open issue, please link to the issue here.--> This change is required when using TVF against a catalog based on `presto-base-arrow-flight` module. ## Impact <!---Describe any public API or user-facing feature change or any performance impact--> This fixes problems when using TVF in a way given above against Arrow Flight based catalogs. ## Test Plan <!---Please fill in how you tested your change--> New unit test added that will test this change. This PR also includes an implementation of TVF called `query_function` that will test this change. This change is backward compatible, so existing test cases will also pass with this change. ## Contributor checklist - [ ] Please make sure your submission complies with our [contributing guide](https://github.com/prestodb/presto/blob/master/CONTRIBUTING.md), in particular [code style](https://github.com/prestodb/presto/blob/master/CONTRIBUTING.md#code-style) and [commit standards](https://github.com/prestodb/presto/blob/master/CONTRIBUTING.md#commit-standards). - [ ] PR description addresses the issue accurately and concisely. If the change is non-trivial, a GitHub Issue is referenced. - [ ] Documented new properties (with its default value), SQL syntax, functions, or other functionality. - [ ] If release notes are required, they follow the [release notes guidelines](https://github.com/prestodb/presto/wiki/Release-Notes-Guidelines). - [ ] Adequate tests were added if applicable. - [ ] CI passed. ## Release Notes Please follow [release notes guidelines](https://github.com/prestodb/presto/wiki/Release-Notes-Guidelines) and fill in the release notes below. ``` == NO RELEASE NOTE == ```
1 parent 4162b53 commit 7c80949

File tree

11 files changed

+469
-12
lines changed

11 files changed

+469
-12
lines changed

presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import com.facebook.presto.spi.ConnectorPageSource;
2121
import com.facebook.presto.spi.ConnectorSession;
2222
import org.apache.arrow.vector.FieldVector;
23+
import org.apache.arrow.vector.VectorSchemaRoot;
2324

2425
import java.util.ArrayList;
2526
import java.util.List;
@@ -97,16 +98,19 @@ public Page getNextPage()
9798

9899
// Create blocks from the loaded Arrow record batch
99100
List<Block> blocks = new ArrayList<>();
100-
List<FieldVector> vectors = flightStreamAndClient.getRoot().getFieldVectors();
101-
for (int columnIndex = 0; columnIndex < columnHandles.size(); columnIndex++) {
102-
FieldVector vector = vectors.get(columnIndex);
103-
Type type = columnHandles.get(columnIndex).getColumnType();
101+
VectorSchemaRoot vectorSchemaRoot = flightStreamAndClient.getRoot();
102+
for (ArrowColumnHandle columnHandle : columnHandles) {
103+
// In scenarios where the user query contains a Table Valued Function, the output columns could be in a
104+
// different order or could be a subset of the columns in the flight stream. So we are fetching the requested
105+
// field vector by matching the column name instead of fetching by column index.
106+
FieldVector vector = requireNonNull(vectorSchemaRoot.getVector(columnHandle.getColumnName()), "No field named " + columnHandle.getColumnName() + " in the list of vectors from flight stream");
107+
Type type = columnHandle.getColumnType();
104108
Block block = arrowBlockBuilder.buildBlockFromFieldVector(vector, type, flightStreamAndClient.getDictionaryProvider());
105109
blocks.add(block);
106110
}
107111

108112
if (logger.isDebugEnabled()) {
109-
logger.debug("Read Arrow record batch with rows: %s, columns: %s", flightStreamAndClient.getRoot().getRowCount(), vectors.size());
113+
logger.debug("Read Arrow record batch with rows: %s, columns: %s", flightStreamAndClient.getRoot().getRowCount(), vectorSchemaRoot.getFieldVectors().size());
110114
}
111115

112116
return new Page(flightStreamAndClient.getRoot().getRowCount(), blocks.toArray(new Block[0]));

presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightNativeQueries.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,18 @@ protected FeaturesConfig createFeaturesConfig()
118118
return new FeaturesConfig().setNativeExecutionEnabled(true);
119119
}
120120

121+
@Test
122+
public void testQueryFunctionWithRestrictedColumns()
123+
{
124+
assertQuery("SELECT NAME FROM TABLE(system.query_function('SELECT NATIONKEY, NAME FROM tpch.nation WHERE NATIONKEY = 4','NATIONKEY BIGINT, NAME VARCHAR'))", "SELECT NAME FROM nation WHERE NATIONKEY = 4");
125+
}
126+
127+
@Test
128+
public void testQueryFunctionWithoutRestrictedColumns() throws InterruptedException
129+
{
130+
assertQuery("SELECT NAME, NATIONKEY FROM TABLE(system.query_function('SELECT NATIONKEY, NAME FROM tpch.nation WHERE NATIONKEY = 4','NATIONKEY BIGINT, NAME VARCHAR'))", "SELECT NAME, NATIONKEY FROM nation WHERE NATIONKEY = 4");
131+
}
132+
121133
@Test
122134
public void testFiltersAndProjections1()
123135
{

presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightQueries.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,30 @@ public void testDescribeUnknownTable()
166166
assertEquals(actualRows, expectedRows);
167167
}
168168

169+
@Test
170+
public void testQueryFunctionWithRestrictedColumns()
171+
{
172+
assertQuery("SELECT NAME FROM TABLE(system.query_function('SELECT NATIONKEY, NAME FROM tpch.nation WHERE NATIONKEY = 4','NATIONKEY BIGINT, NAME VARCHAR'))", "SELECT NAME FROM nation WHERE NATIONKEY = 4");
173+
}
174+
175+
@Test
176+
public void testQueryFunctionWithoutRestrictedColumns()
177+
{
178+
assertQuery("SELECT NATIONKEY, NAME FROM TABLE(system.query_function('SELECT NATIONKEY, NAME FROM tpch.nation WHERE NATIONKEY = 4','NATIONKEY BIGINT, NAME VARCHAR'))", "SELECT NATIONKEY, NAME FROM nation WHERE NATIONKEY = 4");
179+
}
180+
181+
@Test
182+
public void testQueryFunctionWithDifferentColumnOrder()
183+
{
184+
assertQuery("SELECT NAME, NATIONKEY FROM TABLE(system.query_function('SELECT NATIONKEY, NAME FROM tpch.nation WHERE NATIONKEY = 4','NATIONKEY BIGINT, NAME VARCHAR'))", "SELECT NAME, NATIONKEY FROM nation WHERE NATIONKEY = 4");
185+
}
186+
187+
@Test
188+
public void testQueryFunctionWithInvalidColumn()
189+
{
190+
assertQueryFails("SELECT NAME, NATIONKEY, INVALID_COLUMN FROM TABLE(system.query_function('SELECT NATIONKEY, NAME FROM tpch.nation WHERE NATIONKEY = 4','NATIONKEY BIGINT, NAME VARCHAR'))", "Column 'invalid_column' cannot be resolved", true);
191+
}
192+
169193
private LocalDate getDate(String dateString)
170194
{
171195
DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd");
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.plugin.arrow.testingConnector;
15+
16+
import com.facebook.presto.common.type.BigintType;
17+
import com.facebook.presto.common.type.BooleanType;
18+
import com.facebook.presto.common.type.DateType;
19+
import com.facebook.presto.common.type.DoubleType;
20+
import com.facebook.presto.common.type.IntegerType;
21+
import com.facebook.presto.common.type.RealType;
22+
import com.facebook.presto.common.type.SmallintType;
23+
import com.facebook.presto.common.type.TimeType;
24+
import com.facebook.presto.common.type.TimestampType;
25+
import com.facebook.presto.common.type.Type;
26+
import com.facebook.presto.spi.PrestoException;
27+
28+
import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType;
29+
import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
30+
31+
public final class PrimitiveToPrestoTypeMappings
32+
{
33+
private PrimitiveToPrestoTypeMappings()
34+
{
35+
throw new UnsupportedOperationException();
36+
}
37+
38+
public static Type fromPrimitiveToPrestoType(String dataType)
39+
{
40+
switch (dataType) {
41+
case "INTEGER":
42+
return IntegerType.INTEGER;
43+
case "VARCHAR":
44+
return createUnboundedVarcharType();
45+
case "DOUBLE":
46+
return DoubleType.DOUBLE;
47+
case "SMALLINT":
48+
return SmallintType.SMALLINT;
49+
case "BOOLEAN":
50+
return BooleanType.BOOLEAN;
51+
case "TIMESTAMP":
52+
return TimestampType.TIMESTAMP;
53+
case "TIME":
54+
return TimeType.TIME;
55+
case "REAL":
56+
return RealType.REAL;
57+
case "DATE":
58+
return DateType.DATE;
59+
case "BIGINT":
60+
return BigintType.BIGINT;
61+
}
62+
throw new PrestoException(NOT_SUPPORTED, "Unsupported datatype '" + dataType + "' in the selected table.");
63+
}
64+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.plugin.arrow.testingConnector;
15+
16+
import com.facebook.plugin.arrow.ArrowConnector;
17+
import com.facebook.presto.spi.connector.ConnectorMetadata;
18+
import com.facebook.presto.spi.connector.ConnectorPageSourceProvider;
19+
import com.facebook.presto.spi.connector.ConnectorSplitManager;
20+
import com.facebook.presto.spi.function.table.ConnectorTableFunction;
21+
import com.google.common.collect.ImmutableSet;
22+
import com.google.inject.Inject;
23+
import org.apache.arrow.memory.BufferAllocator;
24+
25+
import java.util.Set;
26+
27+
import static java.util.Objects.requireNonNull;
28+
29+
public class TestingArrowConnector
30+
extends ArrowConnector
31+
{
32+
private final Set<ConnectorTableFunction> connectorTableFunctions;
33+
34+
@Inject
35+
public TestingArrowConnector(ConnectorMetadata metadata, ConnectorSplitManager splitManager, ConnectorPageSourceProvider pageSourceProvider, Set<ConnectorTableFunction> connectorTableFunctions, BufferAllocator allocator)
36+
{
37+
super(metadata, splitManager, pageSourceProvider, allocator);
38+
this.connectorTableFunctions = ImmutableSet.copyOf(requireNonNull(connectorTableFunctions, "connectorTableFunctions is null"));
39+
}
40+
41+
@Override
42+
public Set<ConnectorTableFunction> getTableFunctions()
43+
{
44+
return connectorTableFunctions;
45+
}
46+
}

presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightClientHandler.java

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,25 @@ public List<SchemaTableName> listTables(ConnectorSession session, Optional<Strin
143143
public FlightDescriptor getFlightDescriptorForTableScan(ConnectorSession session, ArrowTableLayoutHandle tableLayoutHandle)
144144
{
145145
ArrowTableHandle tableHandle = tableLayoutHandle.getTable();
146-
String query = new TestingArrowQueryBuilder().buildSql(
147-
tableHandle.getSchema(),
148-
tableHandle.getTable(),
149-
tableLayoutHandle.getColumnHandles(), ImmutableMap.of(),
150-
tableLayoutHandle.getTupleDomain());
151-
TestingArrowFlightRequest request = TestingArrowFlightRequest.createQueryRequest(tableHandle.getSchema(), tableHandle.getTable(), query);
146+
147+
String query;
148+
String table;
149+
150+
if (tableHandle instanceof TestingQueryArrowTableHandle) {
151+
TestingQueryArrowTableHandle testingQueryArrowTableHandle = (TestingQueryArrowTableHandle) tableHandle;
152+
query = testingQueryArrowTableHandle.getQuery();
153+
table = null;
154+
}
155+
else {
156+
query = new TestingArrowQueryBuilder().buildSql(
157+
tableHandle.getSchema(),
158+
tableHandle.getTable(),
159+
tableLayoutHandle.getColumnHandles(), ImmutableMap.of(),
160+
tableLayoutHandle.getTupleDomain());
161+
table = tableHandle.getTable();
162+
}
163+
164+
TestingArrowFlightRequest request = TestingArrowFlightRequest.createQueryRequest(tableHandle.getSchema(), table, query);
152165
return FlightDescriptor.command(requestCodec.toBytes(request));
153166
}
154167

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.plugin.arrow.testingConnector;
15+
16+
import com.facebook.plugin.arrow.ArrowBlockBuilder;
17+
import com.facebook.plugin.arrow.ArrowColumnHandle;
18+
import com.facebook.plugin.arrow.ArrowFlightConfig;
19+
import com.facebook.plugin.arrow.ArrowMetadata;
20+
import com.facebook.plugin.arrow.BaseArrowFlightClientHandler;
21+
import com.facebook.plugin.arrow.testingConnector.tvf.QueryFunctionProvider;
22+
import com.facebook.presto.spi.ColumnHandle;
23+
import com.facebook.presto.spi.ColumnMetadata;
24+
import com.facebook.presto.spi.ConnectorSession;
25+
import com.facebook.presto.spi.ConnectorTableHandle;
26+
import com.facebook.presto.spi.ConnectorTableMetadata;
27+
import com.facebook.presto.spi.SchemaTableName;
28+
import com.facebook.presto.spi.connector.TableFunctionApplicationResult;
29+
import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle;
30+
import jakarta.inject.Inject;
31+
32+
import java.util.ArrayList;
33+
import java.util.List;
34+
import java.util.Map;
35+
import java.util.Optional;
36+
import java.util.stream.Collectors;
37+
38+
public class TestingArrowMetadata
39+
extends ArrowMetadata
40+
{
41+
@Inject
42+
public TestingArrowMetadata(BaseArrowFlightClientHandler clientHandler, ArrowBlockBuilder arrowBlockBuilder, ArrowFlightConfig config)
43+
{
44+
super(clientHandler, arrowBlockBuilder, config);
45+
}
46+
47+
@Override
48+
public Optional<TableFunctionApplicationResult<ConnectorTableHandle>> applyTableFunction(ConnectorSession session, ConnectorTableFunctionHandle handle)
49+
{
50+
if (handle instanceof QueryFunctionProvider.QueryFunctionHandle) {
51+
QueryFunctionProvider.QueryFunctionHandle functionHandle = (QueryFunctionProvider.QueryFunctionHandle) handle;
52+
return Optional.of(new TableFunctionApplicationResult<>(functionHandle.getTableHandle(), new ArrayList<>(functionHandle.getTableHandle().getColumns())));
53+
}
54+
return Optional.empty();
55+
}
56+
57+
@Override
58+
public Map<String, ColumnHandle> getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle)
59+
{
60+
if (tableHandle instanceof TestingQueryArrowTableHandle) {
61+
TestingQueryArrowTableHandle queryArrowTableHandle = (TestingQueryArrowTableHandle) tableHandle;
62+
return queryArrowTableHandle.getColumns().stream().collect(Collectors.toMap(c -> normalizeIdentifier(session, c.getColumnName()), c -> c));
63+
}
64+
else {
65+
return super.getColumnHandles(session, tableHandle);
66+
}
67+
}
68+
69+
@Override
70+
public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle)
71+
{
72+
if (tableHandle instanceof TestingQueryArrowTableHandle) {
73+
TestingQueryArrowTableHandle queryArrowTableHandle = (TestingQueryArrowTableHandle) tableHandle;
74+
75+
List<ColumnMetadata> meta = new ArrayList<>();
76+
for (ArrowColumnHandle columnHandle : queryArrowTableHandle.getColumns()) {
77+
meta.add(ColumnMetadata.builder().setName(normalizeIdentifier(session, columnHandle.getColumnName())).setType(columnHandle.getColumnType()).build());
78+
}
79+
return new ConnectorTableMetadata(new SchemaTableName(queryArrowTableHandle.getSchema(), queryArrowTableHandle.getTable()), meta);
80+
}
81+
else {
82+
return super.getTableMetadata(session, tableHandle);
83+
}
84+
}
85+
}

presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowModule.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,19 @@
1414
package com.facebook.plugin.arrow.testingConnector;
1515

1616
import com.facebook.plugin.arrow.ArrowBlockBuilder;
17+
import com.facebook.plugin.arrow.ArrowConnector;
1718
import com.facebook.plugin.arrow.BaseArrowFlightClientHandler;
19+
import com.facebook.plugin.arrow.testingConnector.tvf.QueryFunctionProvider;
1820
import com.facebook.plugin.arrow.testingServer.TestingArrowFlightRequest;
1921
import com.facebook.plugin.arrow.testingServer.TestingArrowFlightResponse;
22+
import com.facebook.presto.spi.connector.ConnectorMetadata;
23+
import com.facebook.presto.spi.function.table.ConnectorTableFunction;
2024
import com.google.inject.Binder;
2125
import com.google.inject.Module;
2226
import com.google.inject.Scopes;
2327

2428
import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder;
29+
import static com.google.inject.multibindings.Multibinder.newSetBinder;
2530

2631
public class TestingArrowModule
2732
implements Module
@@ -36,6 +41,9 @@ public TestingArrowModule(boolean nativeExecution)
3641
@Override
3742
public void configure(Binder binder)
3843
{
44+
binder.bind(ConnectorMetadata.class).to(TestingArrowMetadata.class).in(Scopes.SINGLETON);
45+
newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(QueryFunctionProvider.class).in(Scopes.SINGLETON);
46+
binder.bind(ArrowConnector.class).to(TestingArrowConnector.class).in(Scopes.SINGLETON);
3947
// Concrete implementation of the BaseFlightClientHandler
4048
binder.bind(BaseArrowFlightClientHandler.class).to(TestingArrowFlightClientHandler.class).in(Scopes.SINGLETON);
4149
// Override the ArrowBlockBuilder with an implementation that handles h2 types, skip for native
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.plugin.arrow.testingConnector;
15+
16+
import com.facebook.plugin.arrow.ArrowColumnHandle;
17+
import com.facebook.plugin.arrow.ArrowTableHandle;
18+
import com.fasterxml.jackson.annotation.JsonCreator;
19+
import com.fasterxml.jackson.annotation.JsonProperty;
20+
21+
import java.util.Collections;
22+
import java.util.List;
23+
import java.util.UUID;
24+
25+
import static java.util.Objects.requireNonNull;
26+
public class TestingQueryArrowTableHandle
27+
extends ArrowTableHandle
28+
{
29+
private final String query;
30+
private final List<ArrowColumnHandle> columns;
31+
32+
@JsonCreator
33+
public TestingQueryArrowTableHandle(String query, List<ArrowColumnHandle> columns)
34+
{
35+
super("schema-" + UUID.randomUUID(), "table-" + UUID.randomUUID());
36+
this.columns = Collections.unmodifiableList(requireNonNull(columns));
37+
this.query = requireNonNull(query);
38+
}
39+
40+
@JsonProperty
41+
public String getQuery()
42+
{
43+
return query;
44+
}
45+
46+
@JsonProperty
47+
public List<ArrowColumnHandle> getColumns()
48+
{
49+
return columns;
50+
}
51+
}

0 commit comments

Comments
 (0)