diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java index 15c33950ce14b..630f4670f6cc2 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java @@ -86,14 +86,6 @@ public Field(Optional nodeLocation, Optional relati this.aliased = aliased; } - public static Field newUnqualified(Optional name, Type type) - { - requireNonNull(name, "name is null"); - requireNonNull(type, "type is null"); - - return new Field(Optional.empty(), Optional.empty(), name, type, false, Optional.empty(), Optional.empty(), false); - } - public Optional getNodeLocation() { return nodeLocation; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java index ca6790b7f4012..ca6b746bcb628 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java @@ -1520,7 +1520,7 @@ private void verifyRequiredColumns(TableFunctionInvocation node, Map column < 0 || column >= inputScope.getRelationType().getAllFieldCount()) // hidden columns can be required as well as visible columns + .filter(column -> column < 0 || column >= inputScope.getRelationType().getVisibleFieldCount()) .findFirst() .ifPresent(column -> { throw new SemanticException(TABLE_FUNCTION_IMPLEMENTATION_ERROR, "Invalid index: %s of required column from table argument %s", column, name); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index e2c79cb02b29c..b775fdfc47521 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -143,7 +143,7 @@ import static java.lang.String.format; import static java.util.Objects.requireNonNull; -class QueryPlanner +public class QueryPlanner { private final Analysis analysis; private final VariableAllocator variableAllocator; @@ -1355,6 +1355,11 @@ private static List toSymbolReferences(List new NodeLocation(location.getLine(), location.getColumn())), variable.getName()); + } + public static class PlanAndMappings { private final PlanBuilder subPlan; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index 0bd4e47fc67b8..a2a514e7b20ae 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -346,6 +346,7 @@ protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node outputVariablesBuilder.build(), sources.stream().map(RelationPlan::getRoot).collect(toImmutableList()), inputRelationsProperties, + functionAnalysis.getCopartitioningLists(), new TableFunctionHandle(functionAnalysis.getConnectorId(), functionAnalysis.getConnectorTableFunctionHandle(), functionAnalysis.getTransactionHandle())); return new RelationPlan(root, scope, outputVariables); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index 91a92107f9f3c..9185087ef2d27 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -490,6 +490,7 @@ public PlanNode visitTableFunction(TableFunctionNode node, RewriteContext node.getOutputVariables(), node.getSources(), node.getTableArgumentProperties(), + node.getCopartitioningLists(), node.getHandle()); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java index 22d4f18e42ff9..f87c1a1bba5c5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java @@ -61,6 +61,11 @@ public C get() return userContext; } + public SimplePlanRewriter getNodeRewriter() + { + return nodeRewriter; + } + /** * Invoke the rewrite logic recursively on children of the given node and swap it * out with an identical copy with the rewritten children diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java index 97892523498c0..44f304155fca5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java @@ -22,13 +22,17 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.Immutable; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @Immutable @@ -37,9 +41,10 @@ public class TableFunctionNode { private final String name; private final Map arguments; - private final List outputVariables; + private final List properOutputs; private final List sources; private final List tableArgumentProperties; + private final List> copartitioningLists; private final TableFunctionHandle handle; @JsonCreator @@ -47,12 +52,13 @@ public TableFunctionNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("name") String name, @JsonProperty("arguments") Map arguments, - @JsonProperty("outputVariables") List outputVariables, + @JsonProperty("properOutputs") List properOutputs, @JsonProperty("sources") List sources, @JsonProperty("tableArgumentProperties") List tableArgumentProperties, + @JsonProperty("copartitioningLists") List> copartitioningLists, @JsonProperty("handle") TableFunctionHandle handle) { - this(Optional.empty(), id, Optional.empty(), name, arguments, outputVariables, sources, tableArgumentProperties, handle); + this(Optional.empty(), id, Optional.empty(), name, arguments, properOutputs, sources, tableArgumentProperties, copartitioningLists, handle); } public TableFunctionNode( @@ -61,17 +67,21 @@ public TableFunctionNode( Optional statsEquivalentPlanNode, String name, Map arguments, - List outputVariables, + List properOutputs, List sources, List tableArgumentProperties, + List> copartitioningLists, TableFunctionHandle handle) { super(sourceLocation, id, statsEquivalentPlanNode); this.name = requireNonNull(name, "name is null"); - this.arguments = requireNonNull(arguments, "arguments is null"); - this.outputVariables = requireNonNull(outputVariables, "outputVariables is null"); - this.sources = requireNonNull(sources, "sources is null"); - this.tableArgumentProperties = requireNonNull(tableArgumentProperties, "tableArgumentProperties is null"); + this.arguments = ImmutableMap.copyOf(arguments); + this.properOutputs = ImmutableList.copyOf(properOutputs); + this.sources = ImmutableList.copyOf(sources); + this.tableArgumentProperties = ImmutableList.copyOf(tableArgumentProperties); + this.copartitioningLists = requireNonNull(copartitioningLists, "copartitioningLists is null").stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); this.handle = requireNonNull(handle, "handle is null"); } @@ -87,10 +97,25 @@ public Map getArguments() return arguments; } - @JsonProperty + @Override public List getOutputVariables() { - return outputVariables; + ImmutableList.Builder variables = ImmutableList.builder(); + variables.addAll(properOutputs); + + tableArgumentProperties.stream() + .map(TableArgumentProperties::getPassThroughSpecification) + .map(PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(PassThroughColumn::getVariable) + .forEach(variables::add); + + return variables.build(); + } + + public List getProperOutputs() + { + return properOutputs; } @JsonProperty @@ -99,6 +124,12 @@ public List getTableArgumentProperties() return tableArgumentProperties; } + @JsonProperty + public List> getCopartitioningLists() + { + return copartitioningLists; + } + @JsonProperty public TableFunctionHandle getHandle() { @@ -122,35 +153,47 @@ public R accept(InternalPlanVisitor visitor, C context) public PlanNode replaceChildren(List newSources) { checkArgument(sources.size() == newSources.size(), "wrong number of new children"); - return new TableFunctionNode(getId(), name, arguments, outputVariables, newSources, tableArgumentProperties, handle); + return new TableFunctionNode(getId(), name, arguments, properOutputs, newSources, tableArgumentProperties, copartitioningLists, handle); } @Override public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) { - return new TableFunctionNode(getSourceLocation(), getId(), statsEquivalentPlanNode, name, arguments, outputVariables, sources, tableArgumentProperties, handle); + return new TableFunctionNode(getSourceLocation(), getId(), statsEquivalentPlanNode, name, arguments, properOutputs, sources, tableArgumentProperties, copartitioningLists, handle); } public static class TableArgumentProperties { + private final String argumentName; private final boolean rowSemantics; private final boolean pruneWhenEmpty; - private final boolean passThroughColumns; + private final PassThroughSpecification passThroughSpecification; + private final List requiredColumns; private final Optional specification; @JsonCreator public TableArgumentProperties( + @JsonProperty("argumentName") String argumentName, @JsonProperty("rowSemantics") boolean rowSemantics, @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, - @JsonProperty("passThroughColumns") boolean passThroughColumns, + @JsonProperty("passThroughSpecification") PassThroughSpecification passThroughSpecification, + @JsonProperty("requiredColumns") List requiredColumns, @JsonProperty("specification") Optional specification) { + this.argumentName = requireNonNull(argumentName, "argumentName is null"); this.rowSemantics = rowSemantics; this.pruneWhenEmpty = pruneWhenEmpty; - this.passThroughColumns = passThroughColumns; + this.passThroughSpecification = requireNonNull(passThroughSpecification, "passThroughSpecification is null"); + this.requiredColumns = ImmutableList.copyOf(requiredColumns); this.specification = requireNonNull(specification, "specification is null"); } + @JsonProperty + public String getArgumentName() + { + return argumentName; + } + @JsonProperty public boolean isRowSemantics() { @@ -164,15 +207,83 @@ public boolean isPruneWhenEmpty() } @JsonProperty - public boolean isPassThroughColumns() + public PassThroughSpecification getPassThroughSpecification() + { + return passThroughSpecification; + } + + @JsonProperty + public List getRequiredColumns() { - return passThroughColumns; + return requiredColumns; } @JsonProperty - public Optional specification() + public Optional getSpecification() { return specification; } } + + /** + * Specifies how columns from source tables are passed through to the output of a table function. + * This class manages both explicitly declared pass-through columns and partitioning columns + * that must be preserved in the output. + */ + public static class PassThroughSpecification + { + private final boolean declaredAsPassThrough; + private final List columns; + + @JsonCreator + public PassThroughSpecification( + @JsonProperty("declaredAsPassThrough") boolean declaredAsPassThrough, + @JsonProperty("columns") List columns) + { + this.declaredAsPassThrough = declaredAsPassThrough; + this.columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null")); + checkArgument( + declaredAsPassThrough || this.columns.stream().allMatch(PassThroughColumn::isPartitioningColumn), + "non-partitioning pass-through column for non-pass-through source of a table function"); + } + + @JsonProperty + public boolean isDeclaredAsPassThrough() + { + return declaredAsPassThrough; + } + + @JsonProperty + public List getColumns() + { + return columns; + } + } + + public static class PassThroughColumn + { + private final VariableReferenceExpression variable; + private final boolean isPartitioningColumn; + + @JsonCreator + public PassThroughColumn( + @JsonProperty("variable") VariableReferenceExpression variable, + @JsonProperty("partitioningColumn") boolean isPartitioningColumn) + { + this.variable = requireNonNull(variable, "variable is null"); + this.isPartitioningColumn = isPartitioningColumn; + } + + @JsonProperty + public VariableReferenceExpression getVariable() + { + return variable; + } + + @JsonProperty + public boolean isPartitioningColumn() + { + return isPartitioningColumn; + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index f7fd052f02c3a..6bdf3c742af26 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -772,7 +772,8 @@ public void installPlugin(Plugin plugin) @Override public void createCatalog(String catalogName, String connectorName, Map properties) { - throw new UnsupportedOperationException(); + nodeManager.addCurrentNodeConnector(new ConnectorId(catalogName)); + connectorManager.createConnection(catalogName, connectorName, properties); } @Override diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java index 96373d826b50a..316d98787cf31 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java @@ -67,18 +67,17 @@ public class TestingTableFunctions public static class TestConnectorTableFunction extends AbstractConnectorTableFunction { - private static final String TEST_FUNCTION = "test_function"; - + private static final String FUNCTION_NAME = "test_function"; public TestConnectorTableFunction() { - super(SCHEMA_NAME, TEST_FUNCTION, ImmutableList.of(), ReturnTypeSpecification.GenericTable.GENERIC_TABLE); + super(SCHEMA_NAME, FUNCTION_NAME, ImmutableList.of(), ReturnTypeSpecification.GenericTable.GENERIC_TABLE); } @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { return TableFunctionAnalysis.builder() - .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, TEST_FUNCTION))) + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field("c1", Optional.of(BOOLEAN))))) .build(); } @@ -87,11 +86,10 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class TestConnectorTableFunction2 extends AbstractConnectorTableFunction { - private static final String TEST_FUNCTION_2 = "test_function2"; - + private static final String FUNCTION_NAME = "test_function2"; public TestConnectorTableFunction2() { - super(SCHEMA_NAME, TEST_FUNCTION_2, ImmutableList.of(), ONLY_PASS_THROUGH); + super(SCHEMA_NAME, FUNCTION_NAME, ImmutableList.of(), ONLY_PASS_THROUGH); } @Override @@ -104,11 +102,10 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class NullArgumentsTableFunction extends AbstractConnectorTableFunction { - private static final String NULL_ARGUMENTS_FUNCTION = "null_arguments_function"; - + private static final String FUNCTION_NAME = "null_arguments_function"; public NullArgumentsTableFunction() { - super(SCHEMA_NAME, NULL_ARGUMENTS_FUNCTION, null, ONLY_PASS_THROUGH); + super(SCHEMA_NAME, FUNCTION_NAME, null, ONLY_PASS_THROUGH); } @Override @@ -121,12 +118,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class DuplicateArgumentsTableFunction extends AbstractConnectorTableFunction { - private static final String DUPLICATE_ARGUMENTS_FUNCTION = "duplicate_arguments_function"; + private static final String FUNCTION_NAME = "duplicate_arguments_function"; public DuplicateArgumentsTableFunction() { super( SCHEMA_NAME, - DUPLICATE_ARGUMENTS_FUNCTION, + FUNCTION_NAME, ImmutableList.of( ScalarArgumentSpecification.builder().name("a").type(INTEGER).build(), ScalarArgumentSpecification.builder().name("a").type(INTEGER).build()), @@ -143,12 +140,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class MultipleRSTableFunction extends AbstractConnectorTableFunction { - private static final String MULTIPLE_SOURCES_FUNCTION = "multiple_sources_function"; + private static final String FUNCTION_NAME = "multiple_sources_function"; public MultipleRSTableFunction() { super( SCHEMA_NAME, - MULTIPLE_SOURCES_FUNCTION, + FUNCTION_NAME, ImmutableList.of(TableArgumentSpecification.builder().name("t").rowSemantics().build(), TableArgumentSpecification.builder().name("t2").rowSemantics().build()), ONLY_PASS_THROUGH); @@ -172,7 +169,6 @@ public static class SimpleTableFunction { private static final String FUNCTION_NAME = "simple_table_function"; private static final String TABLE_NAME = "simple_table"; - public SimpleTableFunction() { super( @@ -227,11 +223,12 @@ public TestTVFConnectorTableHandle getTableHandle() public static class TwoScalarArgumentsFunction extends AbstractConnectorTableFunction { + private static final String FUNCTION_NAME = "two_scalar_arguments_function"; public TwoScalarArgumentsFunction() { super( SCHEMA_NAME, - "two_arguments_function", + FUNCTION_NAME, ImmutableList.of( ScalarArgumentSpecification.builder() .name("TEXT") @@ -256,7 +253,6 @@ public static class TableArgumentFunction extends AbstractConnectorTableFunction { public static final String FUNCTION_NAME = "table_argument_function"; - public TableArgumentFunction() { super( @@ -284,11 +280,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class DescriptorArgumentFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "descriptor_argument_function"; public DescriptorArgumentFunction() { super( SCHEMA_NAME, - "descriptor_argument_function", + FUNCTION_NAME, ImmutableList.of( DescriptorArgumentSpecification.builder() .name("SCHEMA") @@ -327,11 +324,16 @@ public TestTVFConnectorTableHandle getTableHandle() public static class TestingTableFunctionHandle implements ConnectorTableFunctionHandle { + private final TestTVFConnectorTableHandle tableHandle; private final SchemaFunctionName schemaFunctionName; @JsonCreator public TestingTableFunctionHandle(@JsonProperty("schemaFunctionName") SchemaFunctionName schemaFunctionName) { + this.tableHandle = new TestTVFConnectorTableHandle( + new SchemaTableName(SCHEMA_NAME, TABLE_NAME), + Optional.of(ImmutableList.of(new TestTVFConnectorColumnHandle(COLUMN_NAME, BOOLEAN))), + TupleDomain.all()); this.schemaFunctionName = requireNonNull(schemaFunctionName, "schemaFunctionName is null"); } @@ -340,16 +342,22 @@ public SchemaFunctionName getSchemaFunctionName() { return schemaFunctionName; } + + public TestTVFConnectorTableHandle getTableHandle() + { + return tableHandle; + } } public static class TableArgumentRowSemanticsFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "table_argument_row_semantics_function"; public TableArgumentRowSemanticsFunction() { super( SCHEMA_NAME, - "table_argument_row_semantics_function", + FUNCTION_NAME, ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT") @@ -372,17 +380,20 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class TwoTableArgumentsFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "two_table_arguments_function"; public TwoTableArgumentsFunction() { super( SCHEMA_NAME, - "two_table_arguments_function", + FUNCTION_NAME, ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT1") + .keepWhenEmpty() .build(), TableArgumentSpecification.builder() .name("INPUT2") + .keepWhenEmpty() .build()), GENERIC_TABLE); } @@ -402,11 +413,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class OnlyPassThroughFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "only_pass_through_function"; public OnlyPassThroughFunction() { super( SCHEMA_NAME, - "only_pass_through_function", + FUNCTION_NAME, ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT") @@ -425,11 +437,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class MonomorphicStaticReturnTypeFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "monomorphic_static_return_type_function"; public MonomorphicStaticReturnTypeFunction() { super( SCHEMA_NAME, - "monomorphic_static_return_type_function", + FUNCTION_NAME, ImmutableList.of(), new DescribedTable(Descriptor.descriptor( ImmutableList.of("a", "b"), @@ -448,11 +461,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class PolymorphicStaticReturnTypeFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "polymorphic_static_return_type_function"; public PolymorphicStaticReturnTypeFunction() { super( SCHEMA_NAME, - "polymorphic_static_return_type_function", + FUNCTION_NAME, ImmutableList.of(TableArgumentSpecification.builder() .name("INPUT") .build()), @@ -471,14 +485,16 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class PassThroughFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "pass_through_function"; public PassThroughFunction() { super( SCHEMA_NAME, - "pass_through_function", + FUNCTION_NAME, ImmutableList.of(TableArgumentSpecification.builder() .name("INPUT") .passThroughColumns() + .keepWhenEmpty() .build()), new DescribedTable(Descriptor.descriptor( ImmutableList.of("x"), @@ -495,14 +511,16 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class RequiredColumnsFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "required_columns_function"; public RequiredColumnsFunction() { super( SCHEMA_NAME, - "required_columns_function", + FUNCTION_NAME, ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT") + .keepWhenEmpty() .build()), GENERIC_TABLE); } @@ -517,4 +535,51 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact .build(); } } + + public static class DifferentArgumentTypesFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "different_arguments_function"; + public DifferentArgumentTypesFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT_1") + .passThroughColumns() + .keepWhenEmpty() + .build(), + DescriptorArgumentSpecification.builder() + .name("LAYOUT") + .build(), + TableArgumentSpecification.builder() + .name("INPUT_2") + .rowSemantics() + .passThroughColumns() + .build(), + ScalarArgumentSpecification.builder() + .name("ID") + .type(BIGINT) + .build(), + TableArgumentSpecification.builder() + .name("INPUT_3") + .pruneWhenEmpty() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT_1", ImmutableList.of(0)) + .requiredColumns("INPUT_2", ImmutableList.of(0)) + .requiredColumns("INPUT_3", ImmutableList.of(0)) + .build(); + } + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java index 2e3f4d2534806..7af314e25d42f 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java @@ -1964,59 +1964,59 @@ public void testTableFunctionNotFound() @Test public void testTableFunctionArguments() { - assertFails(TABLE_FUNCTION_INVALID_ARGUMENTS, "line 1:51: Too many arguments. Expected at most 2 arguments, got 3 arguments", "SELECT * FROM TABLE(system.two_arguments_function(1, 2, 3))"); + assertFails(TABLE_FUNCTION_INVALID_ARGUMENTS, "line 1:58: Too many arguments. Expected at most 2 arguments, got 3 arguments", "SELECT * FROM TABLE(system.two_scalar_arguments_function(1, 2, 3))"); - analyze("SELECT * FROM TABLE(system.two_arguments_function('foo'))"); - analyze("SELECT * FROM TABLE(system.two_arguments_function(text => 'foo'))"); - analyze("SELECT * FROM TABLE(system.two_arguments_function('foo', 1))"); - analyze("SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', number => 1))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('foo'))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo'))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('foo', 1))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', number => 1))"); assertFails(TABLE_FUNCTION_INVALID_ARGUMENTS, - "line 1:51: All arguments must be passed by name or all must be passed positionally", - "SELECT * FROM TABLE(system.two_arguments_function('foo', number => 1))"); + "line 1:58: All arguments must be passed by name or all must be passed positionally", + "SELECT * FROM TABLE(system.two_scalar_arguments_function('foo', number => 1))"); assertFails(TABLE_FUNCTION_INVALID_ARGUMENTS, - "line 1:51: All arguments must be passed by name or all must be passed positionally", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', 1))"); + "line 1:58: All arguments must be passed by name or all must be passed positionally", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', 1))"); assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:66: Duplicate argument name: TEXT", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', text => 'bar'))"); + "line 1:73: Duplicate argument name: TEXT", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', text => 'bar'))"); // argument names are resolved in the canonical form assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:66: Duplicate argument name: TEXT", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', TeXt => 'bar'))"); + "line 1:73: Duplicate argument name: TEXT", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', TeXt => 'bar'))"); assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:66: Unexpected argument name: BAR", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', bar => 'bar'))"); + "line 1:73: Unexpected argument name: BAR", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', bar => 'bar'))"); assertFails(TABLE_FUNCTION_MISSING_ARGUMENT, - "line 1:51: Missing argument: TEXT", - "SELECT * FROM TABLE(system.two_arguments_function(number => 1))"); + "line 1:58: Missing argument: TEXT", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(number => 1))"); } @Test public void testScalarArgument() { - analyze("SELECT * FROM TABLE(system.two_arguments_function('foo', 1))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('foo', 1))"); assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:64: Invalid argument NUMBER. Expected expression, got descriptor", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'a', number => DESCRIPTOR(x integer, y boolean)))"); + "line 1:71: Invalid argument NUMBER. Expected expression, got descriptor", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'a', number => DESCRIPTOR(x integer, y boolean)))"); assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:64: 'descriptor' function is not allowed as a table function argument", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'a', number => DESCRIPTOR(1 + 2)))"); + "line 1:71: 'descriptor' function is not allowed as a table function argument", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'a', number => DESCRIPTOR(1 + 2)))"); assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:64: Invalid argument NUMBER. Expected expression, got table", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'a', number => TABLE(t1)))"); + "line 1:71: Invalid argument NUMBER. Expected expression, got table", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'a', number => TABLE(t1)))"); assertFails(EXPRESSION_NOT_CONSTANT, - "line 1:74: Constant expression cannot contain a subquery", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'a', number => (SELECT 1)))"); + "line 1:81: Constant expression cannot contain a subquery", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'a', number => (SELECT 1)))"); } @Test @@ -2228,10 +2228,10 @@ public void testNullArguments() // the default value for the argument schema is null analyze("SELECT * FROM TABLE(system.descriptor_argument_function())"); - analyze("SELECT * FROM TABLE(system.two_arguments_function(null, null))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function(null, null))"); // the default value for the second argument is null - analyze("SELECT * FROM TABLE(system.two_arguments_function('a'))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('a'))"); } @Test @@ -2243,8 +2243,8 @@ public void testTableFunctionInvocationContext() "SELECT * FROM TABLE(system.only_pass_through_function(TABLE(t1))) f(x)"); // per SQL standard, relation alias is required for table function with GENERIC TABLE return type. We don't require it. - analyze("SELECT * FROM TABLE(system.two_arguments_function('a', 1)) f(x)"); - analyze("SELECT * FROM TABLE(system.two_arguments_function('a', 1))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('a', 1)) f(x)"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('a', 1))"); // per SQL standard, relation alias is required for table function with statically declared return type, only if the function is polymorphic. // We don't require aliasing polymorphic functions. @@ -2261,7 +2261,7 @@ public void testTableFunctionInvocationContext() // aliased + sampled assertFails(TABLE_FUNCTION_INVALID_TABLE_FUNCTION_INVOCATION, "line 1:15: Cannot apply sample to polymorphic table function invocation", - "SELECT * FROM TABLE(system.two_arguments_function('a', 1)) f(x) TABLESAMPLE BERNOULLI (10)"); + "SELECT * FROM TABLE(system.two_scalar_arguments_function('a', 1)) f(x) TABLESAMPLE BERNOULLI (10)"); } @Test @@ -2279,19 +2279,19 @@ public void testTableFunctionAliasing() analyze("SELECT * FROM TABLE(system.table_argument_function(TABLE(t1) t2)) T1(x)"); // the original returned relation type is ("column" : BOOLEAN) - analyze("SELECT column FROM TABLE(system.two_arguments_function('a', 1)) table_alias"); + analyze("SELECT column FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias"); - analyze("SELECT column_alias FROM TABLE(system.two_arguments_function('a', 1)) table_alias(column_alias)"); + analyze("SELECT column_alias FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias(column_alias)"); - analyze("SELECT table_alias.column_alias FROM TABLE(system.two_arguments_function('a', 1)) table_alias(column_alias)"); + analyze("SELECT table_alias.column_alias FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias(column_alias)"); assertFails(MISSING_ATTRIBUTE, "line 1:8: Column 'column' cannot be resolved", - "SELECT column FROM TABLE(system.two_arguments_function('a', 1)) table_alias(column_alias)"); + "SELECT column FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias(column_alias)"); assertFails(MISMATCHED_COLUMN_ALIASES, "line 1:20: Column alias list has 3 entries but table function has 1 proper columns", - "SELECT column FROM TABLE(system.two_arguments_function('a', 1)) table_alias(col1, col2, col3)"); + "SELECT column FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias(col1, col2, col3)"); // the original returned relation type is ("a" : BOOLEAN, "b" : INTEGER) analyze("SELECT column_alias_1, column_alias_2 FROM TABLE(system.monomorphic_static_return_type_function()) table_alias(column_alias_1, column_alias_2)"); @@ -2333,7 +2333,9 @@ public void testTableFunctionRequiredColumns() "Invalid index: 1 of required column from table argument INPUT", "SELECT * FROM TABLE(system.required_columns_function(input => TABLE(SELECT 1)))"); - // table s1.t5 has two columns. The second column is hidden. Table function can require a hidden column. - analyze("SELECT * FROM TABLE(system.required_columns_function(input => TABLE(s1.t5)))"); + // table s1.t5 has two columns. The second column is hidden. Table function cannot require a hidden column. + assertFails(TABLE_FUNCTION_IMPLEMENTATION_ERROR, + "Invalid index: 1 of required column from table argument INPUT", + "SELECT * FROM TABLE(system.required_columns_function(input => TABLE(s1.t5)))"); } } diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java index 9d28b9d4b4219..8b3840e633f40 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java @@ -1667,7 +1667,7 @@ public Node visitDescriptorArgument(SqlBaseParser.DescriptorArgumentContext cont @Override public Node visitDescriptorField(SqlBaseParser.DescriptorFieldContext context) { - return new DescriptorField(getLocation(context), (Identifier) visit(context.identifier()), Optional.of(getType(context.type()))); + return new DescriptorField(getLocation(context), (Identifier) visit(context.identifier()), Optional.ofNullable(context.type()).map(this::getType)); } /**