diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java index 0d7c27769823b..4ad814040b594 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java @@ -32,6 +32,7 @@ import com.facebook.presto.spi.function.FunctionKind; import com.facebook.presto.spi.function.table.Argument; import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.procedure.DistributedProcedure; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.spi.security.AccessControlContext; import com.facebook.presto.spi.security.AllowAllAccessControl; @@ -174,6 +175,13 @@ public class Analysis private final Multiset columnMaskScopes = HashMultiset.create(); private final Map, Map> columnMasks = new LinkedHashMap<>(); + // for call distributed procedure + private Optional distributedProcedureType = Optional.empty(); + private Optional procedureName = Optional.empty(); + private Optional procedureArguments = Optional.empty(); + private Optional callTarget = Optional.empty(); + private Optional targetQuery = Optional.empty(); + // for create table private Optional createTableDestination = Optional.empty(); private Map createTableProperties = ImmutableMap.of(); @@ -666,6 +674,46 @@ public Optional getCreateTableDestination() return createTableDestination; } + public Optional getProcedureName() + { + return procedureName; + } + + public void setProcedureName(Optional procedureName) + { + this.procedureName = procedureName; + } + + public Optional getDistributedProcedureType() + { + return distributedProcedureType; + } + + public void setDistributedProcedureType(Optional distributedProcedureType) + { + this.distributedProcedureType = distributedProcedureType; + } + + public Optional getProcedureArguments() + { + return procedureArguments; + } + + public void setProcedureArguments(Optional procedureArguments) + { + this.procedureArguments = procedureArguments; + } + + public Optional getCallTarget() + { + return callTarget; + } + + public void setCallTarget(TableHandle callTarget) + { + this.callTarget = Optional.of(callTarget); + } + public Optional getAnalyzeTarget() { return analyzeTarget; @@ -1020,6 +1068,16 @@ public Optional getCurrentQuerySpecification() return currentQuerySpecification; } + public void setTargetQuery(QuerySpecification targetQuery) + { + this.targetQuery = Optional.of(targetQuery); + } + + public Optional getTargetQuery() + { + return this.targetQuery; + } + public Map> getInvokedFunctions() { Map> functionMap = new HashMap<>(); diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryPreparer.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryPreparer.java index c3c91c5a46bb5..5ad696f46edfc 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryPreparer.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryPreparer.java @@ -13,19 +13,25 @@ */ package com.facebook.presto.sql.analyzer; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.analyzer.PreparedQuery; import com.facebook.presto.common.resourceGroups.QueryType; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.PrestoWarning; +import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.analyzer.AnalyzerOptions; import com.facebook.presto.spi.analyzer.QueryPreparer; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.sql.analyzer.utils.StatementUtils; import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.tree.Call; import com.facebook.presto.sql.tree.Execute; import com.facebook.presto.sql.tree.Explain; import com.facebook.presto.sql.tree.ExplainType; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.Statement; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -36,6 +42,7 @@ import java.util.Optional; import static com.facebook.presto.common.WarningHandlingLevel.AS_ERROR; +import static com.facebook.presto.common.resourceGroups.QueryType.CALL_DISTRIBUTED_PROCEDURE; import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.StandardErrorCode.WARNING_AS_ERROR; @@ -43,6 +50,7 @@ import static com.facebook.presto.sql.analyzer.ConstantExpressionVerifier.verifyExpressionIsConstant; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PARAMETER_USAGE; import static com.facebook.presto.sql.analyzer.utils.AnalyzerUtil.createParsingOptions; +import static com.facebook.presto.sql.analyzer.utils.MetadataUtils.createQualifiedObjectName; import static com.facebook.presto.sql.analyzer.utils.ParameterExtractor.getParameterCount; import static com.facebook.presto.sql.tree.ExplainType.Type.VALIDATE; import static java.lang.String.format; @@ -56,11 +64,15 @@ public class BuiltInQueryPreparer implements QueryPreparer { private final SqlParser sqlParser; + private final ProcedureRegistry procedureRegistry; @Inject - public BuiltInQueryPreparer(SqlParser sqlParser) + public BuiltInQueryPreparer( + SqlParser sqlParser, + ProcedureRegistry procedureRegistry) { this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.procedureRegistry = requireNonNull(procedureRegistry, "procedureRegistry is null"); } @Override @@ -87,6 +99,18 @@ public BuiltInPreparedQuery prepareQuery(AnalyzerOptions analyzerOptions, Statem statement = sqlParser.createStatement(query, createParsingOptions(analyzerOptions)); } + Optional distributedProcedureName = Optional.empty(); + if (statement instanceof Call) { + QualifiedName qualifiedName = ((Call) statement).getName(); + QualifiedObjectName qualifiedObjectName = createQualifiedObjectName(analyzerOptions.getSessionCatalogName(), analyzerOptions.getSessionSchemaName(), + statement, qualifiedName, (catalogName, objectName) -> objectName); + if (procedureRegistry.isDistributedProcedure( + new ConnectorId(qualifiedObjectName.getCatalogName()), + new SchemaTableName(qualifiedObjectName.getSchemaName(), qualifiedObjectName.getObjectName()))) { + distributedProcedureName = Optional.of(qualifiedObjectName); + } + } + if (statement instanceof Explain && ((Explain) statement).isAnalyze()) { Statement innerStatement = ((Explain) statement).getStatement(); Optional innerQueryType = StatementUtils.getQueryType(innerStatement.getClass()); @@ -103,7 +127,7 @@ public BuiltInPreparedQuery prepareQuery(AnalyzerOptions analyzerOptions, Statem if (analyzerOptions.isLogFormattedQueryEnabled()) { formattedQuery = Optional.of(getFormattedQuery(statement, parameters)); } - return new BuiltInPreparedQuery(wrappedStatement, statement, parameters, formattedQuery, prepareSql); + return new BuiltInPreparedQuery(wrappedStatement, statement, parameters, formattedQuery, prepareSql, distributedProcedureName); } private static String getFormattedQuery(Statement statement, List parameters) @@ -131,13 +155,19 @@ public static class BuiltInPreparedQuery private final Statement statement; private final Statement wrappedStatement; private final List parameters; + private final Optional distributedProcedureName; - public BuiltInPreparedQuery(Statement wrappedStatement, Statement statement, List parameters, Optional formattedQuery, Optional prepareSql) + public BuiltInPreparedQuery( + Statement wrappedStatement, + Statement statement, List parameters, + Optional formattedQuery, Optional prepareSql, + Optional distributedProcedureName) { super(formattedQuery, prepareSql); this.wrappedStatement = requireNonNull(wrappedStatement, "wrappedStatement is null"); this.statement = requireNonNull(statement, "statement is null"); this.parameters = ImmutableList.copyOf(requireNonNull(parameters, "parameters is null")); + this.distributedProcedureName = requireNonNull(distributedProcedureName, "distributedProcedureName is null"); } public Statement getStatement() @@ -157,9 +187,17 @@ public List getParameters() public Optional getQueryType() { + if (getDistributedProcedureName().isPresent()) { + return Optional.of(CALL_DISTRIBUTED_PROCEDURE); + } return StatementUtils.getQueryType(statement.getClass()); } + public Optional getDistributedProcedureName() + { + return this.distributedProcedureName; + } + public boolean isTransactionControlStatement() { return StatementUtils.isTransactionControlStatement(getStatement()); diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java index 80eef0f465959..957764db853d9 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java @@ -90,6 +90,7 @@ public enum SemanticErrorCode SAMPLE_PERCENTAGE_OUT_OF_RANGE, + PROCEDURE_NOT_FOUND, INVALID_PROCEDURE_ARGUMENTS, INVALID_SESSION_PROPERTY, diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/utils/MetadataUtils.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/utils/MetadataUtils.java new file mode 100644 index 0000000000000..9d241ccd22992 --- /dev/null +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/utils/MetadataUtils.java @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.analyzer.utils; + +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.sql.analyzer.SemanticException; +import com.facebook.presto.sql.tree.Identifier; +import com.facebook.presto.sql.tree.Node; +import com.facebook.presto.sql.tree.QualifiedName; +import com.google.common.collect.Lists; + +import java.util.List; +import java.util.Optional; +import java.util.function.BiFunction; + +import static com.facebook.presto.spi.StandardErrorCode.SYNTAX_ERROR; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.CATALOG_NOT_SPECIFIED; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.SCHEMA_NOT_SPECIFIED; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; + +public class MetadataUtils +{ + private MetadataUtils() + {} + + public static QualifiedObjectName createQualifiedObjectName(Optional sessionCatalogName, Optional sessionSchemaName, Node node, QualifiedName name, + BiFunction normalizer) + { + requireNonNull(sessionCatalogName, "sessionCatalogName is null"); + requireNonNull(sessionSchemaName, "sessionSchemaName is null"); + requireNonNull(name, "name is null"); + if (name.getParts().size() > 3) { + throw new PrestoException(SYNTAX_ERROR, format("Too many dots in table name: %s", name)); + } + + List parts = Lists.reverse(name.getOriginalParts()); + String objectName = parts.get(0).getValue(); + String schemaName = (parts.size() > 1) ? parts.get(1).getValue() : sessionSchemaName.orElseThrow(() -> + new SemanticException(SCHEMA_NOT_SPECIFIED, node, "Schema must be specified when session schema is not set")); + String catalogName = (parts.size() > 2) ? parts.get(2).getValue() : sessionCatalogName.orElseThrow(() -> + new SemanticException(CATALOG_NOT_SPECIFIED, node, "Catalog must be specified when session catalog is not set")); + + catalogName = catalogName.toLowerCase(ENGLISH); + schemaName = normalizer.apply(catalogName, schemaName); + objectName = normalizer.apply(catalogName, objectName); + return new QualifiedObjectName(catalogName, schemaName, objectName); + } +} diff --git a/presto-analyzer/src/test/java/com/facebook/presto/sql/analyzer/TestBuiltInQueryPreparer.java b/presto-analyzer/src/test/java/com/facebook/presto/sql/analyzer/TestBuiltInQueryPreparer.java index bdc99102d1909..319ca019a87bb 100644 --- a/presto-analyzer/src/test/java/com/facebook/presto/sql/analyzer/TestBuiltInQueryPreparer.java +++ b/presto-analyzer/src/test/java/com/facebook/presto/sql/analyzer/TestBuiltInQueryPreparer.java @@ -13,47 +13,110 @@ */ package com.facebook.presto.sql.analyzer; +import com.facebook.presto.common.resourceGroups.QueryType; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.analyzer.AnalyzerOptions; +import com.facebook.presto.spi.procedure.LocalProcedure; +import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; +import com.facebook.presto.spi.procedure.ProcedureRegistry; +import com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure; import com.facebook.presto.sql.analyzer.BuiltInQueryPreparer.BuiltInPreparedQuery; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.tree.AllColumns; +import com.facebook.presto.sql.tree.Call; +import com.facebook.presto.sql.tree.CallArgument; import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.StringLiteral; import com.google.common.collect.ImmutableMap; +import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.Optional; +import static com.facebook.presto.common.type.StandardTypes.VARCHAR; import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND; +import static com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure.SCHEMA; +import static com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure.TABLE_NAME; import static com.facebook.presto.sql.QueryUtil.selectList; import static com.facebook.presto.sql.QueryUtil.simpleQuery; import static com.facebook.presto.sql.QueryUtil.table; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PARAMETER_USAGE; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; public class TestBuiltInQueryPreparer { private static final SqlParser SQL_PARSER = new SqlParser(); - private static final BuiltInQueryPreparer QUERY_PREPARER = new BuiltInQueryPreparer(SQL_PARSER); private static final Map emptyPreparedStatements = ImmutableMap.of(); private static final AnalyzerOptions testAnalyzerOptions = AnalyzerOptions.builder().build(); + private static ProcedureRegistry procedureRegistry; + private static BuiltInQueryPreparer queryPreparer; + + @BeforeClass + public void setup() + { + procedureRegistry = new TestProcedureRegistry(); + List arguments = new ArrayList<>(); + arguments.add(new Argument(SCHEMA, VARCHAR)); + arguments.add(new Argument(TABLE_NAME, VARCHAR)); + + List procedures = new ArrayList<>(); + procedures.add(new LocalProcedure("system", "fun", arguments)); + procedures.add(new TableDataRewriteDistributedProcedure("system", "distributed_fun", + arguments, + (session, transactionContext, procedureHandle, fragments) -> null, + (transactionContext, procedureHandle, fragments) -> {}, + TestProcedureRegistry.TestProcedureContext::new)); + procedureRegistry.addProcedures(new ConnectorId("test"), procedures); + queryPreparer = new BuiltInQueryPreparer(SQL_PARSER, procedureRegistry); + } @Test public void testSelectStatement() { - BuiltInPreparedQuery preparedQuery = QUERY_PREPARER.prepareQuery(testAnalyzerOptions, "SELECT * FROM foo", emptyPreparedStatements, WarningCollector.NOOP); + BuiltInPreparedQuery preparedQuery = queryPreparer.prepareQuery(testAnalyzerOptions, "SELECT * FROM foo", emptyPreparedStatements, WarningCollector.NOOP); assertEquals(preparedQuery.getStatement(), simpleQuery(selectList(new AllColumns()), table(QualifiedName.of("foo")))); } + @Test + public void testCallProcedureStatement() + { + BuiltInPreparedQuery preparedQuery = queryPreparer.prepareQuery(testAnalyzerOptions, "call test.system.fun('a', 'b')", emptyPreparedStatements, WarningCollector.NOOP); + List arguments = new ArrayList<>(); + arguments.add(new CallArgument(new StringLiteral("a"))); + arguments.add(new CallArgument(new StringLiteral("b"))); + assertEquals(preparedQuery.getStatement(), + new Call(QualifiedName.of("test", "system", "fun"), arguments)); + assertTrue(preparedQuery.getQueryType().isPresent()); + assertEquals(preparedQuery.getQueryType().get(), QueryType.DATA_DEFINITION); + } + + @Test + public void testCallDistributedProcedureStatement() + { + BuiltInPreparedQuery preparedQuery = queryPreparer.prepareQuery(testAnalyzerOptions, "call test.system.distributed_fun('a', 'b')", emptyPreparedStatements, WarningCollector.NOOP); + List arguments = new ArrayList<>(); + arguments.add(new CallArgument(new StringLiteral("a"))); + arguments.add(new CallArgument(new StringLiteral("b"))); + assertEquals(preparedQuery.getStatement(), + new Call(QualifiedName.of("test", "system", "distributed_fun"), arguments)); + assertTrue(preparedQuery.getQueryType().isPresent()); + assertEquals(preparedQuery.getQueryType().get(), QueryType.CALL_DISTRIBUTED_PROCEDURE); + } + @Test public void testExecuteStatement() { Map preparedStatements = ImmutableMap.of("my_query", "SELECT * FROM foo"); - BuiltInPreparedQuery preparedQuery = QUERY_PREPARER.prepareQuery(testAnalyzerOptions, "EXECUTE my_query", preparedStatements, WarningCollector.NOOP); + BuiltInPreparedQuery preparedQuery = queryPreparer.prepareQuery(testAnalyzerOptions, "EXECUTE my_query", preparedStatements, WarningCollector.NOOP); assertEquals(preparedQuery.getStatement(), simpleQuery(selectList(new AllColumns()), table(QualifiedName.of("foo")))); } @@ -62,7 +125,7 @@ public void testExecuteStatement() public void testExecuteStatementDoesNotExist() { try { - QUERY_PREPARER.prepareQuery(testAnalyzerOptions, "execute my_query", emptyPreparedStatements, WarningCollector.NOOP); + queryPreparer.prepareQuery(testAnalyzerOptions, "execute my_query", emptyPreparedStatements, WarningCollector.NOOP); fail("expected exception"); } catch (PrestoException e) { @@ -75,7 +138,7 @@ public void testTooManyParameters() { try { Map preparedStatements = ImmutableMap.of("my_query", "SELECT * FROM foo where col1 = ?"); - QUERY_PREPARER.prepareQuery(testAnalyzerOptions, "EXECUTE my_query USING 1,2", preparedStatements, WarningCollector.NOOP); + queryPreparer.prepareQuery(testAnalyzerOptions, "EXECUTE my_query USING 1,2", preparedStatements, WarningCollector.NOOP); fail("expected exception"); } catch (SemanticException e) { @@ -88,7 +151,7 @@ public void testTooFewParameters() { try { Map preparedStatements = ImmutableMap.of("my_query", "SELECT ? FROM foo where col1 = ?"); - QUERY_PREPARER.prepareQuery(testAnalyzerOptions, "EXECUTE my_query USING 1", preparedStatements, WarningCollector.NOOP); + queryPreparer.prepareQuery(testAnalyzerOptions, "EXECUTE my_query USING 1", preparedStatements, WarningCollector.NOOP); fail("expected exception"); } catch (SemanticException e) { @@ -100,7 +163,7 @@ public void testTooFewParameters() public void testFormattedQuery() { AnalyzerOptions analyzerOptions = AnalyzerOptions.builder().setLogFormattedQueryEnabled(true).build(); - BuiltInPreparedQuery preparedQuery = QUERY_PREPARER.prepareQuery( + BuiltInPreparedQuery preparedQuery = queryPreparer.prepareQuery( analyzerOptions, "PREPARE test FROM SELECT * FROM foo where col1 = ?", emptyPreparedStatements, @@ -112,7 +175,7 @@ public void testFormattedQuery() " foo\n" + " WHERE (col1 = ?)\n")); - preparedQuery = QUERY_PREPARER.prepareQuery( + preparedQuery = queryPreparer.prepareQuery( analyzerOptions, "PREPARE test FROM SELECT * FROM foo", emptyPreparedStatements, diff --git a/presto-analyzer/src/test/java/com/facebook/presto/sql/analyzer/TestProcedureRegistry.java b/presto-analyzer/src/test/java/com/facebook/presto/sql/analyzer/TestProcedureRegistry.java new file mode 100644 index 0000000000000..1daa139c57845 --- /dev/null +++ b/presto-analyzer/src/test/java/com/facebook/presto/sql/analyzer/TestProcedureRegistry.java @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.analyzer; + +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.connector.ConnectorProcedureContext; +import com.facebook.presto.spi.procedure.DistributedProcedure; +import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.ProcedureRegistry; + +import java.util.Collection; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.facebook.presto.spi.StandardErrorCode.PROCEDURE_NOT_FOUND; +import static java.util.Objects.requireNonNull; + +public class TestProcedureRegistry + implements ProcedureRegistry +{ + private final Map> connectorProcedures = new ConcurrentHashMap<>(); + + @Override + public void addProcedures(ConnectorId connectorId, Collection procedures) + { + requireNonNull(connectorId, "connectorId is null"); + requireNonNull(procedures, "procedures is null"); + + Map proceduresByName = procedures.stream().collect(Collectors.toMap( + procedure -> new SchemaTableName(procedure.getSchema(), procedure.getName()), + Function.identity())); + if (connectorProcedures.putIfAbsent(connectorId, proceduresByName) != null) { + throw new IllegalStateException("Procedures already registered for connector: " + connectorId); + } + } + + @Override + public void removeProcedures(ConnectorId connectorId) + { + connectorProcedures.remove(connectorId); + } + + @Override + public Procedure resolve(ConnectorId connectorId, SchemaTableName name) + { + Map procedures = connectorProcedures.get(connectorId); + if (procedures != null) { + Procedure procedure = procedures.get(name); + if (procedure != null) { + return procedure; + } + } + throw new PrestoException(PROCEDURE_NOT_FOUND, "Procedure not registered: " + name); + } + + @Override + public DistributedProcedure resolveDistributed(ConnectorId connectorId, SchemaTableName name) + { + Map procedures = connectorProcedures.get(connectorId); + if (procedures != null) { + Procedure procedure = procedures.get(name); + if (procedure != null && procedure instanceof DistributedProcedure) { + return (DistributedProcedure) procedure; + } + } + throw new PrestoException(PROCEDURE_NOT_FOUND, "Distributed procedure not registered: " + name); + } + + @Override + public boolean isDistributedProcedure(ConnectorId connectorId, SchemaTableName name) + { + Map procedures = connectorProcedures.get(connectorId); + return procedures != null && + procedures.containsKey(name) && + procedures.get(name) instanceof DistributedProcedure; + } + + public static class TestProcedureContext + implements ConnectorProcedureContext + {} +} diff --git a/presto-common/src/main/java/com/facebook/presto/common/QualifiedObjectName.java b/presto-common/src/main/java/com/facebook/presto/common/QualifiedObjectName.java index 49f2be4c16e78..a6a7207f94709 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/QualifiedObjectName.java +++ b/presto-common/src/main/java/com/facebook/presto/common/QualifiedObjectName.java @@ -17,6 +17,7 @@ import com.facebook.drift.annotations.ThriftField; import com.facebook.drift.annotations.ThriftStruct; import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonValue; import com.google.errorprone.annotations.Immutable; @@ -57,8 +58,12 @@ public static QualifiedObjectName valueOf(String catalogName, String schemaName, return new QualifiedObjectName(catalogName, schemaName, objectName.toLowerCase(ENGLISH)); } + @JsonCreator @ThriftConstructor - public QualifiedObjectName(String catalogName, String schemaName, String objectName) + public QualifiedObjectName( + @JsonProperty("catalogName") String catalogName, + @JsonProperty("schemaName") String schemaName, + @JsonProperty("objectName") String objectName) { checkLowerCase(catalogName, "catalogName"); this.catalogName = catalogName; @@ -72,18 +77,21 @@ public CatalogSchemaName getCatalogSchemaName() } @ThriftField(1) + @JsonProperty("catalogName") public String getCatalogName() { return catalogName; } @ThriftField(2) + @JsonProperty("schemaName") public String getSchemaName() { return schemaName; } @ThriftField(3) + @JsonProperty("objectName") public String getObjectName() { return objectName; diff --git a/presto-common/src/main/java/com/facebook/presto/common/resourceGroups/QueryType.java b/presto-common/src/main/java/com/facebook/presto/common/resourceGroups/QueryType.java index f2900acbaaf45..8f08f12b4e925 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/resourceGroups/QueryType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/resourceGroups/QueryType.java @@ -28,7 +28,8 @@ public enum QueryType SELECT(7), CONTROL(8), UPDATE(9), - MERGE(10) + MERGE(10), + CALL_DISTRIBUTED_PROCEDURE(11) /**/; private final int value; diff --git a/presto-docs/src/main/sphinx/connector/iceberg.rst b/presto-docs/src/main/sphinx/connector/iceberg.rst index 4339f54b6c61a..aae695b021ce6 100644 --- a/presto-docs/src/main/sphinx/connector/iceberg.rst +++ b/presto-docs/src/main/sphinx/connector/iceberg.rst @@ -1233,6 +1233,47 @@ Examples: CALL iceberg.system.set_table_property('schema_name', 'table_name', 'commit.retry.num-retries', '10'); +Rewrite Data Files +^^^^^^^^^^^^^^^^^^ + +Iceberg tracks all data files under different partition specs in a table. More data files requires +more metadata to be stored in manifest files, and small data files can cause unnecessary amount metadata and +less efficient queries from file open costs. Also, data files under different partition specs can +prevent metadata level deletion or thorough predicate push down for Presto. + +Use `rewrite_data_files` to rewrite the data files of a specified table so that they are +merged into fewer but larger files under the newest partition spec. If the table is partitioned, the data +files compaction can act separately on the selected partitions to improve read performance by reducing +metadata overhead and runtime file open cost. + +The following arguments are available: + +===================== ========== =============== ======================================================================= +Argument Name required type Description +===================== ========== =============== ======================================================================= +``schema`` ✔️ string Schema of the table to update. + +``table_name`` ✔️ string Name of the table to update. + +``filter`` string Predicate as a string used for filtering the files. Currently + only rewrite of whole partitions is supported. Filter on partition + columns. The default value is `true`. + +``options`` map Options to be used for data files rewrite. (to be expanded) +===================== ========== =============== ======================================================================= + +Examples: + +* Rewrite all the data files in table `db.sample` to the newest partition spec and combine small files to larger ones:: + + CALL iceberg.system.rewrite_data_files('db', 'sample'); + CALL iceberg.system.rewrite_data_files(schema => 'db', table_name => 'sample'); + +* Rewrite the data files in partitions specified by a filter in table `db.sample` to the newest partition spec:: + + CALL iceberg.system.rewrite_data_files('db', 'sample', 'partition_key = 1'); + CALL iceberg.system.rewrite_data_files(schema => 'db', table_name => 'sample', filter => 'partition_key = 1'); + Presto C++ Support ^^^^^^^^^^^^^^^^^^ diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/InvalidateMetastoreCacheProcedure.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/InvalidateMetastoreCacheProcedure.java index e14ca6eb37c6a..a1506e843b4a9 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/InvalidateMetastoreCacheProcedure.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/InvalidateMetastoreCacheProcedure.java @@ -17,6 +17,7 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; @@ -65,7 +66,7 @@ public InvalidateMetastoreCacheProcedure(ExtendedHiveMetastore extendedHiveMetas @Override public Procedure get() { - return new Procedure( + return new LocalProcedure( "system", "invalidate_metastore_cache", ImmutableList.of( diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/CreateEmptyPartitionProcedure.java b/presto-hive/src/main/java/com/facebook/presto/hive/CreateEmptyPartitionProcedure.java index cf3e0f943e93b..255a78008dbdd 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/CreateEmptyPartitionProcedure.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/CreateEmptyPartitionProcedure.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; @@ -83,7 +84,7 @@ public CreateEmptyPartitionProcedure( @Override public Procedure get() { - return new Procedure( + return new LocalProcedure( "system", "create_empty_partition", ImmutableList.of( diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/DirectoryListCacheInvalidationProcedure.java b/presto-hive/src/main/java/com/facebook/presto/hive/DirectoryListCacheInvalidationProcedure.java index e6a931dcd2922..06b8c0978713b 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/DirectoryListCacheInvalidationProcedure.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/DirectoryListCacheInvalidationProcedure.java @@ -16,6 +16,7 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; @@ -52,7 +53,7 @@ public DirectoryListCacheInvalidationProcedure(DirectoryLister directoryLister) @Override public Procedure get() { - return new Procedure( + return new LocalProcedure( "system", "invalidate_directory_list_cache", ImmutableList.of( diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/SyncPartitionMetadataProcedure.java b/presto-hive/src/main/java/com/facebook/presto/hive/SyncPartitionMetadataProcedure.java index 7f248470483e3..e33247c7fc5c6 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/SyncPartitionMetadataProcedure.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/SyncPartitionMetadataProcedure.java @@ -24,6 +24,7 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.TableNotFoundException; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; @@ -95,7 +96,7 @@ public SyncPartitionMetadataProcedure( @Override public Procedure get() { - return new Procedure( + return new LocalProcedure( "system", "sync_partition_metadata", ImmutableList.of( diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/CallDistributedProcedureSplitSource.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/CallDistributedProcedureSplitSource.java new file mode 100644 index 0000000000000..e8eeda5c97477 --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/CallDistributedProcedureSplitSource.java @@ -0,0 +1,137 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.presto.iceberg.delete.DeleteFile; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.SplitWeight; +import com.facebook.presto.spi.connector.ConnectorPartitionHandle; +import com.google.common.collect.ImmutableList; +import com.google.common.io.Closer; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.PartitionSpecParser; +import org.apache.iceberg.TableScan; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.CloseableIterator; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +import static com.facebook.presto.hive.HiveCommonSessionProperties.getAffinitySchedulingFileSectionSize; +import static com.facebook.presto.hive.HiveCommonSessionProperties.getNodeSelectionStrategy; +import static com.facebook.presto.iceberg.FileFormat.fromIcebergFileFormat; +import static com.facebook.presto.iceberg.IcebergUtil.getDataSequenceNumber; +import static com.facebook.presto.iceberg.IcebergUtil.getPartitionKeys; +import static com.facebook.presto.iceberg.IcebergUtil.partitionDataFromStructLike; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterators.limit; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.CompletableFuture.completedFuture; + +public class CallDistributedProcedureSplitSource + implements ConnectorSplitSource +{ + private CloseableIterator fileScanTaskIterator; + private Optional> fileScanTaskConsumer; + + private final TableScan tableScan; + private final Closer closer = Closer.create(); + private final double minimumAssignedSplitWeight; + private final ConnectorSession session; + + public CallDistributedProcedureSplitSource( + ConnectorSession session, + TableScan tableScan, + CloseableIterable fileScanTaskIterable, + Optional> fileScanTaskConsumer, + double minimumAssignedSplitWeight) + { + this.session = requireNonNull(session, "session is null"); + this.tableScan = requireNonNull(tableScan, "tableScan is null"); + this.fileScanTaskIterator = fileScanTaskIterable.iterator(); + this.fileScanTaskConsumer = requireNonNull(fileScanTaskConsumer, "fileScanTaskConsumer is null"); + this.minimumAssignedSplitWeight = minimumAssignedSplitWeight; + closer.register(fileScanTaskIterator); + } + + @Override + public CompletableFuture getNextBatch(ConnectorPartitionHandle partitionHandle, int maxSize) + { + // TODO: move this to a background thread + List splits = new ArrayList<>(); + Iterator iterator = limit(fileScanTaskIterator, maxSize); + while (iterator.hasNext()) { + FileScanTask task = iterator.next(); + fileScanTaskConsumer.ifPresent(consumer -> consumer.accept(task)); + splits.add(toIcebergSplit(task)); + } + return completedFuture(new ConnectorSplitBatch(splits, isFinished())); + } + + @Override + public boolean isFinished() + { + return !fileScanTaskIterator.hasNext(); + } + + @Override + public void close() + { + try { + closer.close(); + // TODO: remove this after org.apache.iceberg.io.CloseableIterator'withClose + // correct release resources holds by iterator. + fileScanTaskIterator = CloseableIterator.empty(); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private ConnectorSplit toIcebergSplit(FileScanTask task) + { + PartitionSpec spec = task.spec(); + Optional partitionData = partitionDataFromStructLike(spec, task.file().partition()); + + // TODO: We should leverage residual expression and convert that to TupleDomain. + // The predicate here is used by readers for predicate push down at reader level, + // so when we do not use residual expression, we are just wasting CPU cycles + // on reader side evaluating a condition that we know will always be true. + + return new IcebergSplit( + task.file().path().toString(), + task.start(), + task.length(), + fromIcebergFileFormat(task.file().format()), + ImmutableList.of(), + getPartitionKeys(task), + PartitionSpecParser.toJson(spec), + partitionData.map(PartitionData::toJson), + getNodeSelectionStrategy(session), + SplitWeight.fromProportion(Math.min(Math.max((double) task.length() / tableScan.targetSplitSize(), minimumAssignedSplitWeight), 1.0)), + task.deletes().stream().map(DeleteFile::fromIceberg).collect(toImmutableList()), + Optional.empty(), + getDataSequenceNumber(task.file()), + getAffinitySchedulingFileSectionSize(session).toBytes()); + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java index 1a5db48f80051..1394cc08dcff8 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java @@ -15,6 +15,7 @@ import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.Subfield; import com.facebook.presto.common.predicate.TupleDomain; @@ -35,10 +36,13 @@ import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorInsertTableHandle; import com.facebook.presto.spi.ConnectorNewTableLayout; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplitSource; import com.facebook.presto.spi.ConnectorTableHandle; import com.facebook.presto.spi.ConnectorTableLayout; import com.facebook.presto.spi.ConnectorTableLayoutHandle; @@ -59,6 +63,9 @@ import com.facebook.presto.spi.connector.ConnectorTableVersion.VersionType; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; +import com.facebook.presto.spi.procedure.DistributedProcedure; +import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.RowExpressionService; import com.facebook.presto.spi.statistics.ColumnStatisticMetadata; @@ -219,10 +226,12 @@ public abstract class IcebergAbstractMetadata protected static final String INFORMATION_SCHEMA = "information_schema"; protected final TypeManager typeManager; + protected final ProcedureRegistry procedureRegistry; protected final JsonCodec commitTaskCodec; protected final NodeVersion nodeVersion; protected final RowExpressionService rowExpressionService; protected final FilterStatsCalculatorService filterStatsCalculatorService; + protected Optional procedureContext = Optional.empty(); protected Transaction transaction; protected final StatisticsFileCache statisticsFileCache; protected final IcebergTableProperties tableProperties; @@ -232,6 +241,7 @@ public abstract class IcebergAbstractMetadata public IcebergAbstractMetadata( TypeManager typeManager, + ProcedureRegistry procedureRegistry, StandardFunctionResolution functionResolution, RowExpressionService rowExpressionService, JsonCodec commitTaskCodec, @@ -241,6 +251,7 @@ public IcebergAbstractMetadata( IcebergTableProperties tableProperties) { this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.procedureRegistry = requireNonNull(procedureRegistry, "procedureRegistry is null"); this.commitTaskCodec = requireNonNull(commitTaskCodec, "commitTaskCodec is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); this.rowExpressionService = requireNonNull(rowExpressionService, "rowExpressionService is null"); @@ -267,6 +278,11 @@ protected final Table getIcebergTable(ConnectorSession session, SchemaTableName public abstract void unregisterTable(ConnectorSession clientSession, SchemaTableName schemaTableName); + public Optional getSplitSourceInCurrentCallProcedureTransaction() + { + return procedureContext.flatMap(IcebergProcedureContext::getConnectorSplitSource); + } + /** * This class implements the default implementation for getTableLayoutForConstraint which will be used in the case of a Java Worker */ @@ -1041,6 +1057,48 @@ public void truncateTable(ConnectorSession session, ConnectorTableHandle tableHa removeScanFiles(icebergTable, TupleDomain.all()); } + @Override + public ConnectorDistributedProcedureHandle beginCallDistributedProcedure( + ConnectorSession session, + QualifiedObjectName procedureName, + ConnectorTableLayoutHandle tableLayoutHandle, + Object[] arguments) + { + IcebergTableHandle handle = ((IcebergTableLayoutHandle) tableLayoutHandle).getTable(); + Table icebergTable = getIcebergTable(session, handle.getSchemaTableName()); + + if (handle.isSnapshotSpecified()) { + throw new PrestoException(NOT_SUPPORTED, "This connector do not allow table execute at specified snapshot"); + } + + transaction = icebergTable.newTransaction(); + Procedure procedure = procedureRegistry.resolve( + new ConnectorId(procedureName.getCatalogName()), + new SchemaTableName( + procedureName.getSchemaName(), + procedureName.getObjectName())); + verify(procedure instanceof DistributedProcedure, "procedure must be DistributedProcedure"); + procedureContext = Optional.of((IcebergProcedureContext) ((DistributedProcedure) procedure).createContext()); + procedureContext.get().setTable(icebergTable); + procedureContext.get().setTransaction(transaction); + return ((DistributedProcedure) procedure).begin(session, procedureContext.get(), tableLayoutHandle, arguments); + } + + @Override + public void finishCallDistributedProcedure(ConnectorSession session, ConnectorDistributedProcedureHandle procedureHandle, QualifiedObjectName procedureName, Collection fragments) + { + Procedure procedure = procedureRegistry.resolve( + new ConnectorId(procedureName.getCatalogName()), + new SchemaTableName( + procedureName.getSchemaName(), + procedureName.getObjectName())); + verify(procedure instanceof DistributedProcedure, "procedure must be DistributedProcedure"); + verify(procedureContext.isPresent(), "procedure context must be present"); + ((DistributedProcedure) procedure).finish(procedureContext.get(), procedureHandle, fragments); + transaction.commitTransaction(); + procedureContext.get().destroy(); + } + @Override public ConnectorDeleteTableHandle beginDelete(ConnectorSession session, ConnectorTableHandle tableHandle) { diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergCommonModule.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergCommonModule.java index d0e3049e0f5ad..54f428b709285 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergCommonModule.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergCommonModule.java @@ -185,6 +185,7 @@ protected void setup(Binder binder) procedures.addBinding().toProvider(SetTablePropertyProcedure.class).in(Scopes.SINGLETON); procedures.addBinding().toProvider(StatisticsFileCacheInvalidationProcedure.class).in(Scopes.SINGLETON); procedures.addBinding().toProvider(ManifestFileCacheInvalidationProcedure.class).in(Scopes.SINGLETON); + procedures.addBinding().toProvider(RewriteDataFilesProcedure.class).in(Scopes.SINGLETON); // for orc binder.bind(EncryptionLibrary.class).annotatedWith(HiveDwrfEncryptionProvider.ForCryptoService.class).to(UnsupportedEncryptionLibrary.class).in(Scopes.SINGLETON); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergDistributedProcedureHandle.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergDistributedProcedureHandle.java new file mode 100644 index 0000000000000..0ae38a46e8946 --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergDistributedProcedureHandle.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.presto.hive.HiveCompressionCodec; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Map; + +public class IcebergDistributedProcedureHandle + extends IcebergWritableTableHandle + implements ConnectorDistributedProcedureHandle +{ + @JsonCreator + public IcebergDistributedProcedureHandle( + @JsonProperty("schemaName") String schemaName, + @JsonProperty("tableName") IcebergTableName tableName, + @JsonProperty("schema") PrestoIcebergSchema schema, + @JsonProperty("partitionSpec") PrestoIcebergPartitionSpec partitionSpec, + @JsonProperty("inputColumns") List inputColumns, + @JsonProperty("outputPath") String outputPath, + @JsonProperty("fileFormat") FileFormat fileFormat, + @JsonProperty("compressionCodec") HiveCompressionCodec compressionCodec, + @JsonProperty("storageProperties") Map storageProperties) + { + super( + schemaName, + tableName, + schema, + partitionSpec, + inputColumns, + outputPath, + fileFormat, + compressionCodec, + storageProperties, + ImmutableList.of()); + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHandleResolver.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHandleResolver.java index 199939c6b7985..92d3d0e9fdeec 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHandleResolver.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHandleResolver.java @@ -16,6 +16,7 @@ import com.facebook.presto.hive.HiveTransactionHandle; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorHandleResolver; import com.facebook.presto.spi.ConnectorInsertTableHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; @@ -69,6 +70,12 @@ public Class getDeleteTableHandleClass() return IcebergTableHandle.class; } + @Override + public Class getDistributedProcedureHandleClass() + { + return IcebergDistributedProcedureHandle.class; + } + @Override public Class getTransactionHandleClass() { diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java index 3737247d9c595..cc471631de512 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java @@ -54,6 +54,7 @@ import com.facebook.presto.spi.ViewNotFoundException; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpressionService; import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.statistics.ColumnStatisticMetadata; @@ -175,6 +176,7 @@ public IcebergHiveMetadata( ExtendedHiveMetastore metastore, HdfsEnvironment hdfsEnvironment, TypeManager typeManager, + ProcedureRegistry procedureRegistry, StandardFunctionResolution functionResolution, RowExpressionService rowExpressionService, JsonCodec commitTaskCodec, @@ -186,7 +188,7 @@ public IcebergHiveMetadata( IcebergTableProperties tableProperties, ConnectorSystemConfig connectorSystemConfig) { - super(typeManager, functionResolution, rowExpressionService, commitTaskCodec, nodeVersion, filterStatsCalculatorService, statisticsFileCache, tableProperties); + super(typeManager, procedureRegistry, functionResolution, rowExpressionService, commitTaskCodec, nodeVersion, filterStatsCalculatorService, statisticsFileCache, tableProperties); this.catalogName = requireNonNull(catalogName, "catalogName is null"); this.metastore = requireNonNull(metastore, "metastore is null"); this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadataFactory.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadataFactory.java index df801b6da3c88..7c6fd7ed13cce 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadataFactory.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadataFactory.java @@ -23,6 +23,7 @@ import com.facebook.presto.spi.connector.ConnectorMetadata; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpressionService; import jakarta.inject.Inject; @@ -35,6 +36,7 @@ public class IcebergHiveMetadataFactory final ExtendedHiveMetastore metastore; final HdfsEnvironment hdfsEnvironment; final TypeManager typeManager; + final ProcedureRegistry procedureRegistry; final JsonCodec commitTaskCodec; final StandardFunctionResolution functionResolution; final RowExpressionService rowExpressionService; @@ -52,6 +54,7 @@ public IcebergHiveMetadataFactory( ExtendedHiveMetastore metastore, HdfsEnvironment hdfsEnvironment, TypeManager typeManager, + ProcedureRegistry procedureRegistry, StandardFunctionResolution functionResolution, RowExpressionService rowExpressionService, JsonCodec commitTaskCodec, @@ -67,6 +70,7 @@ public IcebergHiveMetadataFactory( this.metastore = requireNonNull(metastore, "metastore is null"); this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.procedureRegistry = requireNonNull(procedureRegistry, "procedureRegistry is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); this.rowExpressionService = requireNonNull(rowExpressionService, "rowExpressionService is null"); this.commitTaskCodec = requireNonNull(commitTaskCodec, "commitTaskCodec is null"); @@ -86,6 +90,7 @@ public ConnectorMetadata create() metastore, hdfsEnvironment, typeManager, + procedureRegistry, functionResolution, rowExpressionService, commitTaskCodec, diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadata.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadata.java index 766f0bf389528..88b0ea6cddb44 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadata.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadata.java @@ -30,6 +30,7 @@ import com.facebook.presto.spi.SchemaTablePrefix; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpressionService; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -104,6 +105,7 @@ public class IcebergNativeMetadata public IcebergNativeMetadata( IcebergNativeCatalogFactory catalogFactory, TypeManager typeManager, + ProcedureRegistry procedureRegistry, StandardFunctionResolution functionResolution, RowExpressionService rowExpressionService, JsonCodec commitTaskCodec, @@ -113,7 +115,7 @@ public IcebergNativeMetadata( StatisticsFileCache statisticsFileCache, IcebergTableProperties tableProperties) { - super(typeManager, functionResolution, rowExpressionService, commitTaskCodec, nodeVersion, filterStatsCalculatorService, statisticsFileCache, tableProperties); + super(typeManager, procedureRegistry, functionResolution, rowExpressionService, commitTaskCodec, nodeVersion, filterStatsCalculatorService, statisticsFileCache, tableProperties); this.catalogFactory = requireNonNull(catalogFactory, "catalogFactory is null"); this.catalogType = requireNonNull(catalogType, "catalogType is null"); this.warehouseDataDir = Optional.ofNullable(catalogFactory.getCatalogWarehouseDataDir()); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadataFactory.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadataFactory.java index 59b203c623303..1ccf0daab1388 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadataFactory.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadataFactory.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.connector.ConnectorMetadata; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpressionService; import jakarta.inject.Inject; @@ -29,6 +30,7 @@ public class IcebergNativeMetadataFactory implements IcebergMetadataFactory { final TypeManager typeManager; + final ProcedureRegistry procedureRegistry; final JsonCodec commitTaskCodec; final IcebergNativeCatalogFactory catalogFactory; final CatalogType catalogType; @@ -44,6 +46,7 @@ public IcebergNativeMetadataFactory( IcebergConfig config, IcebergNativeCatalogFactory catalogFactory, TypeManager typeManager, + ProcedureRegistry procedureRegistry, StandardFunctionResolution functionResolution, RowExpressionService rowExpressionService, JsonCodec commitTaskCodec, @@ -54,6 +57,7 @@ public IcebergNativeMetadataFactory( { this.catalogFactory = requireNonNull(catalogFactory, "catalogFactory is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.procedureRegistry = requireNonNull(procedureRegistry, "procedureRegistry is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); this.rowExpressionService = requireNonNull(rowExpressionService, "rowExpressionService is null"); this.commitTaskCodec = requireNonNull(commitTaskCodec, "commitTaskCodec is null"); @@ -67,6 +71,6 @@ public IcebergNativeMetadataFactory( public ConnectorMetadata create() { - return new IcebergNativeMetadata(catalogFactory, typeManager, functionResolution, rowExpressionService, commitTaskCodec, catalogType, nodeVersion, filterStatsCalculatorService, statisticsFileCache, tableProperties); + return new IcebergNativeMetadata(catalogFactory, typeManager, procedureRegistry, functionResolution, rowExpressionService, commitTaskCodec, catalogType, nodeVersion, filterStatsCalculatorService, statisticsFileCache, tableProperties); } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSinkProvider.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSinkProvider.java index e8e8db1163aed..e14d0178b153d 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSinkProvider.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSinkProvider.java @@ -16,6 +16,7 @@ import com.facebook.airlift.json.JsonCodec; import com.facebook.presto.hive.HdfsContext; import com.facebook.presto.hive.HdfsEnvironment; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorInsertTableHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorPageSink; @@ -79,6 +80,12 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa return createPageSink(session, (IcebergWritableTableHandle) insertTableHandle); } + @Override + public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorDistributedProcedureHandle procedureHandle, PageSinkContext pageSinkContext) + { + return createPageSink(session, (IcebergWritableTableHandle) procedureHandle); + } + private ConnectorPageSink createPageSink(ConnectorSession session, IcebergWritableTableHandle tableHandle) { HdfsContext hdfsContext = new HdfsContext(session, tableHandle.getSchemaName(), tableHandle.getTableName().getTableName()); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergProcedureContext.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergProcedureContext.java new file mode 100644 index 0000000000000..a0ce2e325959d --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergProcedureContext.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.connector.ConnectorProcedureContext; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.Transaction; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static java.util.Objects.requireNonNull; + +public class IcebergProcedureContext + implements ConnectorProcedureContext +{ + final Set scannedDataFiles = new HashSet<>(); + final Set fullyAppliedDeleteFiles = new HashSet<>(); + final Map relevantData = new HashMap<>(); + Optional table = Optional.empty(); + Transaction transaction; + Optional connectorSplitSource = Optional.empty(); + + public void setTable(Table table) + { + this.table = Optional.of(table); + } + + public void setTransaction(Transaction transaction) + { + this.transaction = transaction; + } + + public Optional
getTable() + { + return table; + } + + public Transaction getTransaction() + { + return transaction; + } + + public void setConnectorSplitSource(ConnectorSplitSource connectorSplitSource) + { + requireNonNull(connectorSplitSource, "connectorSplitSource is null"); + this.connectorSplitSource = Optional.of(connectorSplitSource); + } + + public Optional getConnectorSplitSource() + { + return this.connectorSplitSource; + } + + public Set getScannedDataFiles() + { + return scannedDataFiles; + } + + public Set getFullyAppliedDeleteFiles() + { + return fullyAppliedDeleteFiles; + } + + public Map getRelevantData() + { + return relevantData; + } + + public void destroy() + { + this.relevantData.clear(); + this.scannedDataFiles.clear(); + this.fullyAppliedDeleteFiles.clear(); + this.connectorSplitSource.ifPresent(ConnectorSplitSource::close); + this.connectorSplitSource = null; + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplitManager.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplitManager.java index 0ad3345b7ae9d..69141aa4c5df8 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplitManager.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplitManager.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.ConnectorSplitSource; import com.facebook.presto.spi.ConnectorTableLayoutHandle; import com.facebook.presto.spi.FixedSplitSource; +import com.facebook.presto.spi.connector.ConnectorMetadata; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; @@ -35,6 +36,7 @@ import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; +import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadPoolExecutor; @@ -82,6 +84,15 @@ public ConnectorSplitSource getSplits( TupleDomain predicate = getNonMetadataColumnConstraints(layoutHandle .getValidPredicate()); + ConnectorMetadata connectorMetadata = transactionManager.get(transaction); + if (connectorMetadata != null) { + IcebergAbstractMetadata icebergMetadata = (IcebergAbstractMetadata) connectorMetadata; + Optional connectorSplitSource = icebergMetadata.getSplitSourceInCurrentCallProcedureTransaction(); + if (connectorSplitSource.isPresent()) { + return connectorSplitSource.get(); + } + } + Table icebergTable = getIcebergTable(transactionManager.get(transaction), session, table.getSchemaTableName()); if (table.getIcebergTableName().getTableType() == CHANGELOG) { diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/InternalIcebergConnectorFactory.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/InternalIcebergConnectorFactory.java index 4f7356be14567..b4e7f2304d63d 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/InternalIcebergConnectorFactory.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/InternalIcebergConnectorFactory.java @@ -49,6 +49,7 @@ import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpressionService; import com.facebook.presto.spi.session.PropertyMetadata; import com.google.common.collect.ImmutableSet; @@ -94,6 +95,7 @@ public static Connector createConnector( binder.bind(NodeVersion.class).toInstance(new NodeVersion(context.getNodeManager().getCurrentNode().getVersion())); binder.bind(NodeManager.class).toInstance(context.getNodeManager()); binder.bind(TypeManager.class).toInstance(context.getTypeManager()); + binder.bind(ProcedureRegistry.class).toInstance(context.getProcedureRegistry()); binder.bind(PageIndexerFactory.class).toInstance(context.getPageIndexerFactory()); binder.bind(PageSorter.class).toInstance(context.getPageSorter()); binder.bind(StandardFunctionResolution.class).toInstance(context.getStandardFunctionResolution()); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/RewriteDataFilesProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/RewriteDataFilesProcedure.java new file mode 100644 index 0000000000000..b65bf64a7c3ba --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/RewriteDataFilesProcedure.java @@ -0,0 +1,206 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.FixedSplitSource; +import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.procedure.DistributedProcedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; +import com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure; +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import io.airlift.slice.Slice; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileContent; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.RewriteFiles; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableScan; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.util.TableScanUtil; + +import javax.inject.Inject; +import javax.inject.Provider; + +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.function.Consumer; + +import static com.facebook.presto.common.type.StandardTypes.VARCHAR; +import static com.facebook.presto.iceberg.ExpressionConverter.toIcebergExpression; +import static com.facebook.presto.iceberg.IcebergSessionProperties.getCompressionCodec; +import static com.facebook.presto.iceberg.IcebergSessionProperties.getMinimumAssignedSplitWeight; +import static com.facebook.presto.iceberg.IcebergUtil.getColumns; +import static com.facebook.presto.iceberg.IcebergUtil.getFileFormat; +import static com.facebook.presto.iceberg.PartitionSpecConverter.toPrestoPartitionSpec; +import static com.facebook.presto.iceberg.SchemaConverter.toPrestoSchema; +import static com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure.SCHEMA; +import static com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure.TABLE_NAME; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +public class RewriteDataFilesProcedure + implements Provider +{ + TypeManager typeManager; + JsonCodec commitTaskCodec; + + @Inject + public RewriteDataFilesProcedure( + TypeManager typeManager, + JsonCodec commitTaskCodec) + { + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.commitTaskCodec = requireNonNull(commitTaskCodec, "commitTaskCodec is null"); + } + + @Override + public DistributedProcedure get() + { + return new TableDataRewriteDistributedProcedure( + "system", + "rewrite_data_files", + ImmutableList.of( + new Argument(SCHEMA, VARCHAR), + new Argument(TABLE_NAME, VARCHAR), + new Argument("filter", VARCHAR, false, "TRUE"), + new Argument("options", "map(varchar, varchar)", false, null)), + (session, procedureContext, tableLayoutHandle, arguments) -> beginCallDistributedProcedure(session, (IcebergProcedureContext) procedureContext, (IcebergTableLayoutHandle) tableLayoutHandle, arguments), + ((procedureContext, tableHandle, fragments) -> finishCallDistributedProcedure((IcebergProcedureContext) procedureContext, tableHandle, fragments)), + IcebergProcedureContext::new); + } + + private ConnectorDistributedProcedureHandle beginCallDistributedProcedure(ConnectorSession session, IcebergProcedureContext procedureContext, IcebergTableLayoutHandle layoutHandle, Object[] arguments) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(getClass().getClassLoader())) { + Table icebergTable = procedureContext.getTable().orElseThrow(() -> new VerifyException("No partition data for partitioned table")); + IcebergTableHandle tableHandle = layoutHandle.getTable(); + + ConnectorSplitSource splitSource; + if (!tableHandle.getIcebergTableName().getSnapshotId().isPresent()) { + splitSource = new FixedSplitSource(ImmutableList.of()); + } + else { + TupleDomain predicate = layoutHandle.getValidPredicate(); + TableScan tableScan = icebergTable.newScan() + .filter(toIcebergExpression(predicate)) + .useSnapshot(tableHandle.getIcebergTableName().getSnapshotId().get()); + + Consumer fileScanTaskConsumer = (task) -> { + procedureContext.getScannedDataFiles().add(task.file()); + if (!task.deletes().isEmpty()) { + task.deletes().forEach(deleteFile -> { + if (deleteFile.content() == FileContent.EQUALITY_DELETES && + !icebergTable.specs().get(deleteFile.specId()).isPartitioned() && + !predicate.isAll()) { + // Equality files with an unpartitioned spec are applied as global deletes + // So they should not be cleaned up unless the whole table is optimized + return; + } + procedureContext.getFullyAppliedDeleteFiles().add(deleteFile); + }); + } + }; + + splitSource = new CallDistributedProcedureSplitSource( + session, + tableScan, + TableScanUtil.splitFiles(tableScan.planFiles(), tableScan.targetSplitSize()), + Optional.of(fileScanTaskConsumer), + getMinimumAssignedSplitWeight(session)); + } + procedureContext.setConnectorSplitSource(splitSource); + + return new IcebergDistributedProcedureHandle( + tableHandle.getSchemaName(), + tableHandle.getIcebergTableName(), + toPrestoSchema(icebergTable.schema(), typeManager), + toPrestoPartitionSpec(icebergTable.spec(), typeManager), + getColumns(icebergTable.schema(), icebergTable.spec(), typeManager), + icebergTable.location(), + getFileFormat(icebergTable), + getCompressionCodec(session), + icebergTable.properties()); + } + } + + private void finishCallDistributedProcedure(IcebergProcedureContext procedureContext, ConnectorDistributedProcedureHandle procedureHandle, Collection fragments) + { + if (fragments.isEmpty() && + procedureContext.getScannedDataFiles().isEmpty() && + procedureContext.getFullyAppliedDeleteFiles().isEmpty()) { + return; + } + + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(getClass().getClassLoader())) { + IcebergDistributedProcedureHandle handle = (IcebergDistributedProcedureHandle) procedureHandle; + Table icebergTable = procedureContext.getTransaction().table(); + + List commitTasks = fragments.stream() + .map(slice -> commitTaskCodec.fromJson(slice.getBytes())) + .collect(toImmutableList()); + + org.apache.iceberg.types.Type[] partitionColumnTypes = icebergTable.spec().fields().stream() + .map(field -> field.transform().getResultType( + icebergTable.schema().findType(field.sourceId()))) + .toArray(Type[]::new); + + Set newFiles = new HashSet<>(); + for (CommitTaskData task : commitTasks) { + DataFiles.Builder builder = DataFiles.builder(icebergTable.spec()) + .withPath(task.getPath()) + .withFileSizeInBytes(task.getFileSizeInBytes()) + .withFormat(handle.getFileFormat().name()) + .withMetrics(task.getMetrics().metrics()); + + if (!icebergTable.spec().fields().isEmpty()) { + String partitionDataJson = task.getPartitionDataJson() + .orElseThrow(() -> new VerifyException("No partition data for partitioned table")); + builder.withPartition(PartitionData.fromJson(partitionDataJson, partitionColumnTypes)); + } + newFiles.add(builder.build()); + } + + RewriteFiles rewriteFiles = procedureContext.getTransaction().newRewrite(); + Set scannedDataFiles = procedureContext.getScannedDataFiles(); + Set fullyAppliedDeleteFiles = procedureContext.getFullyAppliedDeleteFiles(); + rewriteFiles.rewriteFiles(scannedDataFiles, fullyAppliedDeleteFiles, newFiles, ImmutableSet.of()); + + // Table.snapshot method returns null if there is no matching snapshot + Snapshot snapshot = requireNonNull( + handle.getTableName() + .getSnapshotId() + .map(icebergTable::snapshot) + .orElse(null), + "snapshot is null"); + if (icebergTable.currentSnapshot() != null) { + rewriteFiles.validateFromSnapshot(snapshot.snapshotId()); + } + rewriteFiles.commit(); + } + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/ExpireSnapshotsProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/ExpireSnapshotsProcedure.java index f91bbac8f2c2c..9c5e9d5460367 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/ExpireSnapshotsProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/ExpireSnapshotsProcedure.java @@ -20,7 +20,9 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; import jakarta.inject.Inject; import org.apache.iceberg.ExpireSnapshots; @@ -61,15 +63,15 @@ public ExpireSnapshotsProcedure(IcebergMetadataFactory metadataFactory) @Override public Procedure get() { - return new Procedure( + return new LocalProcedure( "system", "expire_snapshots", ImmutableList.of( - new Procedure.Argument("schema", VARCHAR), - new Procedure.Argument("table_name", VARCHAR), - new Procedure.Argument("older_than", TIMESTAMP, false, null), - new Procedure.Argument("retain_last", INTEGER, false, null), - new Procedure.Argument("snapshot_ids", "array(bigint)", false, null)), + new Argument("schema", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("older_than", TIMESTAMP, false, null), + new Argument("retain_last", INTEGER, false, null), + new Argument("snapshot_ids", "array(bigint)", false, null)), EXPIRE_SNAPSHOTS.bindTo(this)); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/FastForwardBranchProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/FastForwardBranchProcedure.java index 3cadc073ef1d8..a80a90e5d05c4 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/FastForwardBranchProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/FastForwardBranchProcedure.java @@ -17,7 +17,9 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; import jakarta.inject.Inject; import org.apache.iceberg.Table; @@ -54,14 +56,14 @@ public FastForwardBranchProcedure(IcebergMetadataFactory metadataFactory) @Override public Procedure get() { - return new Procedure( + return new LocalProcedure( "system", "fast_forward", ImmutableList.of( - new Procedure.Argument("schema", VARCHAR), - new Procedure.Argument("table_name", VARCHAR), - new Procedure.Argument("branch", VARCHAR), - new Procedure.Argument("to", VARCHAR)), + new Argument("schema", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("branch", VARCHAR), + new Argument("to", VARCHAR)), FAST_FORWARD.bindTo(this)); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/ManifestFileCacheInvalidationProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/ManifestFileCacheInvalidationProcedure.java index 1f3eb2f708eb6..15154e4e9e07f 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/ManifestFileCacheInvalidationProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/ManifestFileCacheInvalidationProcedure.java @@ -15,6 +15,7 @@ import com.facebook.presto.iceberg.ManifestFileCache; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure; import com.google.common.collect.ImmutableList; import jakarta.inject.Inject; @@ -44,7 +45,7 @@ public ManifestFileCacheInvalidationProcedure(ManifestFileCache manifestFileCach @Override public Procedure get() { - return new Procedure( + return new LocalProcedure( "system", "invalidate_manifest_file_cache", ImmutableList.of(), diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RegisterTableProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RegisterTableProcedure.java index e00216f5a9867..129ddd81b3256 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RegisterTableProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RegisterTableProcedure.java @@ -23,7 +23,9 @@ import com.facebook.presto.spi.SchemaNotFoundException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; import jakarta.inject.Inject; import org.apache.hadoop.fs.FileStatus; @@ -79,14 +81,14 @@ public RegisterTableProcedure( @Override public Procedure get() { - return new Procedure( + return new LocalProcedure( "system", "register_table", ImmutableList.of( - new Procedure.Argument("schema", VARCHAR), - new Procedure.Argument("table_name", VARCHAR), - new Procedure.Argument("metadata_location", VARCHAR), - new Procedure.Argument("metadata_file", VARCHAR, false, null)), + new Argument("schema", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("metadata_location", VARCHAR), + new Argument("metadata_file", VARCHAR, false, null)), REGISTER_TABLE.bindTo(this)); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RemoveOrphanFiles.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RemoveOrphanFiles.java index 2d88d01c14501..2ce3df04f0f49 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RemoveOrphanFiles.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RemoveOrphanFiles.java @@ -23,6 +23,7 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; @@ -91,7 +92,7 @@ public RemoveOrphanFiles(IcebergMetadataFactory metadataFactory, @Override public Procedure get() { - return new Procedure( + return new LocalProcedure( "system", "remove_orphan_files", ImmutableList.of( diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToSnapshotProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToSnapshotProcedure.java index f1b622a6f5977..3e6c1594c6784 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToSnapshotProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToSnapshotProcedure.java @@ -17,7 +17,9 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; import jakarta.inject.Inject; @@ -53,13 +55,13 @@ public RollbackToSnapshotProcedure(IcebergMetadataFactory metadataFactory) @Override public Procedure get() { - return new Procedure( + return new LocalProcedure( "system", "rollback_to_snapshot", ImmutableList.of( - new Procedure.Argument("schema", VARCHAR), - new Procedure.Argument("table_name", VARCHAR), - new Procedure.Argument("snapshot_id", BIGINT)), + new Argument("schema", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("snapshot_id", BIGINT)), ROLLBACK_TO_SNAPSHOT.bindTo(this)); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToTimestampProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToTimestampProcedure.java index 77ba3a8144a5e..c583cb91c10b3 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToTimestampProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToTimestampProcedure.java @@ -19,7 +19,9 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; import jakarta.inject.Inject; @@ -55,13 +57,13 @@ public RollbackToTimestampProcedure(IcebergMetadataFactory metadataFactory) @Override public Procedure get() { - return new Procedure( + return new LocalProcedure( "system", "rollback_to_timestamp", ImmutableList.of( - new Procedure.Argument("schema", VARCHAR), - new Procedure.Argument("table_name", VARCHAR), - new Procedure.Argument("timestamp", TIMESTAMP)), + new Argument("schema", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("timestamp", TIMESTAMP)), ROLLBACK_TO_TIMESTAMP.bindTo(this)); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetCurrentSnapshotProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetCurrentSnapshotProcedure.java index abf9185b7071a..9998b0b601fe7 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetCurrentSnapshotProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetCurrentSnapshotProcedure.java @@ -17,7 +17,9 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; import jakarta.inject.Inject; import org.apache.iceberg.SnapshotRef; @@ -57,14 +59,14 @@ public SetCurrentSnapshotProcedure(IcebergMetadataFactory metadataFactory) @Override public Procedure get() { - return new Procedure( + return new LocalProcedure( "system", "set_current_snapshot", ImmutableList.of( - new Procedure.Argument("schema", VARCHAR), - new Procedure.Argument("table_name", VARCHAR), - new Procedure.Argument("snapshot_id", BIGINT, false, null), - new Procedure.Argument("ref", VARCHAR, false, null)), + new Argument("schema", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("snapshot_id", BIGINT, false, null), + new Argument("ref", VARCHAR, false, null)), SET_CURRENT_SNAPSHOT.bindTo(this)); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetTablePropertyProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetTablePropertyProcedure.java index 76bdb75fb7acd..2f01bf2d9e2ee 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetTablePropertyProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetTablePropertyProcedure.java @@ -22,7 +22,9 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; import jakarta.inject.Inject; import org.apache.iceberg.Table; @@ -66,14 +68,14 @@ public SetTablePropertyProcedure( @Override public Procedure get() { - return new Procedure( + return new LocalProcedure( "system", "set_table_property", ImmutableList.of( - new Procedure.Argument("schema", VARCHAR), - new Procedure.Argument("table_name", VARCHAR), - new Procedure.Argument("key", VARCHAR), - new Procedure.Argument("value", VARCHAR)), + new Argument("schema", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("key", VARCHAR), + new Argument("value", VARCHAR)), SET_TABLE_PROPERTY.bindTo(this)); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/StatisticsFileCacheInvalidationProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/StatisticsFileCacheInvalidationProcedure.java index 6e91c57b53cf3..ab407a019bd2f 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/StatisticsFileCacheInvalidationProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/StatisticsFileCacheInvalidationProcedure.java @@ -15,6 +15,7 @@ import com.facebook.presto.iceberg.statistics.StatisticsFileCache; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure; import com.google.common.collect.ImmutableList; import jakarta.inject.Inject; @@ -44,7 +45,7 @@ public StatisticsFileCacheInvalidationProcedure(StatisticsFileCache statisticsFi @Override public Procedure get() { - return new Procedure( + return new LocalProcedure( "system", "invalidate_statistics_file_cache", ImmutableList.of(), diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/UnregisterTableProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/UnregisterTableProcedure.java index 62709b89923c1..a8e8c4198be6f 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/UnregisterTableProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/UnregisterTableProcedure.java @@ -19,7 +19,9 @@ import com.facebook.presto.spi.SchemaNotFoundException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; import jakarta.inject.Inject; @@ -52,12 +54,12 @@ public UnregisterTableProcedure(IcebergMetadataFactory metadataFactory) @Override public Procedure get() { - return new Procedure( + return new LocalProcedure( "system", "unregister_table", ImmutableList.of( - new Procedure.Argument("schema", VARCHAR), - new Procedure.Argument("table_name", VARCHAR)), + new Argument("schema", VARCHAR), + new Argument("table_name", VARCHAR)), UNREGISTER_TABLE.bindTo(this)); } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedSmokeTestBase.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedSmokeTestBase.java index 39e13724d44e9..1101edd14f3b5 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedSmokeTestBase.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedSmokeTestBase.java @@ -73,6 +73,7 @@ import static com.facebook.presto.iceberg.procedure.RegisterTableProcedure.getFileSystem; import static com.facebook.presto.iceberg.procedure.RegisterTableProcedure.resolveLatestMetadataLocation; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; +import static com.facebook.presto.tests.sql.TestTable.randomTableSuffix; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.Iterables.getOnlyElement; import static java.lang.String.format; @@ -2013,6 +2014,66 @@ public void testMetadataDeleteOnTableWithUnsupportedSpecsWhoseDataAllDeleted(Str } } + @Test(dataProvider = "version_and_mode") + public void testMetadataDeleteOnTableAfterWholeRewriteDataFiles(String version, String mode) + { + String errorMessage = "This connector only supports delete where one or more partitions are deleted entirely.*"; + String schemaName = getSession().getSchema().get(); + String tableName = "test_rewrite_data_files_table_" + randomTableSuffix(); + try { + // Create a table with partition column `a`, and insert some data under this partition spec + assertUpdate("CREATE TABLE " + tableName + " (a INTEGER, b VARCHAR) WITH (format_version = '" + version + "', delete_mode = '" + mode + "')"); + assertUpdate("INSERT INTO " + tableName + " VALUES (1, '1001'), (2, '1002')", 2); + + // Then evaluate the partition spec by adding a partition column `c`, and insert some data under the new partition spec + assertUpdate("ALTER TABLE " + tableName + " ADD COLUMN c INTEGER WITH (partitioning = 'identity')"); + assertUpdate("INSERT INTO " + tableName + " VALUES (3, '1003', 3), (4, '1004', 4), (5, '1005', 5)", 3); + + // Do not support metadata delete with filter on column `c`, because we have data with old partition spec + assertQueryFails("DELETE FROM " + tableName + " WHERE c > 3", errorMessage); + + // Call procedure rewrite_data_files without filter to rewrite all data files + assertUpdate("call system.rewrite_data_files(table_name => '" + tableName + "', schema => '" + schemaName + "')", 5); + + // Then we can do metadata delete on column `c`, because all data files are rewritten under new partition spec + assertUpdate("DELETE FROM " + tableName + " WHERE c > 3", 2); + assertQuery("SELECT * FROM " + tableName, "VALUES (1, '1001', NULL), (2, '1002', NULL), (3, '1003', 3)"); + } + finally { + dropTable(getSession(), tableName); + } + } + + @Test(dataProvider = "version_and_mode") + public void testMetadataDeleteOnTableAfterPartialRewriteDataFiles(String version, String mode) + { + String errorMessage = "This connector only supports delete where one or more partitions are deleted entirely.*"; + String schemaName = getSession().getSchema().get(); + String tableName = "test_rewrite_data_files_table_" + randomTableSuffix(); + try { + // Create a table with partition column `a`, and insert some data under this partition spec + assertUpdate("CREATE TABLE " + tableName + " (a INTEGER, b VARCHAR) WITH (format_version = '" + version + "', delete_mode = '" + mode + "', partitioning = ARRAY['a'])"); + assertUpdate("INSERT INTO " + tableName + " VALUES (1, '1001'), (2, '1002')", 2); + + // Then evaluate the partition spec by adding a partition column `c`, and insert some data under the new partition spec + assertUpdate("ALTER TABLE " + tableName + " ADD COLUMN c INTEGER WITH (partitioning = 'identity')"); + assertUpdate("INSERT INTO " + tableName + " VALUES (3, '1003', 3), (4, '1004', 4), (5, '1005', 5)", 3); + + // Do not support metadata delete with filter on column `c`, because we have data with old partition spec + assertQueryFails("DELETE FROM " + tableName + " WHERE c > 3", errorMessage); + + // Call procedure rewrite_data_files with filter to rewrite data files under the prior partition spec + assertUpdate("call system.rewrite_data_files(table_name => '" + tableName + "', schema => '" + schemaName + "', filter => 'a in (1, 2)')", 2); + + // Then we can do metadata delete on column `c`, because all data files are now under new partition spec + assertUpdate("DELETE FROM " + tableName + " WHERE c > 3", 2); + assertQuery("SELECT * FROM " + tableName, "VALUES (1, '1001', NULL), (2, '1002', NULL), (3, '1003', 3)"); + } + finally { + dropTable(getSession(), tableName); + } + } + @DataProvider(name = "version_and_mode") public Object[][] versionAndMode() { diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java index 1acc165f4199c..68b55de1f3baa 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java @@ -2039,6 +2039,62 @@ public void testDecimal(boolean decimalVectorReaderEnabled) } } + public void testMetadataDeleteOnV2MorTableWithRewriteDataFiles() + { + String tableName = "test_rewrite_data_files_table_" + randomTableSuffix(); + try { + // Create a table with partition column `a`, and insert some data under this partition spec + assertUpdate("CREATE TABLE " + tableName + " (a INTEGER, b VARCHAR) WITH (format_version = '2', delete_mode = 'merge-on-read')"); + assertUpdate("INSERT INTO " + tableName + " VALUES (1, '1001'), (2, '1002')", 2); + assertUpdate("DELETE FROM " + tableName + " WHERE a = 1", 1); + assertQuery("SELECT * FROM " + tableName, "VALUES (2, '1002')"); + + Table icebergTable = loadTable(tableName); + assertHasDataFiles(icebergTable.currentSnapshot(), 1); + assertHasDeleteFiles(icebergTable.currentSnapshot(), 1); + + // Evaluate the partition spec by adding a partition column `c`, and insert some data under the new partition spec + assertUpdate("ALTER TABLE " + tableName + " ADD COLUMN c INTEGER WITH (partitioning = 'identity')"); + assertUpdate("INSERT INTO " + tableName + " VALUES (3, '1003', 3), (4, '1004', 4), (5, '1005', 5)", 3); + + icebergTable = loadTable(tableName); + assertHasDataFiles(icebergTable.currentSnapshot(), 4); + assertHasDeleteFiles(icebergTable.currentSnapshot(), 1); + + // Execute row level delete with filter on column `b` + assertUpdate("DELETE FROM " + tableName + " WHERE b = '1004'", 1); + assertQuery("SELECT * FROM " + tableName, "VALUES (2, '1002', NULL), (3, '1003', 3), (5, '1005', 5)"); + icebergTable = loadTable(tableName); + assertHasDataFiles(icebergTable.currentSnapshot(), 4); + assertHasDeleteFiles(icebergTable.currentSnapshot(), 2); + + assertQueryFails("call system.rewrite_data_files(table_name => '" + tableName + "', schema => 'tpch', filter => 'a > 3')", ".*"); + assertQueryFails("call system.rewrite_data_files(table_name => '" + tableName + "', schema => 'tpch', filter => 'c > 3')", ".*"); + + assertUpdate("call system.rewrite_data_files(table_name => '" + tableName + "', schema => 'tpch')", 3); + assertQuery("SELECT * FROM " + tableName, "VALUES (2, '1002', NULL), (3, '1003', 3), (5, '1005', 5)"); + icebergTable = loadTable(tableName); + assertHasDataFiles(icebergTable.currentSnapshot(), 3); + assertHasDeleteFiles(icebergTable.currentSnapshot(), 0); + + // Do metadata delete on column `a`, because all partition specs contains partition column `a` + assertUpdate("DELETE FROM " + tableName + " WHERE c = 5", 1); + assertQuery("SELECT * FROM " + tableName, "VALUES (2, '1002', NULL), (3, '1003', 3)"); + icebergTable = loadTable(tableName); + assertHasDataFiles(icebergTable.currentSnapshot(), 2); + assertHasDeleteFiles(icebergTable.currentSnapshot(), 0); + + assertUpdate("call system.rewrite_data_files(table_name => '" + tableName + "', schema => 'tpch', filter => 'c > 2')", 1); + assertQuery("SELECT * FROM " + tableName, "VALUES (2, '1002', NULL), (3, '1003', 3)"); + icebergTable = loadTable(tableName); + assertHasDataFiles(icebergTable.currentSnapshot(), 2); + assertHasDeleteFiles(icebergTable.currentSnapshot(), 0); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + tableName); + } + } + @Test public void testRefsTable() { @@ -2887,14 +2943,14 @@ private void testWithAllFileFormats(Session session, BiConsumer map = snapshot.summary(); int totalDataFiles = Integer.valueOf(map.get(TOTAL_DATA_FILES_PROP)); assertEquals(totalDataFiles, dataFilesCount); } - private void assertHasDeleteFiles(Snapshot snapshot, int deleteFilesCount) + protected void assertHasDeleteFiles(Snapshot snapshot, int deleteFilesCount) { Map map = snapshot.summary(); int totalDeleteFiles = Integer.valueOf(map.get(TOTAL_DELETE_FILES_PROP)); diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergLogicalPlanner.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergLogicalPlanner.java index ed274b55b7bf5..98e8521dbaf04 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergLogicalPlanner.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergLogicalPlanner.java @@ -98,6 +98,7 @@ import static com.facebook.presto.sql.planner.assertions.MatchResult.match; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyNot; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.callDistributedProcedure; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; @@ -107,8 +108,13 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictTableScan; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFinish; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; @@ -730,6 +736,50 @@ public void testThoroughlyPushdownForTableWithUnsupportedSpecsWhoseDataAllDelete } } + @Test + public void testCallDistributedProcedureOnPartitionedTable() + { + String tableName = "partition_table_for_call_distributed_procedure"; + try { + assertUpdate("CREATE TABLE " + tableName + " (c1 integer, c2 varchar) with (partitioning = ARRAY['c1'])"); + assertUpdate("INSERT INTO " + tableName + " values(1, 'foo'), (2, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(3, 'foo'), (4, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(5, 'foo'), (6, 'bar')", 2); + + assertPlan(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s')", tableName, getSession().getSchema().get()), + output(tableFinish(exchange(REMOTE_STREAMING, GATHER, + callDistributedProcedure( + exchange(LOCAL, GATHER, + exchange(REMOTE_STREAMING, REPARTITION, + strictTableScan(tableName, identityMap("c1", "c2"))))))))); + + // Do not support the filter that couldn't be enforced totally by tableScan + assertQueryFails(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'c2 > ''bar''')", tableName, getSession().getSchema().get()), + "Unexpected FilterNode found in plan; probably connector was not able to handle provided WHERE expression"); + + // Support the filter that could be enforced totally by tableScan + assertPlan(getSession(), format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'c1 > 3')", tableName, getSession().getSchema().get()), + output(tableFinish(exchange(REMOTE_STREAMING, GATHER, + callDistributedProcedure( + exchange(LOCAL, GATHER, + exchange(REMOTE_STREAMING, REPARTITION, + strictTableScan(tableName, identityMap("c1", "c2")))))))), + plan -> assertTableLayout( + plan, + tableName, + withColumnDomains(ImmutableMap.of( + new Subfield( + "c1", + ImmutableList.of()), + Domain.create(ValueSet.ofRanges(greaterThan(INTEGER, 3L)), false))), + TRUE_CONSTANT, + ImmutableSet.of("c1"))); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + @DataProvider(name = "timezones") public Object[][] timezones() { diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestRewriteDataFilesProcedure.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestRewriteDataFilesProcedure.java new file mode 100644 index 0000000000000..fb79f69618fca --- /dev/null +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestRewriteDataFilesProcedure.java @@ -0,0 +1,508 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.google.common.collect.ImmutableMap; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.hadoop.HadoopCatalog; +import org.apache.iceberg.io.CloseableIterator; +import org.testng.annotations.Test; + +import java.io.File; +import java.nio.file.Path; +import java.util.Map; +import java.util.OptionalInt; +import java.util.concurrent.atomic.AtomicInteger; + +import static com.facebook.presto.iceberg.CatalogType.HADOOP; +import static com.facebook.presto.iceberg.FileFormat.PARQUET; +import static com.facebook.presto.iceberg.IcebergQueryRunner.ICEBERG_CATALOG; +import static com.facebook.presto.iceberg.IcebergQueryRunner.getIcebergDataDirectoryPath; +import static java.lang.String.format; +import static org.apache.iceberg.SnapshotSummary.TOTAL_DATA_FILES_PROP; +import static org.apache.iceberg.SnapshotSummary.TOTAL_DELETE_FILES_PROP; +import static org.apache.iceberg.expressions.Expressions.alwaysTrue; +import static org.testng.Assert.assertEquals; + +public class TestRewriteDataFilesProcedure + extends AbstractTestQueryFramework +{ + public static final String TEST_SCHEMA = "tpch"; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return IcebergQueryRunner.builder() + .setCatalogType(HADOOP) + .setFormat(PARQUET) + .setNodeCount(OptionalInt.of(1)) + .setCreateTpchTables(false) + .setAddJmxPlugin(false) + .build().getQueryRunner(); + } + + public void dropTable(String tableName) + { + assertQuerySucceeds("DROP TABLE IF EXISTS " + tableName); + } + + @Test + public void testRewriteDataFilesInEmptyTable() + { + String tableName = "default_empty_table"; + try { + assertUpdate("CREATE TABLE " + tableName + " (id integer, value integer)"); + assertUpdate(format("CALL system.rewrite_data_files('%s', '%s')", TEST_SCHEMA, tableName), 0); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testRewriteDataFilesOnPartitionTable() + { + String tableName = "example_partition_table"; + try { + assertUpdate("CREATE TABLE " + tableName + " (c1 integer, c2 varchar) with (partitioning = ARRAY['c2'])"); + + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + assertUpdate("INSERT INTO " + tableName + " values(1, 'foo'), (2, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(3, 'foo'), (4, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(5, 'foo'), (6, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(7, 'foo'), (8, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(9, 'foo'), (10, 'bar')", 2); + + Table table = loadTable(tableName); + assertHasSize(table.snapshots(), 5); + //The number of data files is 10,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 10); + assertHasDeleteFiles(table.currentSnapshot(), 0); + CloseableIterator fileScanTasks = table.newScan() + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 10, 0); + + assertUpdate("DELETE from " + tableName + " WHERE c1 = 7", 1); + assertUpdate("DELETE from " + tableName + " WHERE c1 in (8, 10)", 2); + + table.refresh(); + assertHasSize(table.snapshots(), 7); + //The number of data files is 10,and the number of delete files is 3 + assertHasDataFiles(table.currentSnapshot(), 10); + assertHasDeleteFiles(table.currentSnapshot(), 3); + assertQuery("select * from " + tableName, + "values(1, 'foo'), (2, 'bar'), " + + "(3, 'foo'), (4, 'bar'), " + + "(5, 'foo'), (6, 'bar'), " + + "(9, 'foo')"); + + assertUpdate(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s')", tableName, TEST_SCHEMA), 7); + + table.refresh(); + assertHasSize(table.snapshots(), 8); + //The number of data files is 2,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 2); + assertHasDeleteFiles(table.currentSnapshot(), 0); + fileScanTasks = table.newScan() + .filter(alwaysTrue()) + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 2, 0); + assertQuery("select * from " + tableName, + "values(1, 'foo'), (2, 'bar'), " + + "(3, 'foo'), (4, 'bar'), " + + "(5, 'foo'), (6, 'bar'), " + + "(9, 'foo')"); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testRewriteDataFilesOnNonPartitionTable() + { + String tableName = "example_non_partition_table"; + try { + assertUpdate("CREATE TABLE " + tableName + " (c1 integer, c2 varchar)"); + + // create 5 files + assertUpdate("INSERT INTO " + tableName + " values(1, 'foo'), (2, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(3, 'foo'), (4, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(5, 'foo'), (6, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(7, 'foo'), (8, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(9, 'foo'), (10, 'bar')", 2); + + Table table = loadTable(tableName); + assertHasSize(table.snapshots(), 5); + //The number of data files is 5,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 5); + assertHasDeleteFiles(table.currentSnapshot(), 0); + CloseableIterator fileScanTasks = table.newScan() + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 5, 0); + + assertUpdate("DELETE from " + tableName + " WHERE c1 = 7", 1); + assertUpdate("DELETE from " + tableName + " WHERE c1 in (9, 10)", 2); + + table.refresh(); + assertHasSize(table.snapshots(), 7); + //The number of data files is 5,and the number of delete files is 2 + assertHasDataFiles(table.currentSnapshot(), 5); + assertHasDeleteFiles(table.currentSnapshot(), 2); + assertQuery("select * from " + tableName, + "values(1, 'foo'), (2, 'bar'), " + + "(3, 'foo'), (4, 'bar'), " + + "(5, 'foo'), (6, 'bar'), " + + "(8, 'bar')"); + + assertUpdate(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s')", tableName, TEST_SCHEMA), 7); + + table.refresh(); + assertHasSize(table.snapshots(), 8); + //The number of data files is 1,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 1); + assertHasDeleteFiles(table.currentSnapshot(), 0); + fileScanTasks = table.newScan() + .filter(alwaysTrue()) + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 1, 0); + assertQuery("select * from " + tableName, + "values(1, 'foo'), (2, 'bar'), " + + "(3, 'foo'), (4, 'bar'), " + + "(5, 'foo'), (6, 'bar'), " + + "(8, 'bar')"); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testRewriteDataFilesWithFilter() + { + String tableName = "example_partition_filter_table"; + try { + assertUpdate("CREATE TABLE " + tableName + " (c1 integer, c2 varchar) with (partitioning = ARRAY['c2'])"); + + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + assertUpdate("INSERT INTO " + tableName + " values(1, 'foo'), (2, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(3, 'foo'), (4, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(5, 'foo'), (6, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(7, 'foo'), (8, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(9, 'foo'), (10, 'bar')", 2); + + Table table = loadTable(tableName); + assertHasSize(table.snapshots(), 5); + //The number of data files is 10,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 10); + assertHasDeleteFiles(table.currentSnapshot(), 0); + CloseableIterator fileScanTasks = table.newScan() + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 10, 0); + + // do not support rewrite files filtered by non-identity columns + assertQueryFails(format("call system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'c1 > 3')", tableName, TEST_SCHEMA), ".*"); + + // select 5 files to rewrite + assertUpdate(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'c2 = ''bar''')", tableName, TEST_SCHEMA), 5); + table.refresh(); + assertHasSize(table.snapshots(), 6); + //The number of data files is 6,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 6); + assertHasDeleteFiles(table.currentSnapshot(), 0); + fileScanTasks = table.newScan() + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 6, 0); + + assertQuery("select * from " + tableName, + "values(1, 'foo'), (2, 'bar'), " + + "(3, 'foo'), (4, 'bar'), " + + "(5, 'foo'), (6, 'bar'), " + + "(7, 'foo'), (8, 'bar'), " + + "(9, 'foo'), (10, 'bar')"); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testRewriteDataFilesWithDeterministicTrueFilter() + { + String tableName = "example_non_partition_true_filter_table"; + try { + assertUpdate("CREATE TABLE " + tableName + " (c1 integer, c2 varchar)"); + + // create 5 files + assertUpdate("INSERT INTO " + tableName + " values(1, 'foo'), (2, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(3, 'foo'), (4, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(5, 'foo'), (6, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(7, 'foo'), (8, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(9, 'foo'), (10, 'bar')", 2); + + Table table = loadTable(tableName); + assertHasSize(table.snapshots(), 5); + //The number of data files is 5,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 5); + assertHasDeleteFiles(table.currentSnapshot(), 0); + CloseableIterator fileScanTasks = table.newScan() + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 5, 0); + + // do not support rewrite files filtered by non-identity columns + assertQueryFails(format("call system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'c1 > 3')", tableName, TEST_SCHEMA), ".*"); + + // the filter is `true` means select all files to rewrite + assertUpdate(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => '1 = 1')", tableName, TEST_SCHEMA), 10); + + table.refresh(); + assertHasSize(table.snapshots(), 6); + //The number of data files is 1,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 1); + assertHasDeleteFiles(table.currentSnapshot(), 0); + fileScanTasks = table.newScan() + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 1, 0); + + assertQuery("select * from " + tableName, + "values(1, 'foo'), (2, 'bar'), " + + "(3, 'foo'), (4, 'bar'), " + + "(5, 'foo'), (6, 'bar'), " + + "(7, 'foo'), (8, 'bar'), " + + "(9, 'foo'), (10, 'bar')"); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testRewriteDataFilesWithDeterministicFalseFilter() + { + String tableName = "example_non_partition_false_filter_table"; + try { + assertUpdate("CREATE TABLE " + tableName + " (c1 integer, c2 varchar)"); + + // create 5 files + assertUpdate("INSERT INTO " + tableName + " values(1, 'foo'), (2, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(3, 'foo'), (4, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(5, 'foo'), (6, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(7, 'foo'), (8, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(9, 'foo'), (10, 'bar')", 2); + + Table table = loadTable(tableName); + assertHasSize(table.snapshots(), 5); + //The number of data files is 5,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 5); + assertHasDeleteFiles(table.currentSnapshot(), 0); + CloseableIterator fileScanTasks = table.newScan() + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 5, 0); + + // the filter is `false` means select no file to rewrite + assertUpdate(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => '1 = 0')", tableName, TEST_SCHEMA), 0); + + table.refresh(); + assertHasSize(table.snapshots(), 5); + //The number of data files is still 5,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 5); + assertHasDeleteFiles(table.currentSnapshot(), 0); + fileScanTasks = table.newScan() + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 5, 0); + + assertQuery("select * from " + tableName, + "values(1, 'foo'), (2, 'bar'), " + + "(3, 'foo'), (4, 'bar'), " + + "(5, 'foo'), (6, 'bar'), " + + "(7, 'foo'), (8, 'bar'), " + + "(9, 'foo'), (10, 'bar')"); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testRewriteDataFilesWithDeleteAndPartitionEvolution() + { + String tableName = "example_partition_evolution_table"; + try { + assertUpdate("CREATE TABLE " + tableName + " (a int, b varchar)"); + assertUpdate("INSERT INTO " + tableName + " values(1, '1001'), (2, '1002')", 2); + assertUpdate("DELETE FROM " + tableName + " WHERE a = 1", 1); + assertQuery("select * from " + tableName, "values(2, '1002')"); + + Table table = loadTable(tableName); + assertHasSize(table.snapshots(), 2); + //The number of data files is 1,and the number of delete files is 1 + assertHasDataFiles(table.currentSnapshot(), 1); + assertHasDeleteFiles(table.currentSnapshot(), 1); + + assertUpdate("alter table " + tableName + " add column c int with (partitioning = 'identity')"); + assertUpdate("INSERT INTO " + tableName + " values(5, '1005', 5), (6, '1006', 6), (7, '1007', 7)", 3); + assertUpdate("DELETE FROM " + tableName + " WHERE b = '1006'", 1); + assertQuery("select * from " + tableName, "values(2, '1002', NULL), (5, '1005', 5), (7, '1007', 7)"); + + table.refresh(); + assertHasSize(table.snapshots(), 4); + //The number of data files is 4,and the number of delete files is 2 + assertHasDataFiles(table.currentSnapshot(), 4); + assertHasDeleteFiles(table.currentSnapshot(), 2); + + assertQueryFails(format("call system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'a > 3')", tableName, TEST_SCHEMA), ".*"); + assertQueryFails(format("call system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'c > 3')", tableName, TEST_SCHEMA), ".*"); + + assertUpdate(format("call system.rewrite_data_files(table_name => '%s', schema => '%s')", tableName, TEST_SCHEMA), 3); + table.refresh(); + assertHasSize(table.snapshots(), 5); + //The number of data files is 3,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 3); + assertHasDeleteFiles(table.currentSnapshot(), 0); + CloseableIterator fileScanTasks = table.newScan() + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 3, 0); + assertQuery("select * from " + tableName, "values(2, '1002', NULL), (5, '1005', 5), (7, '1007', 7)"); + + assertUpdate("delete from " + tableName + " where b = '1002'", 1); + table.refresh(); + assertHasSize(table.snapshots(), 6); + //The number of data files is 3,and the number of delete files is 1 + assertHasDataFiles(table.currentSnapshot(), 3); + assertHasDeleteFiles(table.currentSnapshot(), 1); + assertUpdate(format("call system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'c is null')", tableName, TEST_SCHEMA), 0); + + table.refresh(); + assertHasSize(table.snapshots(), 7); + //The number of data files is 2,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 2); + assertHasDeleteFiles(table.currentSnapshot(), 0); + assertQuery("select * from " + tableName, "values(5, '1005', 5), (7, '1007', 7)"); + + // This is a metadata delete + assertUpdate("delete from " + tableName + " where c = 7", 1); + table.refresh(); + assertHasSize(table.snapshots(), 8); + //The number of data files is 1,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 1); + assertHasDeleteFiles(table.currentSnapshot(), 0); + assertQuery("select * from " + tableName, "values(5, '1005', 5)"); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testInvalidParameterCases() + { + String tableName = "invalid_parameter_table"; + try { + assertUpdate("CREATE TABLE " + tableName + " (a int, b varchar, c int)"); + assertQueryFails("CALL system.rewrite_data_files('n', table_name => 't')", ".*Named and positional arguments cannot be mixed"); + assertQueryFails("CALL custom.rewrite_data_files('n', 't')", "Procedure not registered: custom.rewrite_data_files"); + assertQueryFails("CALL system.rewrite_data_files()", ".*Required procedure argument 'schema' is missing"); + assertQueryFails("CALL system.rewrite_data_files('s', 'n')", "Schema s does not exist"); + assertQueryFails("CALL system.rewrite_data_files('', '')", "Table name is empty"); + assertQueryFails(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => '''hello''')", tableName, TEST_SCHEMA), ".*WHERE clause must evaluate to a boolean: actual type varchar\\(5\\)"); + assertQueryFails(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => '1001')", tableName, TEST_SCHEMA), ".*WHERE clause must evaluate to a boolean: actual type integer"); + assertQueryFails(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'a')", tableName, TEST_SCHEMA), ".*WHERE clause must evaluate to a boolean: actual type integer"); + assertQueryFails(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'n')", tableName, TEST_SCHEMA), ".*Column 'n' cannot be resolved"); + } + finally { + dropTable(tableName); + } + } + + private Table loadTable(String tableName) + { + Catalog catalog = CatalogUtil.loadCatalog(HadoopCatalog.class.getName(), ICEBERG_CATALOG, getProperties(), new Configuration()); + return catalog.loadTable(TableIdentifier.of(TEST_SCHEMA, tableName)); + } + + private Map getProperties() + { + File metastoreDir = getCatalogDirectory(); + return ImmutableMap.of("warehouse", metastoreDir.toString()); + } + + private File getCatalogDirectory() + { + Path dataDirectory = getDistributedQueryRunner().getCoordinator().getDataDirectory(); + Path catalogDirectory = getIcebergDataDirectoryPath(dataDirectory, HADOOP.name(), new IcebergConfig().getFileFormat(), false); + return catalogDirectory.toFile(); + } + + private void assertHasSize(Iterable iterable, int size) + { + AtomicInteger count = new AtomicInteger(0); + iterable.forEach(obj -> count.incrementAndGet()); + assertEquals(count.get(), size); + } + + private void assertHasDataFiles(Snapshot snapshot, int dataFilesCount) + { + Map map = snapshot.summary(); + int totalDataFiles = Integer.valueOf(map.get(TOTAL_DATA_FILES_PROP)); + assertEquals(totalDataFiles, dataFilesCount); + } + + private void assertHasDeleteFiles(Snapshot snapshot, int deleteFilesCount) + { + Map map = snapshot.summary(); + int totalDeleteFiles = Integer.valueOf(map.get(TOTAL_DELETE_FILES_PROP)); + assertEquals(totalDeleteFiles, deleteFilesCount); + } + + private void assertFilesPlan(CloseableIterator iterator, int dataFileCount, int deleteFileCount) + { + AtomicInteger dataCount = new AtomicInteger(0); + AtomicInteger deleteCount = new AtomicInteger(0); + while (iterator.hasNext()) { + FileScanTask fileScanTask = iterator.next(); + dataCount.incrementAndGet(); + deleteCount.addAndGet(fileScanTask.deletes().size()); + } + assertEquals(dataCount.get(), dataFileCount); + assertEquals(deleteCount.get(), deleteFileCount); + + try { + iterator.close(); + iterator = CloseableIterator.empty(); + } + catch (Exception e) { + // do nothing + } + } +} diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestRenameTableOnFragileFileSystem.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestRenameTableOnFragileFileSystem.java index 4e7f439cc50fc..4f6232645539c 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestRenameTableOnFragileFileSystem.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestRenameTableOnFragileFileSystem.java @@ -48,6 +48,7 @@ import com.facebook.presto.iceberg.IcebergTableType; import com.facebook.presto.iceberg.ManifestFileCache; import com.facebook.presto.iceberg.statistics.StatisticsFileCache; +import com.facebook.presto.metadata.BuiltInProcedureRegistry; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.spi.ConnectorSession; @@ -410,6 +411,7 @@ private ConnectorMetadata getIcebergHiveMetadata(ExtendedHiveMetastore metastore metastore, hdfsEnvironment, FUNCTION_AND_TYPE_MANAGER, + new BuiltInProcedureRegistry(METADATA.getFunctionAndTypeManager()), FUNCTION_RESOLUTION, ROW_EXPRESSION_SERVICE, jsonCodec(CommitTaskData.class), diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/TestIcebergSmokeRest.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/TestIcebergSmokeRest.java index 727076b088744..42e0df564402a 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/TestIcebergSmokeRest.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/TestIcebergSmokeRest.java @@ -142,10 +142,4 @@ public void testSetOauth2ServerUriPropertyI() assertEquals(catalog.properties().get(OAUTH2_SERVER_URI), authEndpoint); } - - @Override - public void testDeprecatedTablePropertiesCreateTable() - { - // v1 table create fails due to Iceberg REST catalog bug (see: https://github.com/apache/iceberg/issues/8756) - } } diff --git a/presto-kudu/src/main/java/com/facebook/presto/kudu/procedures/RangePartitionProcedures.java b/presto-kudu/src/main/java/com/facebook/presto/kudu/procedures/RangePartitionProcedures.java index 4fa52348f09fb..2144f01d1bf19 100644 --- a/presto-kudu/src/main/java/com/facebook/presto/kudu/procedures/RangePartitionProcedures.java +++ b/presto-kudu/src/main/java/com/facebook/presto/kudu/procedures/RangePartitionProcedures.java @@ -17,7 +17,7 @@ import com.facebook.presto.kudu.properties.KuduTableProperties; import com.facebook.presto.kudu.properties.RangePartition; import com.facebook.presto.spi.SchemaTableName; -import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; import jakarta.inject.Inject; @@ -43,9 +43,9 @@ public RangePartitionProcedures(KuduClientSession clientSession) this.clientSession = requireNonNull(clientSession); } - public Procedure getAddPartitionProcedure() + public LocalProcedure getAddPartitionProcedure() { - return new Procedure( + return new LocalProcedure( "system", "add_range_partition", ImmutableList.of(new Argument("schema", VARCHAR), new Argument("table", VARCHAR), @@ -53,9 +53,9 @@ public Procedure getAddPartitionProcedure() ADD.bindTo(this)); } - public Procedure getDropPartitionProcedure() + public LocalProcedure getDropPartitionProcedure() { - return new Procedure( + return new LocalProcedure( "system", "drop_range_partition", ImmutableList.of(new Argument("schema", VARCHAR), new Argument("table", VARCHAR), diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorContextInstance.java b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorContextInstance.java index c36ff87b0460b..8182e56359c9d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorContextInstance.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorContextInstance.java @@ -23,6 +23,7 @@ import com.facebook.presto.spi.function.FunctionMetadataManager; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpressionService; import static java.util.Objects.requireNonNull; @@ -32,6 +33,7 @@ public class ConnectorContextInstance { private final NodeManager nodeManager; private final TypeManager typeManager; + private final ProcedureRegistry procedureRegistry; private final FunctionMetadataManager functionMetadataManager; private final StandardFunctionResolution functionResolution; private final PageSorter pageSorter; @@ -44,6 +46,7 @@ public class ConnectorContextInstance public ConnectorContextInstance( NodeManager nodeManager, TypeManager typeManager, + ProcedureRegistry procedureRegistry, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution functionResolution, PageSorter pageSorter, @@ -55,6 +58,7 @@ public ConnectorContextInstance( { this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.procedureRegistry = requireNonNull(procedureRegistry, "procedureRegistry is null"); this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); this.pageSorter = requireNonNull(pageSorter, "pageSorter is null"); @@ -77,6 +81,12 @@ public TypeManager getTypeManager() return typeManager; } + @Override + public ProcedureRegistry getProcedureRegistry() + { + return procedureRegistry; + } + @Override public FunctionMetadataManager getFunctionMetadataManager() { diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java index dc85eea82f0e5..5c45f6c739758 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java @@ -53,6 +53,7 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.function.table.ConnectorTableFunction; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.DomainTranslator; import com.facebook.presto.spi.relation.PredicateCompiler; @@ -110,6 +111,7 @@ public class ConnectorManager private final HandleResolver handleResolver; private final InternalNodeManager nodeManager; private final TypeManager typeManager; + private final ProcedureRegistry procedureRegistry; private final PageSorter pageSorter; private final PageIndexerFactory pageIndexerFactory; private final NodeInfo nodeInfo; @@ -146,6 +148,7 @@ public ConnectorManager( InternalNodeManager nodeManager, NodeInfo nodeInfo, TypeManager typeManager, + ProcedureRegistry procedureRegistry, PageSorter pageSorter, PageIndexerFactory pageIndexerFactory, TransactionManager transactionManager, @@ -170,6 +173,7 @@ public ConnectorManager( this.handleResolver = requireNonNull(handleResolver, "handleResolver is null"); this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.procedureRegistry = requireNonNull(procedureRegistry, "procedureRegistry is null"); this.pageSorter = requireNonNull(pageSorter, "pageSorter is null"); this.pageIndexerFactory = requireNonNull(pageIndexerFactory, "pageIndexerFactory is null"); this.nodeInfo = requireNonNull(nodeInfo, "nodeInfo is null"); @@ -374,6 +378,7 @@ private Connector createConnector(ConnectorId connectorId, ConnectorFactory fact ConnectorContext context = new ConnectorContextInstance( new ConnectorAwareNodeManager(nodeManager, nodeInfo.getEnvironment(), connectorId), typeManager, + procedureRegistry, metadataManager.getFunctionAndTypeManager(), new FunctionResolution(metadataManager.getFunctionAndTypeManager().getFunctionAndTypeResolver()), pageSorter, diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/KillQueryProcedure.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/KillQueryProcedure.java index e236e51c43b33..bd5f5f8495ff2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/KillQueryProcedure.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/KillQueryProcedure.java @@ -18,7 +18,7 @@ import com.facebook.presto.execution.QueryState; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.QueryId; -import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; import jakarta.inject.Inject; @@ -73,9 +73,9 @@ public void killQuery(String queryId, String message) } } - public Procedure getProcedure() + public LocalProcedure getProcedure() { - return new Procedure( + return new LocalProcedure( "runtime", "kill_query", ImmutableList.builder() diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/CallTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/CallTask.java index da84f56afbed0..9b3b08e4f88c5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/CallTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/CallTask.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.procedure.Procedure.Argument; import com.facebook.presto.spi.security.AccessControl; @@ -58,6 +59,7 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PROCEDURE_ARGUMENTS; import static com.facebook.presto.sql.analyzer.utils.ParameterUtils.parameterExtractor; import static com.facebook.presto.sql.planner.ExpressionInterpreter.evaluateConstantExpression; +import static com.facebook.presto.util.Failures.checkArgument; import static com.facebook.presto.util.Failures.checkCondition; import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.base.Verify.verify; @@ -85,6 +87,48 @@ public ListenableFuture execute(Call call, TransactionManager transactionMana ConnectorId connectorId = getConnectorIdOrThrow(session, metadata, procedureName.getCatalogName(), call, catalogError); Procedure procedure = metadata.getProcedureRegistry().resolve(connectorId, toSchemaTableName(procedureName)); + Map, Expression> parameterLookup = parameterExtractor(call, parameters); + checkArgument(procedure instanceof LocalProcedure, "Must call an inner procedure in CallTask"); + LocalProcedure innerProcedure = (LocalProcedure) procedure; + Object[] values = extractParameterValuesInOrder(call, innerProcedure, metadata, session, parameterLookup); + + // validate arguments + MethodType methodType = innerProcedure.getMethodHandle().type(); + for (int i = 0; i < innerProcedure.getArguments().size(); i++) { + if ((values[i] == null) && methodType.parameterType(i).isPrimitive()) { + String name = innerProcedure.getArguments().get(i).getName(); + throw new PrestoException(INVALID_PROCEDURE_ARGUMENT, "Procedure argument cannot be null: " + name); + } + } + + // insert session argument + List arguments = new ArrayList<>(); + Iterator valuesIterator = asList(values).iterator(); + for (Class type : methodType.parameterList()) { + if (ConnectorSession.class.isAssignableFrom(type)) { + arguments.add(session.toConnectorSession(connectorId)); + } + else { + arguments.add(valuesIterator.next()); + } + } + + try { + innerProcedure.getMethodHandle().invokeWithArguments(arguments); + } + catch (Throwable t) { + if (t instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + throwIfInstanceOf(t, PrestoException.class); + throw new PrestoException(PROCEDURE_CALL_FAILED, t); + } + + return immediateFuture(null); + } + + public static Object[] extractParameterValuesInOrder(Call call, Procedure procedure, Metadata metadata, Session session, Map, Expression> parameterLookup) + { // map declared argument names to positions Map positions = new HashMap<>(); for (int i = 0; i < procedure.getArguments().size(); i++) { @@ -131,7 +175,6 @@ else if (i < procedure.getArguments().size()) { // get argument values Object[] values = new Object[procedure.getArguments().size()]; - Map, Expression> parameterLookup = parameterExtractor(call, parameters); for (Entry entry : names.entrySet()) { CallArgument callArgument = entry.getValue(); int index = positions.get(entry.getKey()); @@ -156,39 +199,7 @@ else if (i < procedure.getArguments().size()) { } } - // validate arguments - MethodType methodType = procedure.getMethodHandle().type(); - for (int i = 0; i < procedure.getArguments().size(); i++) { - if ((values[i] == null) && methodType.parameterType(i).isPrimitive()) { - String name = procedure.getArguments().get(i).getName(); - throw new PrestoException(INVALID_PROCEDURE_ARGUMENT, "Procedure argument cannot be null: " + name); - } - } - - // insert session argument - List arguments = new ArrayList<>(); - Iterator valuesIterator = asList(values).iterator(); - for (Class type : methodType.parameterList()) { - if (ConnectorSession.class.isAssignableFrom(type)) { - arguments.add(session.toConnectorSession(connectorId)); - } - else { - arguments.add(valuesIterator.next()); - } - } - - try { - procedure.getMethodHandle().invokeWithArguments(arguments); - } - catch (Throwable t) { - if (t instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - throwIfInstanceOf(t, PrestoException.class); - throw new PrestoException(PROCEDURE_CALL_FAILED, t); - } - - return immediateFuture(null); + return values; } private static Object toTypeObjectValue(Session session, Type type, Object value) diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ExecutionWriterTarget.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ExecutionWriterTarget.java index 2704ab5ab8960..97175e428d2a6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ExecutionWriterTarget.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ExecutionWriterTarget.java @@ -17,7 +17,9 @@ import com.facebook.drift.annotations.ThriftConstructor; import com.facebook.drift.annotations.ThriftField; import com.facebook.drift.annotations.ThriftStruct; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.metadata.DeleteTableHandle; +import com.facebook.presto.metadata.DistributedProcedureHandle; import com.facebook.presto.metadata.InsertTableHandle; import com.facebook.presto.metadata.OutputTableHandle; import com.facebook.presto.spi.SchemaTableName; @@ -35,7 +37,9 @@ @JsonSubTypes.Type(value = ExecutionWriterTarget.InsertHandle.class, name = "InsertHandle"), @JsonSubTypes.Type(value = ExecutionWriterTarget.DeleteHandle.class, name = "DeleteHandle"), @JsonSubTypes.Type(value = ExecutionWriterTarget.RefreshMaterializedViewHandle.class, name = "RefreshMaterializedViewHandle"), - @JsonSubTypes.Type(value = ExecutionWriterTarget.UpdateHandle.class, name = "UpdateHandle")}) + @JsonSubTypes.Type(value = ExecutionWriterTarget.UpdateHandle.class, name = "UpdateHandle"), + @JsonSubTypes.Type(value = ExecutionWriterTarget.ExecuteProcedureHandle.class, name = "ExecuteProcedureHandle") +}) @SuppressWarnings({"EmptyClass", "ClassMayBeInterface"}) public abstract class ExecutionWriterTarget { @@ -228,4 +232,47 @@ public String toString() return handle.toString(); } } + + public static class ExecuteProcedureHandle + extends ExecutionWriterTarget + { + private final DistributedProcedureHandle handle; + private final SchemaTableName schemaTableName; + private final QualifiedObjectName procedureName; + + @JsonCreator + public ExecuteProcedureHandle( + @JsonProperty("handle") DistributedProcedureHandle handle, + @JsonProperty("schemaTableName") SchemaTableName schemaTableName, + @JsonProperty("procedureName") QualifiedObjectName procedureName) + { + this.handle = requireNonNull(handle, "handle is null"); + this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); + this.procedureName = requireNonNull(procedureName, "procedureName is null"); + } + + @JsonProperty + public DistributedProcedureHandle getHandle() + { + return handle; + } + + @JsonProperty + public SchemaTableName getSchemaTableName() + { + return schemaTableName; + } + + @JsonProperty + public QualifiedObjectName getProcedureName() + { + return procedureName; + } + + @Override + public String toString() + { + return handle.toString(); + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/TableWriteInfo.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/TableWriteInfo.java index 2d2fa96b6882e..d1925e1937942 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/TableWriteInfo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/TableWriteInfo.java @@ -18,11 +18,13 @@ import com.facebook.drift.annotations.ThriftField; import com.facebook.drift.annotations.ThriftStruct; import com.facebook.presto.Session; +import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.ExecuteProcedureHandle; import com.facebook.presto.metadata.AnalyzeTableHandle; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.TableFinishNode; import com.facebook.presto.spi.plan.TableWriterNode; +import com.facebook.presto.spi.plan.TableWriterNode.CallDistributedProcedureTarget; import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.fasterxml.jackson.annotation.JsonCreator; @@ -101,6 +103,17 @@ private static Optional createWriterTarget(Optional
> connectorProcedures = new ConcurrentHashMap<>(); private final TypeManager typeManager; - public ProcedureRegistry(TypeManager typeManager) + @Inject + public BuiltInProcedureRegistry(TypeManager typeManager) { this.typeManager = requireNonNull(typeManager, "typeManager is null"); } + @Override public void addProcedures(ConnectorId connectorId, Collection procedures) { requireNonNull(connectorId, "connectorId is null"); @@ -71,11 +78,13 @@ public void addProcedures(ConnectorId connectorId, Collection procedu checkState(connectorProcedures.putIfAbsent(connectorId, proceduresByName) == null, "Procedures already registered for connector: %s", connectorId); } + @Override public void removeProcedures(ConnectorId connectorId) { connectorProcedures.remove(connectorId); } + @Override public Procedure resolve(ConnectorId connectorId, SchemaTableName name) { Map procedures = connectorProcedures.get(connectorId); @@ -88,14 +97,41 @@ public Procedure resolve(ConnectorId connectorId, SchemaTableName name) throw new PrestoException(PROCEDURE_NOT_FOUND, "Procedure not registered: " + name); } + @Override + public DistributedProcedure resolveDistributed(ConnectorId connectorId, SchemaTableName name) + { + Map procedures = connectorProcedures.get(connectorId); + if (procedures != null) { + Procedure procedure = procedures.get(name); + if (procedure != null && procedure instanceof DistributedProcedure) { + return (DistributedProcedure) procedure; + } + } + throw new PrestoException(PROCEDURE_NOT_FOUND, "Distributed procedure not registered: " + name); + } + + @Override + public boolean isDistributedProcedure(ConnectorId connectorId, SchemaTableName name) + { + Map procedures = connectorProcedures.get(connectorId); + return procedures != null && + procedures.containsKey(name) && + procedures.get(name) instanceof DistributedProcedure; + } + private void validateProcedure(Procedure procedure) { - List> parameters = procedure.getMethodHandle().type().parameterList().stream() + if (procedure instanceof DistributedProcedure) { + return; + } + + LocalProcedure innerProcedure = (LocalProcedure) procedure; + List> parameters = innerProcedure.getMethodHandle().type().parameterList().stream() .filter(type -> !ConnectorSession.class.isAssignableFrom(type)) .collect(toList()); for (int i = 0; i < procedure.getArguments().size(); i++) { - Argument argument = procedure.getArguments().get(i); + Argument argument = innerProcedure.getArguments().get(i); Type type = typeManager.getType(argument.getType()); Class argumentType = Primitives.unwrap(parameters.get(i)); diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/DelegatingMetadataManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/DelegatingMetadataManager.java index 45c1cc7af319b..3d45b27e77e58 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/DelegatingMetadataManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/DelegatingMetadataManager.java @@ -38,6 +38,7 @@ import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.function.SqlFunction; import com.facebook.presto.spi.plan.PartitioningHandle; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.security.GrantInfo; import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.security.Privilege; @@ -400,6 +401,18 @@ public Optional finishDeleteWithOutput(Session session, return delegate.finishDeleteWithOutput(session, tableHandle, fragments); } + @Override + public DistributedProcedureHandle beginCallDistributedProcedure(Session session, QualifiedObjectName procedureName, TableHandle tableHandle, Object[] arguments) + { + return delegate.beginCallDistributedProcedure(session, procedureName, tableHandle, arguments); + } + + @Override + public void finishCallDistributedProcedure(Session session, DistributedProcedureHandle procedureHandle, QualifiedObjectName procedureName, Collection fragments) + { + delegate.finishCallDistributedProcedure(session, procedureHandle, procedureName, fragments); + } + @Override public TableHandle beginUpdate(Session session, TableHandle tableHandle, List updatedColumns) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/DistributedProcedureHandle.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/DistributedProcedureHandle.java new file mode 100644 index 0000000000000..1d3776b3ecca8 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/DistributedProcedureHandle.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public final class DistributedProcedureHandle +{ + private final ConnectorId connectorId; + private final ConnectorTransactionHandle transactionHandle; + private final ConnectorDistributedProcedureHandle connectorHandle; + + @JsonCreator + public DistributedProcedureHandle( + @JsonProperty("connectorId") ConnectorId connectorId, + @JsonProperty("transactionHandle") ConnectorTransactionHandle transactionHandle, + @JsonProperty("connectorHandle") ConnectorDistributedProcedureHandle connectorHandle) + { + this.connectorId = requireNonNull(connectorId, "connectorId is null"); + this.transactionHandle = requireNonNull(transactionHandle, "transactionHandle is null"); + this.connectorHandle = requireNonNull(connectorHandle, "connectorHandle is null"); + } + + @JsonProperty + public ConnectorId getConnectorId() + { + return connectorId; + } + + @JsonProperty + public ConnectorTransactionHandle getTransactionHandle() + { + return transactionHandle; + } + + @JsonProperty + public ConnectorDistributedProcedureHandle getConnectorHandle() + { + return connectorHandle; + } + + @Override + public int hashCode() + { + return Objects.hash(connectorId, transactionHandle, connectorHandle); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + DistributedProcedureHandle o = (DistributedProcedureHandle) obj; + return Objects.equals(this.connectorId, o.connectorId) && + Objects.equals(this.transactionHandle, o.transactionHandle) && + Objects.equals(this.connectorHandle, o.connectorHandle); + } + + @Override + public String toString() + { + return connectorId + ":" + connectorHandle; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/DistributedProcedureHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/DistributedProcedureHandleJacksonModule.java new file mode 100644 index 0000000000000..ff9cd22861ad5 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/DistributedProcedureHandleJacksonModule.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; + +import javax.inject.Inject; + +public class DistributedProcedureHandleJacksonModule + extends AbstractTypedJacksonModule +{ + @Inject + public DistributedProcedureHandleJacksonModule(HandleResolver handleResolver) + { + super(ConnectorDistributedProcedureHandle.class, + handleResolver::getId, + handleResolver::getDistributedProcedureHandleClass); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java index 952fc38be9426..c1b80caa2ec27 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java @@ -33,6 +33,7 @@ public void configure(Binder binder) jsonBinder(binder).addModuleBinding().to(OutputTableHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(InsertTableHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(DeleteTableHandleJacksonModule.class); + jsonBinder(binder).addModuleBinding().to(DistributedProcedureHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(IndexHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(TransactionHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(PartitioningHandleJacksonModule.class); diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java index 30630f0a7bff7..d9bbf61d95e71 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java @@ -17,6 +17,7 @@ import com.facebook.presto.connector.system.SystemHandleResolver; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorHandleResolver; import com.facebook.presto.spi.ConnectorIndexHandle; import com.facebook.presto.spi.ConnectorInsertTableHandle; @@ -119,6 +120,11 @@ public String getId(ConnectorDeleteTableHandle deleteHandle) return getId(deleteHandle, MaterializedHandleResolver::getDeleteTableHandleClass); } + public String getId(ConnectorDistributedProcedureHandle distributedProcedureHandle) + { + return getId(distributedProcedureHandle, MaterializedHandleResolver::getDistributedProcedureHandleClass); + } + public String getId(ConnectorPartitioningHandle partitioningHandle) { return getId(partitioningHandle, MaterializedHandleResolver::getPartitioningHandleClass); @@ -174,6 +180,11 @@ public Class getDeleteTableHandleClass(Str return resolverFor(id).getDeleteTableHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); } + public Class getDistributedProcedureHandleClass(String id) + { + return resolverFor(id).getDistributedProcedureHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); + } + public Class getPartitioningHandleClass(String id) { return resolverFor(id).getPartitioningHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); @@ -241,6 +252,7 @@ private static class MaterializedHandleResolver private final Optional> outputTableHandle; private final Optional> insertTableHandle; private final Optional> deleteTableHandle; + private final Optional> distributedProcedureHandle; private final Optional> partitioningHandle; private final Optional> transactionHandle; @@ -256,6 +268,7 @@ public MaterializedHandleResolver(ConnectorHandleResolver resolver) deleteTableHandle = getHandleClass(resolver::getDeleteTableHandleClass); partitioningHandle = getHandleClass(resolver::getPartitioningHandleClass); transactionHandle = getHandleClass(resolver::getTransactionHandleClass); + distributedProcedureHandle = getHandleClass(resolver::getDistributedProcedureHandleClass); } private static Optional> getHandleClass(Supplier> callable) @@ -308,6 +321,11 @@ public Optional> getDeleteTableHandl return deleteTableHandle; } + public Optional> getDistributedProcedureHandleClass() + { + return distributedProcedureHandle; + } + public Optional> getPartitioningHandleClass() { return partitioningHandle; diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java index 050702ff6ad38..18e33ca8954cb 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java @@ -45,6 +45,7 @@ import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.function.SqlFunction; import com.facebook.presto.spi.plan.PartitioningHandle; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.security.GrantInfo; @@ -341,6 +342,16 @@ public interface Metadata */ Optional finishDeleteWithOutput(Session session, DeleteTableHandle tableHandle, Collection fragments); + /** + * Begin call distributed procedure + */ + DistributedProcedureHandle beginCallDistributedProcedure(Session session, QualifiedObjectName procedureName, TableHandle tableHandle, Object[] arguments); + + /** + * Finish call distributed procedure + */ + void finishCallDistributedProcedure(Session session, DistributedProcedureHandle procedureHandle, QualifiedObjectName procedureName, Collection fragments); + /** * Begin update query */ diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java index a1be6217dbb60..18badb5ec600f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java @@ -30,6 +30,7 @@ import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorInsertTableHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; @@ -65,6 +66,7 @@ import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.function.SqlFunction; import com.facebook.presto.spi.plan.PartitioningHandle; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.security.GrantInfo; @@ -172,6 +174,30 @@ public MetadataManager( ColumnPropertyManager columnPropertyManager, AnalyzePropertyManager analyzePropertyManager, TransactionManager transactionManager) + { + this( + functionAndTypeManager, + blockEncodingSerde, + sessionPropertyManager, + schemaPropertyManager, + tablePropertyManager, + columnPropertyManager, + analyzePropertyManager, + transactionManager, + new BuiltInProcedureRegistry(functionAndTypeManager)); + } + + @VisibleForTesting + public MetadataManager( + FunctionAndTypeManager functionAndTypeManager, + BlockEncodingSerde blockEncodingSerde, + SessionPropertyManager sessionPropertyManager, + SchemaPropertyManager schemaPropertyManager, + TablePropertyManager tablePropertyManager, + ColumnPropertyManager columnPropertyManager, + AnalyzePropertyManager analyzePropertyManager, + TransactionManager transactionManager, + ProcedureRegistry procedureRegistry) { this( createTestingViewCodec(functionAndTypeManager), @@ -182,7 +208,8 @@ public MetadataManager( columnPropertyManager, analyzePropertyManager, transactionManager, - functionAndTypeManager); + functionAndTypeManager, + procedureRegistry); } @Inject @@ -195,7 +222,8 @@ public MetadataManager( ColumnPropertyManager columnPropertyManager, AnalyzePropertyManager analyzePropertyManager, TransactionManager transactionManager, - FunctionAndTypeManager functionAndTypeManager) + FunctionAndTypeManager functionAndTypeManager, + ProcedureRegistry procedureRegistry) { this.viewCodec = requireNonNull(viewCodec, "viewCodec is null"); this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); @@ -206,7 +234,7 @@ public MetadataManager( this.analyzePropertyManager = requireNonNull(analyzePropertyManager, "analyzePropertyManager is null"); this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionManager is null"); - this.procedures = new ProcedureRegistry(functionAndTypeManager); + this.procedures = requireNonNull(procedureRegistry, "procedureRegistry is null"); verifyComparableOrderableContract(); } @@ -255,6 +283,21 @@ public static MetadataManager createTestMetadataManager(TransactionManager trans transactionManager); } + public static MetadataManager createTestMetadataManager(TransactionManager transactionManager, FeaturesConfig featuresConfig, FunctionsConfig functionsConfig, ProcedureRegistry procedureRegistry) + { + BlockEncodingManager blockEncodingManager = new BlockEncodingManager(); + return new MetadataManager( + new FunctionAndTypeManager(transactionManager, new TableFunctionRegistry(), blockEncodingManager, featuresConfig, functionsConfig, new HandleResolver(), ImmutableSet.of()), + blockEncodingManager, + createTestingSessionPropertyManager(), + new SchemaPropertyManager(), + new TablePropertyManager(), + new ColumnPropertyManager(), + new AnalyzePropertyManager(), + transactionManager, + procedureRegistry); + } + @Override public final void verifyComparableOrderableContract() { @@ -960,6 +1003,40 @@ public Optional finishDeleteWithOutput(Session session, return metadata.finishDeleteWithOutput(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), fragments); } + @Override + public DistributedProcedureHandle beginCallDistributedProcedure(Session session, QualifiedObjectName procedureName, TableHandle tableHandle, Object[] arguments) + { + ConnectorId connectorId = tableHandle.getConnectorId(); + CatalogMetadata catalogMetadata = getCatalogMetadataForWrite(session, connectorId); + + ConnectorTableLayoutHandle layout; + if (!tableHandle.getLayout().isPresent()) { + TableLayoutResult result = getLayout(session, tableHandle, Constraint.alwaysTrue(), Optional.empty()); + layout = result.getLayout().getLayoutHandle(); + } + else { + layout = tableHandle.getLayout().get(); + } + + ConnectorDistributedProcedureHandle procedureHandle = catalogMetadata.getMetadata().beginCallDistributedProcedure( + session.toConnectorSession(connectorId), + procedureName, + layout, + arguments); + return new DistributedProcedureHandle( + tableHandle.getConnectorId(), + tableHandle.getTransaction(), + procedureHandle); + } + + @Override + public void finishCallDistributedProcedure(Session session, DistributedProcedureHandle procedureHandle, QualifiedObjectName procedureName, Collection fragments) + { + ConnectorId connectorId = procedureHandle.getConnectorId(); + ConnectorMetadata metadata = getMetadata(session, connectorId); + metadata.finishCallDistributedProcedure(session.toConnectorSession(connectorId), procedureHandle.getConnectorHandle(), procedureName, fragments); + } + @Override public TableHandle beginUpdate(Session session, TableHandle tableHandle, List updatedColumns) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataUtil.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataUtil.java index 40372406e167c..d970d12216012 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataUtil.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataUtil.java @@ -21,13 +21,13 @@ import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorTableHandle; import com.facebook.presto.spi.ConnectorTableMetadata; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.connector.ConnectorMetadata; import com.facebook.presto.spi.connector.ConnectorTableVersion; import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.sql.analyzer.SemanticException; +import com.facebook.presto.sql.analyzer.utils.MetadataUtils; import com.facebook.presto.sql.tree.GrantorSpecification; import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.Node; @@ -37,13 +37,12 @@ import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Lists; import java.util.List; import java.util.Optional; +import java.util.function.BiFunction; import static com.facebook.presto.connector.informationSchema.InformationSchemaMetadata.INFORMATION_SCHEMA; -import static com.facebook.presto.spi.StandardErrorCode.SYNTAX_ERROR; import static com.facebook.presto.spi.security.PrincipalType.ROLE; import static com.facebook.presto.spi.security.PrincipalType.USER; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.CATALOG_NOT_SPECIFIED; @@ -153,24 +152,8 @@ public static CatalogSchemaName createCatalogSchemaName(Session session, Node no public static QualifiedObjectName createQualifiedObjectName(Session session, Node node, QualifiedName name, Metadata metadata) { - requireNonNull(session, "session is null"); - requireNonNull(name, "name is null"); - if (name.getOriginalParts().size() > 3) { - throw new PrestoException(SYNTAX_ERROR, format("Too many dots in table name: %s", name)); - } - - List parts = Lists.reverse(name.getOriginalParts()); - String objectName = parts.get(0).getValue(); - String schemaName = (parts.size() > 1) ? parts.get(1).getValue() : session.getSchema().orElseThrow(() -> - new SemanticException(SCHEMA_NOT_SPECIFIED, node, "Schema must be specified when session schema is not set")); - String catalogName = (parts.size() > 2) ? parts.get(2).getValue() : session.getCatalog().orElseThrow(() -> - new SemanticException(CATALOG_NOT_SPECIFIED, node, "Catalog must be specified when session catalog is not set")); - - catalogName = catalogName.toLowerCase(ENGLISH); - schemaName = metadata.normalizeIdentifier(session, catalogName, schemaName); - objectName = metadata.normalizeIdentifier(session, catalogName, objectName); - - return new QualifiedObjectName(catalogName, schemaName, objectName); + BiFunction normalizer = (catalogName, objectName) -> metadata.normalizeIdentifier(session, catalogName, objectName); + return MetadataUtils.createQualifiedObjectName(session.getCatalog(), session.getSchema(), node, name, normalizer); } public static Optional getOptionalCatalogMetadata(Session session, TransactionManager transactionManager, String catalogName) diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/TableWriterOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/TableWriterOperator.java index 5deda5f3ffb94..781b822dab73d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/TableWriterOperator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/TableWriterOperator.java @@ -28,6 +28,7 @@ import com.facebook.presto.execution.TaskId; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.CreateHandle; +import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.ExecuteProcedureHandle; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.InsertHandle; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.RefreshMaterializedViewHandle; import com.facebook.presto.memory.context.LocalMemoryContext; @@ -112,8 +113,11 @@ public TableWriterOperatorFactory( this.notNullChannelColumnNames = requireNonNull(notNullChannelColumnNames, "notNullChannelColumnNames is null"); this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null"); checkArgument( - writerTarget instanceof CreateHandle || writerTarget instanceof InsertHandle || writerTarget instanceof RefreshMaterializedViewHandle, - "writerTarget must be CreateHandle or InsertHandle or RefreshMaterializedViewHandle"); + writerTarget instanceof CreateHandle || + writerTarget instanceof InsertHandle || + writerTarget instanceof RefreshMaterializedViewHandle || + writerTarget instanceof ExecuteProcedureHandle, + "writerTarget must be CreateHandle or InsertHandle or RefreshMaterializedViewHandle or TableExecuteHandle"); this.target = requireNonNull(writerTarget, "writerTarget is null"); this.session = session; this.statisticsAggregationOperatorFactory = requireNonNull(statisticsAggregationOperatorFactory, "statisticsAggregationOperatorFactory is null"); @@ -157,6 +161,9 @@ private ConnectorPageSink createPageSink(OperatorContext operatorContext) if (target instanceof RefreshMaterializedViewHandle) { return pageSinkManager.createPageSink(session, ((RefreshMaterializedViewHandle) target).getHandle(), pageSinkContextBuilder.build(), runtimeStats); } + if (target instanceof ExecuteProcedureHandle) { + return pageSinkManager.createPageSink(session, ((ExecuteProcedureHandle) target).getHandle(), pageSinkContextBuilder.build()); + } throw new UnsupportedOperationException("Unhandled target type: " + target.getClass().getName()); } @@ -174,6 +181,9 @@ private static ConnectorId getConnectorId(ExecutionWriterTarget handle) return ((RefreshMaterializedViewHandle) handle).getHandle().getConnectorId(); } + if (handle instanceof ExecuteProcedureHandle) { + return ((ExecuteProcedureHandle) handle).getHandle().getConnectorId(); + } throw new UnsupportedOperationException("Unhandled target type: " + handle.getClass().getName()); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/split/PageSinkManager.java b/presto-main-base/src/main/java/com/facebook/presto/split/PageSinkManager.java index 39f6a860ac1d5..ad02188c7814c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/split/PageSinkManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/split/PageSinkManager.java @@ -15,6 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.common.RuntimeStats; +import com.facebook.presto.metadata.DistributedProcedureHandle; import com.facebook.presto.metadata.InsertTableHandle; import com.facebook.presto.metadata.OutputTableHandle; import com.facebook.presto.spi.ConnectorId; @@ -73,6 +74,14 @@ public ConnectorPageSink createPageSink(Session session, InsertTableHandle table return createPageSink(session, tableHandle, pageSinkContext, null); } + @Override + public ConnectorPageSink createPageSink(Session session, DistributedProcedureHandle procedureHandle, PageSinkContext pageSinkContext) + { + // assumes connectorId and catalog are the same + ConnectorSession connectorSession = session.toConnectorSession(procedureHandle.getConnectorId()); + return providerFor(procedureHandle.getConnectorId()).createPageSink(procedureHandle.getTransactionHandle(), connectorSession, procedureHandle.getConnectorHandle(), pageSinkContext); + } + private ConnectorPageSinkProvider providerFor(ConnectorId connectorId) { ConnectorPageSinkProvider provider = pageSinkProviders.get(connectorId); diff --git a/presto-main-base/src/main/java/com/facebook/presto/split/PageSinkProvider.java b/presto-main-base/src/main/java/com/facebook/presto/split/PageSinkProvider.java index 3e46127870194..8da7105c7c045 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/split/PageSinkProvider.java +++ b/presto-main-base/src/main/java/com/facebook/presto/split/PageSinkProvider.java @@ -14,6 +14,7 @@ package com.facebook.presto.split; import com.facebook.presto.Session; +import com.facebook.presto.metadata.DistributedProcedureHandle; import com.facebook.presto.metadata.InsertTableHandle; import com.facebook.presto.metadata.OutputTableHandle; import com.facebook.presto.spi.ConnectorPageSink; @@ -24,4 +25,6 @@ public interface PageSinkProvider ConnectorPageSink createPageSink(Session session, OutputTableHandle tableHandle, PageSinkContext pageSinkContext); ConnectorPageSink createPageSink(Session session, InsertTableHandle tableHandle, PageSinkContext pageSinkContext); + + ConnectorPageSink createPageSink(Session session, DistributedProcedureHandle procedureHandle, PageSinkContext pageSinkContext); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/Analyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/Analyzer.java index 3ecc50bfa7fb1..8d2f51bf6ab2d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/Analyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/Analyzer.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.analyzer; import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.analyzer.AccessControlReferences; @@ -116,11 +117,20 @@ public Analysis analyze(Statement statement, boolean isDescribe) } public Analysis analyzeSemantic(Statement statement, boolean isDescribe) + { + return analyzeSemantic(statement, Optional.empty(), isDescribe); + } + + public Analysis analyzeSemantic( + Statement statement, + Optional procedureName, + boolean isDescribe) { Statement rewrittenStatement = StatementRewrite.rewrite(session, metadata, sqlParser, queryExplainer, statement, parameters, parameterLookup, accessControl, warningCollector, query); Analysis analysis = new Analysis(rewrittenStatement, parameterLookup, isDescribe); metadataExtractor.populateMetadataHandle(session, rewrittenStatement, analysis.getMetadataHandle()); + analysis.setProcedureName(procedureName); StatementAnalyzer analyzer = new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, session, warningCollector); analyzer.analyze(rewrittenStatement, Optional.empty()); analyzeForUtilizedColumns(analysis, analysis.getStatement(), warningCollector); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryAnalyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryAnalyzer.java index 9308927b49c6f..d5ab8c544c14c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryAnalyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryAnalyzer.java @@ -92,7 +92,10 @@ public QueryAnalysis analyze(AnalyzerContext analyzerContext, PreparedQuery prep Optional.of(metadataExtractorExecutor), analyzerContext.getQuery()); - Analysis analysis = analyzer.analyzeSemantic(((BuiltInQueryPreparer.BuiltInPreparedQuery) preparedQuery).getStatement(), false); + Analysis analysis = analyzer.analyzeSemantic( + ((BuiltInQueryPreparer.BuiltInPreparedQuery) preparedQuery).getStatement(), + ((BuiltInQueryPreparer.BuiltInPreparedQuery) preparedQuery).getDistributedProcedureName(), + false); return new BuiltInQueryAnalysis(analysis); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java index 23dabe8ef8da0..35cb0c287c999 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java @@ -70,6 +70,8 @@ import com.facebook.presto.spi.function.table.TableArgument; import com.facebook.presto.spi.function.table.TableArgumentSpecification; import com.facebook.presto.spi.function.table.TableFunctionAnalysis; +import com.facebook.presto.spi.procedure.DistributedProcedure; +import com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure; import com.facebook.presto.spi.relation.DomainTranslator; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.security.AccessControl; @@ -239,6 +241,7 @@ import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.common.type.UnknownType.UNKNOWN; import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.execution.CallTask.extractParameterValuesInOrder; import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; import static com.facebook.presto.metadata.MetadataUtil.getConnectorIdOrThrow; import static com.facebook.presto.metadata.MetadataUtil.toSchemaTableName; @@ -295,6 +298,7 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISMATCHED_COLUMN_ALIASES; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISMATCHED_SET_COLUMN_TYPES; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_ATTRIBUTE; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_CATALOG; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_COLUMN; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_MATERIALIZED_VIEW; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_SCHEMA; @@ -305,6 +309,7 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NON_NUMERIC_SAMPLE_PERCENTAGE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.ORDER_BY_MUST_BE_IN_SELECT; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.PROCEDURE_NOT_FOUND; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_ALREADY_EXISTS; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_AMBIGUOUS_RETURN_TYPE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_COLUMN_NOT_FOUND; @@ -404,7 +409,7 @@ public Scope analyze(Node node, Scope outerQueryScope) public Scope analyze(Node node, Optional outerQueryScope) { - return new Visitor(outerQueryScope, warningCollector).process(node, Optional.empty()); + return new Visitor(metadata, session, outerQueryScope, warningCollector).process(node, Optional.empty()); } /** @@ -415,11 +420,19 @@ public Scope analyze(Node node, Optional outerQueryScope) private class Visitor extends DefaultTraversalVisitor> { + private final Metadata metadata; + private final Session session; private final Optional outerQueryScope; private final WarningCollector warningCollector; - private Visitor(Optional outerQueryScope, WarningCollector warningCollector) + private Visitor( + Metadata metadata, + Session session, + Optional outerQueryScope, + WarningCollector warningCollector) { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.session = requireNonNull(session, "session is null"); this.outerQueryScope = requireNonNull(outerQueryScope, "outerQueryScope is null"); this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); } @@ -1201,9 +1214,60 @@ protected Scope visitRevoke(Revoke node, Optional scope) } @Override - protected Scope visitCall(Call node, Optional scope) + protected Scope visitCall(Call call, Optional scope) { - return createAndAssignScope(node, scope); + if (analysis.isDescribe()) { + return createAndAssignScope(call, scope); + } + Optional procedureNameOptional = analysis.getProcedureName(); + QualifiedObjectName procedureName; + if (!procedureNameOptional.isPresent()) { + procedureName = createQualifiedObjectName(session, call, call.getName(), metadata); + analysis.setProcedureName(Optional.of(procedureName)); + } + else { + procedureName = procedureNameOptional.get(); + } + ConnectorId connectorId = metadata.getCatalogHandle(session, procedureName.getCatalogName()) + .orElseThrow(() -> new SemanticException(MISSING_CATALOG, call, "Catalog %s does not exist", procedureName.getCatalogName())); + + if (!metadata.getProcedureRegistry().isDistributedProcedure(connectorId, toSchemaTableName(procedureName))) { + throw new SemanticException(PROCEDURE_NOT_FOUND, "Distributed procedure not registered: " + procedureName); + } + DistributedProcedure procedure = metadata.getProcedureRegistry().resolveDistributed(connectorId, toSchemaTableName(procedureName)); + Object[] values = extractParameterValuesInOrder(call, procedure, metadata, session, analysis.getParameters()); + + analysis.setUpdateType("CALL"); + analysis.setDistributedProcedureType(Optional.of(procedure.getType())); + analysis.setProcedureArguments(Optional.of(values)); + switch (procedure.getType()) { + case TABLE_DATA_REWRITE: + TableDataRewriteDistributedProcedure tableDataRewriteDistributedProcedure = (TableDataRewriteDistributedProcedure) procedure; + QualifiedName qualifiedName = QualifiedName.of(tableDataRewriteDistributedProcedure.getSchema(values), tableDataRewriteDistributedProcedure.getTableName(values)); + QualifiedObjectName tableName = createQualifiedObjectName(session, call, qualifiedName, metadata); + + String filter = tableDataRewriteDistributedProcedure.getFilter(values); + Expression filterExpression = sqlParser.createExpression(filter); + QuerySpecification querySpecification = new QuerySpecification( + selectList(new AllColumns()), + Optional.of(new Table(qualifiedName)), + Optional.of(filterExpression), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + analyze(querySpecification, scope); + analysis.setTargetQuery(querySpecification); + + TableHandle tableHandle = metadata.getHandleVersion(session, tableName, Optional.empty()) + .orElseThrow(() -> (new SemanticException(MISSING_TABLE, call, "Table '%s' does not exist", tableName))); + analysis.setCallTarget(tableHandle); + break; + default: + throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "Unsupported distributed procedure type: " + procedure.getType()); + } + return createAndAssignScope(call, scope, Field.newUnqualified(Optional.empty(), "rows", BIGINT)); } private void validateProperties(List properties, Optional scope) diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java index bc77373a5ecbd..ee4c0d41b178d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java @@ -41,6 +41,7 @@ import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; @@ -263,6 +264,15 @@ public PlanNode visitTableWriter(TableWriterNode node, RewriteContext context) + { + if (node.getPartitioningScheme().isPresent()) { + context.get().setDistribution(node.getPartitioningScheme().get().getPartitioning().getHandle(), metadata, session); + } + return context.defaultRewrite(node, context.get()); + } + @Override public PlanNode visitValues(ValuesNode node, RewriteContext context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/GroupedExecutionTagger.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/GroupedExecutionTagger.java index 71d1199bf9da8..5c96e0c69177c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/GroupedExecutionTagger.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/GroupedExecutionTagger.java @@ -27,7 +27,9 @@ import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TableWriterNode; +import com.facebook.presto.spi.plan.TableWriterNode.CallDistributedProcedureTarget; import com.facebook.presto.spi.plan.WindowNode; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; @@ -231,6 +233,22 @@ public GroupedExecutionTagger.GroupedExecutionProperties visitMarkDistinct(MarkD return GroupedExecutionTagger.GroupedExecutionProperties.notCapable(); } + @Override + public GroupedExecutionTagger.GroupedExecutionProperties visitCallDistributedProcedure(CallDistributedProcedureNode node, Void context) + { + GroupedExecutionTagger.GroupedExecutionProperties properties = node.getSource().accept(this, null); + boolean recoveryEligible = properties.isRecoveryEligible(); + CallDistributedProcedureTarget target = node.getTarget().orElseThrow(() -> new VerifyException("target is absent")); + recoveryEligible &= metadata.getConnectorCapabilities(session, target.getConnectorId()).contains(SUPPORTS_PAGE_SINK_COMMIT); + + return new GroupedExecutionTagger.GroupedExecutionProperties( + properties.isCurrentNodeCapable(), + properties.isSubTreeUseful(), + properties.getCapableTableScanNodes(), + properties.getTotalLifespans(), + recoveryEligible); + } + @Override public GroupedExecutionTagger.GroupedExecutionProperties visitTableWriter(TableWriterNode node, Void context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index b5a460f033370..063174a954eae 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -36,6 +36,7 @@ import com.facebook.presto.execution.scheduler.ExecutionWriterTarget; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.CreateHandle; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.DeleteHandle; +import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.ExecuteProcedureHandle; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.InsertHandle; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.RefreshMaterializedViewHandle; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.UpdateHandle; @@ -201,6 +202,7 @@ import com.facebook.presto.sql.gen.PageFunctionCompiler; import com.facebook.presto.sql.planner.optimizations.IndexJoinOptimizer; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; @@ -2672,6 +2674,46 @@ public PhysicalOperation visitSemiJoin(SemiJoinNode node, LocalExecutionPlanCont return new PhysicalOperation(operator, outputMappings, context, probeSource); } + @Override + public PhysicalOperation visitCallDistributedProcedure(CallDistributedProcedureNode node, LocalExecutionPlanContext context) + { + // Set table writer count + if (node.getPartitioningScheme().isPresent()) { + context.setDriverInstanceCount(getTaskPartitionedWriterCount(session)); + } + else { + context.setDriverInstanceCount(getTaskWriterCount(session)); + } + + PhysicalOperation source = node.getSource().accept(this, context); + + ImmutableMap.Builder outputMapping = ImmutableMap.builder(); + outputMapping.put(node.getRowCountVariable(), ROW_COUNT_CHANNEL); + outputMapping.put(node.getFragmentVariable(), FRAGMENT_CHANNEL); + outputMapping.put(node.getTableCommitContextVariable(), CONTEXT_CHANNEL); + + List inputChannels = node.getColumns().stream() + .map(source::variableToChannel) + .collect(toImmutableList()); + List notNullChannelColumnNames = node.getColumns().stream() + .map(variable -> node.getNotNullColumnVariables().contains(variable) ? node.getColumnNames().get(source.variableToChannel(variable)) : null) + .collect(Collectors.toList()); + + OperatorFactory operatorFactory = new TableWriterOperatorFactory( + context.getNextOperatorId(), + node.getId(), + pageSinkManager, + context.getTableWriteInfo().getWriterTarget().orElseThrow(() -> new VerifyException("writerTarget is absent")), + inputChannels, + notNullChannelColumnNames, + session, + new DevNullOperatorFactory(context.getNextOperatorId(), node.getId()), // statistics are not calculated + getVariableTypes(node.getOutputVariables()), + tableCommitContextCodec, + getPageSinkCommitStrategy()); + return new PhysicalOperation(operatorFactory, outputMapping.build(), context, source); + } + @Override public PhysicalOperation visitTableWriter(TableWriterNode node, LocalExecutionPlanContext context) { @@ -3494,6 +3536,10 @@ else if (target instanceof UpdateHandle) { metadata.finishUpdate(session, ((UpdateHandle) target).getHandle(), fragments); return Optional.empty(); } + else if (target instanceof ExecuteProcedureHandle) { + metadata.finishCallDistributedProcedure(session, ((ExecuteProcedureHandle) target).getHandle(), ((ExecuteProcedureHandle) target).getProcedureName(), fragments); + return Optional.empty(); + } else { throw new AssertionError("Unhandled target type: " + target.getClass().getName()); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java index bbc8aaca7f48a..414d2d9b47bd9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java @@ -18,6 +18,8 @@ import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.TableLayout; +import com.facebook.presto.metadata.TableLayout.TablePartitioning; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorId; @@ -33,6 +35,7 @@ import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.Partitioning; +import com.facebook.presto.spi.plan.PartitioningHandle; import com.facebook.presto.spi.plan.PartitioningScheme; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; @@ -41,6 +44,7 @@ import com.facebook.presto.spi.plan.TableFinishNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TableWriterNode; +import com.facebook.presto.spi.plan.TableWriterNode.CallDistributedProcedureTarget; import com.facebook.presto.spi.plan.TableWriterNode.DeleteHandle; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.relation.RowExpression; @@ -53,10 +57,12 @@ import com.facebook.presto.sql.analyzer.Scope; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.StatisticsAggregationPlanner.TableStatisticAggregation; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.UpdateNode; import com.facebook.presto.sql.tree.Analyze; +import com.facebook.presto.sql.tree.Call; import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.CreateTableAsSelect; import com.facebook.presto.sql.tree.Delete; @@ -70,6 +76,7 @@ import com.facebook.presto.sql.tree.NullLiteral; import com.facebook.presto.sql.tree.Parameter; import com.facebook.presto.sql.tree.Query; +import com.facebook.presto.sql.tree.QuerySpecification; import com.facebook.presto.sql.tree.RefreshMaterializedView; import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.sql.tree.Update; @@ -85,12 +92,14 @@ import java.util.Map.Entry; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; import static com.facebook.presto.metadata.MetadataUtil.getConnectorIdOrThrow; import static com.facebook.presto.metadata.MetadataUtil.toSchemaTableName; import static com.facebook.presto.spi.PartitionedTableWritePolicy.MULTIPLE_WRITERS_PER_PARTITION_ALLOWED; +import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet; import static com.facebook.presto.spi.plan.LimitNode.Step.FINAL; @@ -171,6 +180,15 @@ private RelationPlan planStatementWithoutOutput(Analysis analysis, Statement sta else if (statement instanceof Analyze) { return createAnalyzePlan(analysis, (Analyze) statement); } + else if (statement instanceof Call) { + checkState(analysis.getDistributedProcedureType().isPresent(), "Call distributed procedure analysis is missing"); + switch (analysis.getDistributedProcedureType().get()) { + case TABLE_DATA_REWRITE: + return createCallDistributedProcedurePlanForTableDataRewrite(analysis, (Call) statement); + default: + throw new PrestoException(NOT_SUPPORTED, "Unsupported distributed procedure type: " + analysis.getDistributedProcedureType().get()); + } + } else if (statement instanceof Insert) { checkState(analysis.getInsert().isPresent(), "Insert handle is missing"); return createInsertPlan(analysis, (Insert) statement); @@ -211,6 +229,82 @@ private RelationPlan createExplainAnalyzePlan(Analysis analysis, Explain stateme return new RelationPlan(root, scope, ImmutableList.of(outputVariable)); } + private RelationPlan createCallDistributedProcedurePlanForTableDataRewrite(Analysis analysis, Call statement) + { + TableHandle targetTable = analysis.getCallTarget() + .orElseThrow(() -> new PrestoException(NOT_FOUND, "Target table does not exist")); + Optional procedureName = analysis.getProcedureName(); + Optional procedureArguments = analysis.getProcedureArguments(); + + QuerySpecification querySpecification = analysis.getTargetQuery() + .orElseThrow(() -> new PrestoException(NOT_FOUND, "The query for target table does not exist")); + RelationPlan plan = createRelationPlan(analysis, querySpecification, new SqlPlannerContext(0)); + + ConnectorTableMetadata tableMetadata = metadata.getTableMetadata(session, targetTable).getMetadata(); + List columnNames = tableMetadata.getColumns().stream() + .filter(column -> !column.isHidden()) + .map(ColumnMetadata::getName) + .collect(toImmutableList()); + + Map columnHandleMap = metadata.getColumnHandles(session, targetTable); + TableLayout tableLayout = metadata.getLayout(session, targetTable); + List columnHandles = columnNames.stream().map(columnHandleMap::get).collect(Collectors.toList()); + List outputLayout = plan.getRoot().getOutputVariables(); + + Optional partitioningScheme = Optional.empty(); + Optional partitioningHandle = tableLayout.getTablePartitioning().map(TablePartitioning::getPartitioningHandle); + if (partitioningHandle.isPresent()) { + List partitionFunctionArguments = new ArrayList<>(); + tableLayout.getTablePartitioning().get().getPartitioningColumns().stream() + .mapToInt(columnHandles::indexOf) + .mapToObj(outputLayout::get) + .forEach(partitionFunctionArguments::add); + partitioningScheme = Optional.of(new PartitioningScheme( + Partitioning.create(partitioningHandle.get(), partitionFunctionArguments), + outputLayout)); + } + + verify(columnNames.size() == outputLayout.size(), "columnNames.size() != outputLayout.size(): %s and %s", columnNames, outputLayout); + List variables = plan.getFieldMappings(); + verify(columnNames.size() == variables.size(), "columnNames.size() != variables.size(): %s and %s", columnNames, variables); + Map columnToVariableMap = zip(columnNames.stream(), plan.getFieldMappings().stream(), SimpleImmutableEntry::new) + .collect(toImmutableMap(Entry::getKey, Entry::getValue)); + + Set notNullColumnVariables = tableMetadata.getColumns().stream() + .filter(column -> !column.isNullable()) + .map(ColumnMetadata::getName) + .map(columnToVariableMap::get) + .collect(toImmutableSet()); + + CallDistributedProcedureTarget callDistributedProcedureTarget = new CallDistributedProcedureTarget( + procedureName.get(), + procedureArguments.get(), + Optional.of(targetTable), + tableMetadata.getTable()); + TableFinishNode commitNode = new TableFinishNode( + Optional.empty(), + idAllocator.getNextId(), + new CallDistributedProcedureNode( + Optional.empty(), + idAllocator.getNextId(), + Optional.empty(), + plan.getRoot(), + Optional.of(callDistributedProcedureTarget), + variableAllocator.newVariable("rows", BIGINT), + variableAllocator.newVariable("fragment", VARBINARY), + variableAllocator.newVariable("commitcontext", VARBINARY), + plan.getRoot().getOutputVariables(), + columnNames, + notNullColumnVariables, + partitioningScheme), + Optional.of(callDistributedProcedureTarget), + variableAllocator.newVariable("rows", BIGINT), + Optional.empty(), + Optional.empty(), + Optional.empty()); + return new RelationPlan(commitNode, analysis.getScope(statement), commitNode.getOutputVariables()); + } + private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStatement) { TableHandle targetTable = analysis.getAnalyzeTarget().get(); @@ -564,6 +658,12 @@ private RelationPlan createRelationPlan(Analysis analysis, Query query, SqlPlann .process(query, context); } + private RelationPlan createRelationPlan(Analysis analysis, QuerySpecification query, SqlPlannerContext context) + { + return new RelationPlanner(analysis, variableAllocator, idAllocator, buildLambdaDeclarationToVariableMap(analysis, variableAllocator), metadata, session, sqlParser) + .process(query, context); + } + private ConnectorTableMetadata createTableMetadata(QualifiedObjectName table, List columns, Map propertyExpressions, Map, Expression> parameters, Optional comment) { Map properties = metadata.getTablePropertyManager().getProperties( diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/OutputExtractor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/OutputExtractor.java index 1c6f151d70efc..4e7ee0bed5865 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/OutputExtractor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/OutputExtractor.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.TableFinishNode; import com.facebook.presto.spi.plan.TableWriterNode; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.google.common.base.VerifyException; @@ -82,6 +83,17 @@ public Void visitTableFinish(TableFinishNode node, Void context) return super.visitTableFinish(node, context); } + @Override + public Void visitCallDistributedProcedure(CallDistributedProcedureNode node, Void context) + { + TableWriterNode.WriterTarget writerTarget = node.getTarget().orElseThrow(() -> new VerifyException("target is absent")); + connectorId = writerTarget.getConnectorId(); + checkState(schemaTableName == null || schemaTableName.equals(writerTarget.getSchemaTableName()), + "cannot have more than a single create, insert or delete in a query"); + schemaTableName = writerTarget.getSchemaTableName(); + return null; + } + public Void visitSequence(SequenceNode node, Void context) { // Left children of sequence are ignored since they don't output anything diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 2d7c8be053645..538ff31e038b8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -186,6 +186,7 @@ import com.facebook.presto.sql.planner.optimizations.ReplaceConstantVariableReferencesWithConstants; import com.facebook.presto.sql.planner.optimizations.ReplicateSemiJoinInDelete; import com.facebook.presto.sql.planner.optimizations.RewriteIfOverAggregation; +import com.facebook.presto.sql.planner.optimizations.RewriteWriterTarget; import com.facebook.presto.sql.planner.optimizations.SetFlatteningOptimizer; import com.facebook.presto.sql.planner.optimizations.ShardJoins; import com.facebook.presto.sql.planner.optimizations.SimplifyPlanWithEmptyInput; @@ -1023,6 +1024,8 @@ public PlanOptimizers( featuresConfig.isPrestoSparkExecutionEnvironment())))); builder.add(new MetadataDeleteOptimizer(metadata)); + builder.add(new RewriteWriterTarget()); + // TODO: consider adding a formal final plan sanitization optimizer that prepares the plan for transmission/execution/logging // TODO: figure out how to improve the set flattening optimizer so that it can run at any point this.planningTimeOptimizers = builder.build(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java index 03522fabd5303..1feda17089865 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java @@ -49,6 +49,7 @@ import com.facebook.presto.split.SplitSource; import com.facebook.presto.split.SplitSourceProvider; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; @@ -346,6 +347,12 @@ public Map visitTableWriter(TableWriterNode node, Conte return node.getSource().accept(this, context); } + @Override + public Map visitCallDistributedProcedure(CallDistributedProcedureNode node, Context context) + { + return node.getSource().accept(this, context); + } + @Override public Map visitTableWriteMerge(TableWriterMergeNode node, Context context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java index 61b53b70905f0..5cc78f07af09d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java @@ -66,6 +66,7 @@ import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.optimizations.PreferredProperties.PartitioningProperties; import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.ChildReplacer; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; @@ -749,12 +750,32 @@ public PlanWithProperties visitMetadataDelete(MetadataDeleteNode node, Preferred } @Override - public PlanWithProperties visitTableWriter(TableWriterNode node, PreferredProperties preferredProperties) + public PlanWithProperties visitCallDistributedProcedure(CallDistributedProcedureNode node, PreferredProperties preferredProperties) { - PlanWithProperties source = accept(node.getSource(), preferredProperties); + Optional partitioningScheme = node.getPartitioningScheme(); + boolean isSingleWriterPerPartitionRequired = partitioningScheme.isPresent(); + return getTableWriterPlanWithProperties(node, preferredProperties, partitioningScheme, partitioningScheme, isSingleWriterPerPartitionRequired); + } + @Override + public PlanWithProperties visitTableWriter(TableWriterNode node, PreferredProperties preferredProperties) + { Optional shufflePartitioningScheme = node.getTablePartitioningScheme(); - if (!node.isSingleWriterPerPartitionRequired()) { + return getTableWriterPlanWithProperties(node, preferredProperties, node.getTablePartitioningScheme(), + shufflePartitioningScheme, node.isSingleWriterPerPartitionRequired()); + } + + private PlanWithProperties getTableWriterPlanWithProperties( + PlanNode node, + PreferredProperties preferredProperties, + Optional tablePartitioningScheme, + Optional shufflePartitioningScheme, + boolean isSingleWriterPerPartitionRequired) + { + checkArgument(node instanceof TableWriterNode || node instanceof CallDistributedProcedureNode); + PlanWithProperties source = accept(node.getSources().get(0), preferredProperties); + + if (!isSingleWriterPerPartitionRequired) { // prefer scale writers if single writer per partition is not required // TODO: take into account partitioning scheme in scale writer tasks implementation if (scaleWriters) { @@ -774,9 +795,9 @@ else if (redistributeWrites) { !(source.getProperties().isRefinedPartitioningOver(shufflePartitioningScheme.get().getPartitioning(), false, metadata, session) && canPushdownPartialMerge(source.getNode(), partialMergePushdownStrategy))) { PartitioningScheme exchangePartitioningScheme = shufflePartitioningScheme.get(); - if (node.getTablePartitioningScheme().isPresent() && isPrestoSparkAssignBucketToPartitionForPartitionedTableWriteEnabled(session)) { + if (tablePartitioningScheme.isPresent() && isPrestoSparkAssignBucketToPartitionForPartitionedTableWriteEnabled(session)) { int writerThreadsPerNode = getTaskPartitionedWriterCount(session); - int bucketCount = getBucketCount(node.getTablePartitioningScheme().get().getPartitioning().getHandle()); + int bucketCount = getBucketCount(tablePartitioningScheme.get().getPartitioning().getHandle()); int[] bucketToPartition = new int[bucketCount]; for (int i = 0; i < bucketCount; i++) { bucketToPartition[i] = i / writerThreadsPerNode; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java index c76c6b6252771..769315405ca41 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java @@ -50,6 +50,7 @@ import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties; import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; @@ -597,6 +598,20 @@ public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, StreamPrefe return planAndEnforceChildren(node, requiredProperties, requiredProperties); } + @Override + public PlanWithProperties visitCallDistributedProcedure(CallDistributedProcedureNode node, StreamPreferredProperties parentPreferences) + { + if (node.getPartitioningScheme().isPresent() && getTaskPartitionedWriterCount(session) == 1) { + return planAndEnforceChildren(node, singleStream(), defaultParallelism(session)); + } + + if (!node.getPartitioningScheme().isPresent() && getTaskWriterCount(session) == 1) { + return planAndEnforceChildren(node, singleStream(), defaultParallelism(session)); + } + + return planAndEnforceChildren(node, fixedParallelism(), fixedParallelism()); + } + // // Table Writer // diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java index b8b0089153b7b..21a0699b03c65 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java @@ -59,6 +59,7 @@ import com.facebook.presto.sql.planner.optimizations.ActualProperties.Global; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; @@ -798,6 +799,21 @@ public ActualProperties visitTableWriter(TableWriterNode node, List inputProperties) + { + ActualProperties properties = Iterables.getOnlyElement(inputProperties); + + if (properties.isCoordinatorOnly()) { + return ActualProperties.builder() + .global(coordinatorSingleStreamPartition()) + .build(); + } + return ActualProperties.builder() + .global(properties.isSingleNode() ? singleStreamPartition() : arbitraryPartition()) + .build(); + } + @Override public ActualProperties visitTableWriteMerge(TableWriterMergeNode node, List inputProperties) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index d9fd049555be3..c74a76dcf41f4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -57,6 +57,7 @@ import com.facebook.presto.sql.planner.VariablesExtractor; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; @@ -807,6 +808,25 @@ public PlanNode visitTableWriteMerge(TableWriterMergeNode node, RewriteContext> context) + { + PlanNode source = context.rewrite(node.getSource(), ImmutableSet.copyOf(node.getSource().getOutputVariables())); + return new CallDistributedProcedureNode( + node.getSourceLocation(), + node.getId(), + node.getStatsEquivalentPlanNode(), + source, + node.getTarget(), + node.getRowCountVariable(), + node.getFragmentVariable(), + node.getTableCommitContextVariable(), + node.getColumns(), + node.getColumnNames(), + node.getNotNullColumnVariables(), + node.getPartitioningScheme()); + } + @Override public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, RewriteContext> context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java index 748e97998c968..65f33486c9a15 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java @@ -69,6 +69,7 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; @@ -467,6 +468,13 @@ public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext c node.getCurrentConstraint()); } + @Override + public PlanNode visitCallDistributedProcedure(CallDistributedProcedureNode node, RewriteContext context) + { + context.get().variables.addAll(node.getColumns()); + return context.defaultRewrite(node, context.get()); + } + @Override public PlanNode visitTableWriter(TableWriterNode node, RewriteContext context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/RewriteWriterTarget.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/RewriteWriterTarget.java new file mode 100644 index 0000000000000..1a9103d54aef7 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/RewriteWriterTarget.java @@ -0,0 +1,159 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.TableHandle; +import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.TableFinishNode; +import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.plan.TableWriterNode.CallDistributedProcedureTarget; +import com.facebook.presto.spi.plan.TableWriterNode.WriterTarget; +import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.ValuesNode; +import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; +import com.facebook.presto.sql.planner.plan.SimplePlanRewriter.RewriteContext; + +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.stream.Collectors.toSet; + +public class RewriteWriterTarget + implements PlanOptimizer +{ + public RewriteWriterTarget() + {} + + @Override + public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + { + Rewriter rewriter = new Rewriter(); + PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, Optional.empty()); + return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged()); + } + + private class Rewriter + extends SimplePlanRewriter> + { + private boolean planChanged; + + public Rewriter() + {} + + @Override + public PlanNode visitCallDistributedProcedure(CallDistributedProcedureNode node, RewriteContext> context) + { + CallDistributedProcedureTarget callDistributedProcedureTarget = (CallDistributedProcedureTarget) getContextTarget(context); + return new CallDistributedProcedureNode( + node.getSourceLocation(), + node.getId(), + node.getSource(), + Optional.of(callDistributedProcedureTarget), + node.getRowCountVariable(), + node.getFragmentVariable(), + node.getTableCommitContextVariable(), + node.getColumns(), + node.getColumnNames(), + node.getNotNullColumnVariables(), + node.getPartitioningScheme()); + } + + @Override + public PlanNode visitTableFinish(TableFinishNode node, RewriteContext> context) + { + PlanNode child = node.getSource(); + + Optional newTarget = getWriterTarget(child); + if (!newTarget.isPresent()) { + return node; + } + + planChanged = true; + child = context.rewrite(child, newTarget); + + return new TableFinishNode( + node.getSourceLocation(), + node.getId(), + child, + newTarget, + node.getRowCountVariable(), + node.getStatisticsAggregation(), + node.getStatisticsAggregationDescriptor(), + Optional.empty()); + } + + public Optional getWriterTarget(PlanNode node) + { + if (node instanceof CallDistributedProcedureNode) { + Optional tableHandle = findTableHandleForCallDistributedProcedure(((CallDistributedProcedureNode) node).getSource()); + Optional callDistributedProcedureTarget = ((CallDistributedProcedureNode) node).getTarget(); + return !tableHandle.isPresent() ? callDistributedProcedureTarget.map(WriterTarget.class::cast) : + callDistributedProcedureTarget.map(target -> new CallDistributedProcedureTarget( + target.getProcedureName(), + target.getProcedureArguments(), + tableHandle, + target.getSchemaTableName())); + } + + if (node instanceof ExchangeNode || node instanceof UnionNode) { + Set> writerTargets = node.getSources().stream() + .map(this::getWriterTarget) + .collect(toSet()); + return getOnlyElement(writerTargets); + } + + return Optional.empty(); + } + + private Optional findTableHandleForCallDistributedProcedure(PlanNode startNode) + { + List tableScanNodes = PlanNodeSearcher.searchFrom(startNode) + .where(node -> node instanceof TableScanNode) + .findAll(); + + if (tableScanNodes.size() == 1) { + return Optional.of(((TableScanNode) tableScanNodes.get(0)).getTable()); + } + + List valuesNodes = PlanNodeSearcher.searchFrom(startNode) + .where(node -> node instanceof ValuesNode) + .findAll(); + + if (valuesNodes.size() == 1) { + return Optional.empty(); + } + + throw new IllegalArgumentException("Expected to find exactly one update target TableScanNode in plan but found: " + tableScanNodes); + } + + public boolean isPlanChanged() + { + return planChanged; + } + } + + private static WriterTarget getContextTarget(RewriteContext> context) + { + return context.get().orElseThrow(() -> new IllegalStateException("WriterTarget not present")); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java index a56b3c773d6ed..2515e42fbee58 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java @@ -49,6 +49,7 @@ import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties.StreamDistribution; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; @@ -467,6 +468,14 @@ public StreamProperties visitDelete(DeleteNode node, List inpu return properties.withUnspecifiedPartitioning(); } + @Override + public StreamProperties visitCallDistributedProcedure(CallDistributedProcedureNode node, List inputProperties) + { + StreamProperties properties = Iterables.getOnlyElement(inputProperties); + // call distributed procedure only outputs the row count + return properties.withUnspecifiedPartitioning(); + } + @Override public StreamProperties visitTableWriter(TableWriterNode node, List inputProperties) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java index 9805efad17939..cb66746b0c633 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java @@ -36,6 +36,7 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.tree.Expression; @@ -262,6 +263,29 @@ public TableWriterNode map(TableWriterNode node, PlanNode source, PlanNodeId new node.getIsTemporaryTableWriter()); } + public CallDistributedProcedureNode map(CallDistributedProcedureNode node, PlanNode source) + { + ImmutableList columns = node.getColumns().stream() + .map(this::map) + .collect(toImmutableList()); + Set notNullColumnVariables = node.getNotNullColumnVariables().stream() + .map(this::map) + .collect(toImmutableSet()); + + return new CallDistributedProcedureNode( + node.getSourceLocation(), + node.getId(), + source, + node.getTarget(), + node.getRowCountVariable(), + node.getFragmentVariable(), + node.getTableCommitContextVariable(), + columns, + columns.stream().map(VariableReferenceExpression::getName).collect(toImmutableList()), + notNullColumnVariables, + node.getPartitioningScheme().map(partitioningScheme -> canonicalize(partitioningScheme, source))); + } + public StatisticsWriterNode map(StatisticsWriterNode node, PlanNode source) { return new StatisticsWriterNode( diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index 91a92107f9f3c..201bd54c8c937 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -62,6 +62,7 @@ import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; @@ -721,6 +722,14 @@ private static ImmutableList.Builder rewriteSources(SetOperationNode n return rewrittenSources; } + @Override + public PlanNode visitCallDistributedProcedure(CallDistributedProcedureNode node, RewriteContext context) + { + PlanNode source = context.rewrite(node.getSource()); + SymbolMapper mapper = new SymbolMapper(mapping, types, warningCollector); + return mapper.map(node, source); + } + @Override public PlanNode visitTableWriter(TableWriterNode node, RewriteContext context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/CallDistributedProcedureNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/CallDistributedProcedureNode.java new file mode 100644 index 0000000000000..b07ec4ca4c2e6 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/CallDistributedProcedureNode.java @@ -0,0 +1,217 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.plan; + +import com.facebook.presto.spi.SourceLocation; +import com.facebook.presto.spi.plan.PartitioningScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.plan.TableWriterNode.CallDistributedProcedureTarget; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; + +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class CallDistributedProcedureNode + extends InternalPlanNode +{ + private final PlanNode source; + private final Optional target; + private final VariableReferenceExpression rowCountVariable; + private final VariableReferenceExpression fragmentVariable; + private final VariableReferenceExpression tableCommitContextVariable; + private final List columns; + private final List columnNames; + private final Set notNullColumnVariables; + private final Optional partitioningScheme; + private final List outputs; + + @JsonCreator + public CallDistributedProcedureNode( + Optional sourceLocation, + @JsonProperty("id") PlanNodeId id, + @JsonProperty("source") PlanNode source, + @JsonProperty("target") Optional target, + @JsonProperty("rowCountVariable") VariableReferenceExpression rowCountVariable, + @JsonProperty("fragmentVariable") VariableReferenceExpression fragmentVariable, + @JsonProperty("tableCommitContextVariable") VariableReferenceExpression tableCommitContextVariable, + @JsonProperty("columns") List columns, + @JsonProperty("columnNames") List columnNames, + @JsonProperty("notNullColumnVariables") Set notNullColumnVariables, + @JsonProperty("partitioningScheme") Optional partitioningScheme) + { + this(sourceLocation, id, Optional.empty(), source, target, rowCountVariable, fragmentVariable, tableCommitContextVariable, columns, columnNames, notNullColumnVariables, partitioningScheme); + } + + public CallDistributedProcedureNode( + Optional sourceLocation, + PlanNodeId id, + Optional statsEquivalentPlanNode, + PlanNode source, + Optional target, + VariableReferenceExpression rowCountVariable, + VariableReferenceExpression fragmentVariable, + VariableReferenceExpression tableCommitContextVariable, + List columns, + List columnNames, + Set notNullColumnVariables, + Optional partitioningScheme) + { + super(sourceLocation, id, statsEquivalentPlanNode); + + requireNonNull(columns, "columns is null"); + requireNonNull(columnNames, "columnNames is null"); + checkArgument(columns.size() == columnNames.size(), "columns and columnNames sizes don't match"); + + this.source = requireNonNull(source, "source is null"); + this.target = requireNonNull(target, "target is null"); + this.rowCountVariable = requireNonNull(rowCountVariable, "rowCountVariable is null"); + this.fragmentVariable = requireNonNull(fragmentVariable, "fragmentVariable is null"); + this.tableCommitContextVariable = requireNonNull(tableCommitContextVariable, "tableCommitContextVariable is null"); + this.columns = ImmutableList.copyOf(columns); + this.columnNames = ImmutableList.copyOf(columnNames); + this.notNullColumnVariables = ImmutableSet.copyOf(requireNonNull(notNullColumnVariables, "notNullColumns is null")); + this.partitioningScheme = requireNonNull(partitioningScheme, "partitioningScheme is null"); + + ImmutableList.Builder outputs = ImmutableList.builder() + .add(rowCountVariable) + .add(fragmentVariable) + .add(tableCommitContextVariable); + this.outputs = outputs.build(); + } + + @JsonProperty + public PlanNode getSource() + { + return source; + } + + @JsonIgnore + public Optional getTarget() + { + return target; + } + + @JsonProperty + public VariableReferenceExpression getRowCountVariable() + { + return rowCountVariable; + } + + @JsonProperty + public VariableReferenceExpression getFragmentVariable() + { + return fragmentVariable; + } + + @JsonProperty + public VariableReferenceExpression getTableCommitContextVariable() + { + return tableCommitContextVariable; + } + + @JsonProperty + public Optional getPartitioningScheme() + { + return partitioningScheme; + } + + @JsonProperty + public List getColumns() + { + return columns; + } + + @JsonProperty + public List getColumnNames() + { + return columnNames; + } + + @JsonProperty + public Set getNotNullColumnVariables() + { + return notNullColumnVariables; + } + + @JsonProperty + public List getOutputs() + { + return outputs; + } + + @Override + public List getSources() + { + return ImmutableList.of(source); + } + + @Override + public List getOutputVariables() + { + return outputs; + } + + @Override + public R accept(InternalPlanVisitor visitor, C context) + { + return visitor.visitCallDistributedProcedure(this, context); + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + return new CallDistributedProcedureNode( + this.getSourceLocation(), + getId(), + this.getStatsEquivalentPlanNode(), + Iterables.getOnlyElement(newChildren), + target, + rowCountVariable, + fragmentVariable, + tableCommitContextVariable, + columns, + columnNames, + notNullColumnVariables, + partitioningScheme); + } + + @Override + public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) + { + return new CallDistributedProcedureNode( + this.getSourceLocation(), + getId(), + statsEquivalentPlanNode, + source, + target, + rowCountVariable, + fragmentVariable, + tableCommitContextVariable, + columns, + columnNames, + notNullColumnVariables, + partitioningScheme); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java index b33dfc48938d7..4add380c33bc2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java @@ -57,6 +57,11 @@ public R visitStatisticsWriterNode(StatisticsWriterNode node, C context) return visitPlan(node, context); } + public R visitCallDistributedProcedure(CallDistributedProcedureNode node, C context) + { + return visitPlan(node, context); + } + public R visitGroupId(GroupIdNode node, C context) { return visitPlan(node, context); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index 2b7059d12e02a..ed06aa5170dbd 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -70,6 +70,7 @@ import com.facebook.presto.spi.plan.TableFinishNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TableWriterNode; +import com.facebook.presto.spi.plan.TableWriterNode.CallDistributedProcedureTarget; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; import com.facebook.presto.spi.plan.UnnestNode; @@ -86,6 +87,7 @@ import com.facebook.presto.sql.planner.optimizations.JoinNodeUtils; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; @@ -1199,6 +1201,13 @@ public Void visitStatisticsWriterNode(StatisticsWriterNode node, Void context) return processChildren(node, context); } + @Override + public Void visitCallDistributedProcedure(CallDistributedProcedureNode node, Void context) + { + addNode(node, "CallDistributedProcedure", format("[%s]", node.getTarget().map(CallDistributedProcedureTarget::getProcedureName).orElse(null))); + return processChildren(node, context); + } + @Override public Void visitTableFinish(TableFinishNode node, Void context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/CallDistributedProcedureValidator.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/CallDistributedProcedureValidator.java new file mode 100644 index 0000000000000..084cb6e7048b7 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/CallDistributedProcedureValidator.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.sanity; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.plan.OutputNode; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.TableFinishNode; +import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.plan.ValuesNode; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; +import com.facebook.presto.sql.planner.plan.ExchangeNode; + +import java.util.Optional; + +import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; + +public final class CallDistributedProcedureValidator + implements PlanChecker.Checker +{ + @Override + public void validate(PlanNode planNode, Session session, Metadata metadata, WarningCollector warningCollector) + { + Optional callDistributedProcedureNode = searchFrom(planNode) + .where(node -> node instanceof CallDistributedProcedureNode) + .findFirst(); + + if (!callDistributedProcedureNode.isPresent()) { + // not a call distributed procedure plan + return; + } + + searchFrom(planNode) + .findAll() + .forEach(node -> { + if (!isAllowedNode(node)) { + throw new IllegalStateException("Unexpected " + node.getClass().getSimpleName() + " found in plan; probably connector was not able to handle provided WHERE expression"); + } + }); + } + + private boolean isAllowedNode(PlanNode node) + { + return node instanceof TableScanNode + || node instanceof ValuesNode + || node instanceof ProjectNode + || node instanceof CallDistributedProcedureNode + || node instanceof OutputNode + || node instanceof ExchangeNode + || node instanceof TableFinishNode; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanChecker.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanChecker.java index d0a85b8632e35..0f38868fc5c4d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanChecker.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanChecker.java @@ -73,7 +73,8 @@ public PlanChecker(FeaturesConfig featuresConfig, boolean noExchange, PlanChecke new VerifyNoIntermediateFormExpression(), new VerifyProjectionLocality(), new DynamicFiltersChecker(), - new WarnOnScanWithoutPartitionPredicate(featuresConfig)); + new WarnOnScanWithoutPartitionPredicate(featuresConfig), + new CallDistributedProcedureValidator()); if (featuresConfig.isNativeExecutionEnabled()) { if (featuresConfig.isDisableTimeStampWithTimeZoneForNative() || featuresConfig.isDisableIPAddressForNative()) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java index 7bb6f516b9171..b8706219983fb 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java @@ -57,6 +57,7 @@ import com.facebook.presto.sql.planner.optimizations.WindowNodeUtil; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; @@ -595,6 +596,15 @@ public Void visitExchange(ExchangeNode node, Set bo return null; } + @Override + public Void visitCallDistributedProcedure(CallDistributedProcedureNode node, Set boundVariables) + { + PlanNode source = node.getSource(); + source.accept(this, boundVariables); // visit child + + return null; + } + @Override public Void visitTableWriter(TableWriterNode node, Set boundVariables) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/ExplainRewrite.java b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/ExplainRewrite.java index 1b5ee3c1a47dd..279d41376e999 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/ExplainRewrite.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/ExplainRewrite.java @@ -17,6 +17,7 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.analyzer.AnalyzerOptions; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.sql.analyzer.BuiltInQueryPreparer; import com.facebook.presto.sql.analyzer.BuiltInQueryPreparer.BuiltInPreparedQuery; @@ -63,7 +64,8 @@ public Statement rewrite( WarningCollector warningCollector, String query) { - return (Statement) new Visitor(session, parser, queryExplainer, warningCollector, query).process(node, null); + return (Statement) new Visitor(session, parser, queryExplainer, metadata.getProcedureRegistry(), warningCollector, query) + .process(node, null); } private static final class Visitor @@ -79,11 +81,12 @@ public Visitor( Session session, SqlParser parser, Optional queryExplainer, + ProcedureRegistry procedureRegistry, WarningCollector warningCollector, String query) { this.session = requireNonNull(session, "session is null"); - this.queryPreparer = new BuiltInQueryPreparer(requireNonNull(parser, "queryPreparer is null")); + this.queryPreparer = new BuiltInQueryPreparer(requireNonNull(parser, "queryPreparer is null"), procedureRegistry); this.queryExplainer = requireNonNull(queryExplainer, "queryExplainer is null"); this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); this.query = requireNonNull(query, "query is null"); diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index f7fd052f02c3a..ae88a073dfe4b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -98,6 +98,7 @@ import com.facebook.presto.memory.MemoryManagerConfig; import com.facebook.presto.memory.NodeMemoryConfig; import com.facebook.presto.metadata.AnalyzePropertyManager; +import com.facebook.presto.metadata.BuiltInProcedureRegistry; import com.facebook.presto.metadata.CatalogManager; import com.facebook.presto.metadata.ColumnPropertyManager; import com.facebook.presto.metadata.FunctionAndTypeManager; @@ -151,6 +152,7 @@ import com.facebook.presto.spi.plan.SimplePlanFragment; import com.facebook.presto.spi.plan.StageExecutionDescriptor; import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spiller.FileSingleStreamSpillerFactory; import com.facebook.presto.spiller.GenericPartitioningSpillerFactory; import com.facebook.presto.spiller.GenericSpillerFactory; @@ -317,6 +319,7 @@ public class LocalQueryRunner private final PageSorter pageSorter; private final PageIndexerFactory pageIndexerFactory; private final MetadataManager metadata; + private final ProcedureRegistry procedureRegistry; private final ScalarStatsCalculator scalarStatsCalculator; private final StatsNormalizer statsNormalizer; private final FilterStatsCalculator filterStatsCalculator; @@ -433,8 +436,10 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, this.blockEncodingManager = new BlockEncodingManager(); featuresConfig.setIgnoreStatsCalculatorFailures(false); + FunctionAndTypeManager functionAndTypeManager = new FunctionAndTypeManager(transactionManager, new TableFunctionRegistry(), blockEncodingManager, featuresConfig, functionsConfig, new HandleResolver(), ImmutableSet.of()); + this.procedureRegistry = new BuiltInProcedureRegistry(functionAndTypeManager); this.metadata = new MetadataManager( - new FunctionAndTypeManager(transactionManager, new TableFunctionRegistry(), blockEncodingManager, featuresConfig, functionsConfig, new HandleResolver(), ImmutableSet.of()), + functionAndTypeManager, blockEncodingManager, createTestingSessionPropertyManager( new SystemSessionProperties( @@ -456,7 +461,8 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, new TablePropertyManager(), new ColumnPropertyManager(), new AnalyzePropertyManager(), - transactionManager); + transactionManager, + procedureRegistry); this.splitManager = new SplitManager(metadata, new QueryManagerConfig(), nodeSchedulerConfig); this.planCheckerProviderManager = new PlanCheckerProviderManager(new JsonCodecSimplePlanFragmentSerde(jsonCodec(SimplePlanFragment.class)), new PlanCheckerProviderManagerConfig()); this.distributedPlanChecker = new PlanChecker(featuresConfig, false, planCheckerProviderManager); @@ -500,6 +506,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, nodeManager, nodeInfo, metadata.getFunctionAndTypeManager(), + procedureRegistry, pageSorter, pageIndexerFactory, transactionManager, @@ -524,7 +531,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, BuiltInQueryAnalyzer queryAnalyzer = new BuiltInQueryAnalyzer(metadata, sqlParser, accessControl, Optional.empty(), metadataExtractorExecutor); BuiltInAnalyzerProvider analyzerProvider = new BuiltInAnalyzerProvider(queryAnalyzer); - BuiltInQueryPreparer queryPreparer = new BuiltInQueryPreparer(sqlParser); + BuiltInQueryPreparer queryPreparer = new BuiltInQueryPreparer(sqlParser, procedureRegistry); BuiltInQueryPreparerProvider queryPreparerProvider = new BuiltInQueryPreparerProvider(queryPreparer); this.pluginManager = new PluginManager( @@ -913,7 +920,7 @@ private MaterializedResultWithPlan executeInternal(Session session, @Language("S private MaterializedResultWithPlan executeExplainTypeValidate(String sql, Session session, WarningCollector warningCollector) { AnalyzerOptions analyzerOptions = createAnalyzerOptions(session, warningCollector); - BuiltInPreparedQuery preparedQuery = new BuiltInQueryPreparer(sqlParser).prepareQuery(analyzerOptions, sql, session.getPreparedStatements(), warningCollector); + BuiltInPreparedQuery preparedQuery = new BuiltInQueryPreparer(sqlParser, procedureRegistry).prepareQuery(analyzerOptions, sql, session.getPreparedStatements(), warningCollector); assertFormattedSql(sqlParser, createParsingOptions(session), preparedQuery.getStatement()); PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); @@ -945,7 +952,7 @@ private MaterializedResultWithPlan executeExplainTypeValidate(String sql, Sessio private boolean isExplainTypeValidate(String sql, Session session, WarningCollector warningCollector) { AnalyzerOptions analyzerOptions = createAnalyzerOptions(session, warningCollector); - PreparedQuery preparedQuery = new BuiltInQueryPreparer(sqlParser).prepareQuery(analyzerOptions, sql, session.getPreparedStatements(), warningCollector); + PreparedQuery preparedQuery = new BuiltInQueryPreparer(sqlParser, procedureRegistry).prepareQuery(analyzerOptions, sql, session.getPreparedStatements(), warningCollector); return preparedQuery.isExplainTypeValidate(); } @@ -1125,7 +1132,7 @@ public Plan createPlan(Session session, @Language("SQL") String sql, Optimizer.P public Plan createPlan(Session session, @Language("SQL") String sql, Optimizer.PlanStage stage, boolean noExchange, boolean nativeExecutionEnabled, WarningCollector warningCollector) { AnalyzerOptions analyzerOptions = createAnalyzerOptions(session, warningCollector); - BuiltInPreparedQuery preparedQuery = new BuiltInQueryPreparer(sqlParser).prepareQuery(analyzerOptions, sql, session.getPreparedStatements(), warningCollector); + BuiltInPreparedQuery preparedQuery = new BuiltInQueryPreparer(sqlParser, procedureRegistry).prepareQuery(analyzerOptions, sql, session.getPreparedStatements(), warningCollector); assertFormattedSql(sqlParser, createParsingOptions(session), preparedQuery.getStatement()); return createPlan(session, sql, getPlanOptimizers(noExchange, nativeExecutionEnabled), stage, warningCollector); @@ -1179,7 +1186,7 @@ public Plan createPlan(Session session, @Language("SQL") String sql, List optimizers, Optimizer.PlanStage stage, WarningCollector warningCollector) { AnalyzerOptions analyzerOptions = createAnalyzerOptions(session, warningCollector); - BuiltInPreparedQuery preparedQuery = new BuiltInQueryPreparer(sqlParser).prepareQuery(analyzerOptions, sql, session.getPreparedStatements(), warningCollector); + BuiltInPreparedQuery preparedQuery = new BuiltInQueryPreparer(sqlParser, procedureRegistry).prepareQuery(analyzerOptions, sql, session.getPreparedStatements(), warningCollector); assertFormattedSql(sqlParser, createParsingOptions(session), preparedQuery.getStatement()); PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/TestProcedureRegistry.java b/presto-main-base/src/main/java/com/facebook/presto/testing/TestProcedureRegistry.java new file mode 100644 index 0000000000000..5c0a82f4ee75a --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/TestProcedureRegistry.java @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.testing; + +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.connector.ConnectorProcedureContext; +import com.facebook.presto.spi.procedure.DistributedProcedure; +import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.ProcedureRegistry; + +import java.util.Collection; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.facebook.presto.spi.StandardErrorCode.PROCEDURE_NOT_FOUND; +import static java.util.Objects.requireNonNull; + +public class TestProcedureRegistry + implements ProcedureRegistry +{ + private final Map> connectorProcedures = new ConcurrentHashMap<>(); + + @Override + public void addProcedures(ConnectorId connectorId, Collection procedures) + { + requireNonNull(connectorId, "connectorId is null"); + requireNonNull(procedures, "procedures is null"); + + Map proceduresByName = procedures.stream().collect(Collectors.toMap( + procedure -> new SchemaTableName(procedure.getSchema(), procedure.getName()), + Function.identity())); + if (connectorProcedures.putIfAbsent(connectorId, proceduresByName) != null) { + throw new IllegalStateException("Procedures already registered for connector: " + connectorId); + } + } + + @Override + public void removeProcedures(ConnectorId connectorId) + { + connectorProcedures.remove(connectorId); + } + + @Override + public Procedure resolve(ConnectorId connectorId, SchemaTableName name) + { + Map procedures = connectorProcedures.get(connectorId); + if (procedures != null) { + Procedure procedure = procedures.get(name); + if (procedure != null) { + return procedure; + } + } + throw new PrestoException(PROCEDURE_NOT_FOUND, "Procedure not registered: " + name); + } + + @Override + public DistributedProcedure resolveDistributed(ConnectorId connectorId, SchemaTableName name) + { + Map procedures = connectorProcedures.get(connectorId); + if (procedures != null) { + Procedure procedure = procedures.get(name); + if (procedure != null && procedure instanceof DistributedProcedure) { + return (DistributedProcedure) procedure; + } + } + throw new PrestoException(PROCEDURE_NOT_FOUND, "Distributed procedure not registered: " + name); + } + + @Override + public boolean isDistributedProcedure(ConnectorId connectorId, SchemaTableName name) + { + Map procedures = connectorProcedures.get(connectorId); + return procedures != null && + procedures.containsKey(name) && + procedures.get(name) instanceof DistributedProcedure; + } + + public static class TestProcedureContext + implements ConnectorProcedureContext + {} +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingHandle.java b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingHandle.java index 34de904e9068e..40294ba5dcb0f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingHandle.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingHandle.java @@ -13,12 +13,13 @@ */ package com.facebook.presto.testing; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorInsertTableHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorTableLayoutHandle; public enum TestingHandle - implements ConnectorOutputTableHandle, ConnectorInsertTableHandle, ConnectorTableLayoutHandle + implements ConnectorOutputTableHandle, ConnectorInsertTableHandle, ConnectorTableLayoutHandle, ConnectorDistributedProcedureHandle { INSTANCE } diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingHandleResolver.java b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingHandleResolver.java index 8566f9d1ecac3..0421d2c316f58 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingHandleResolver.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingHandleResolver.java @@ -14,6 +14,7 @@ package com.facebook.presto.testing; import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorHandleResolver; import com.facebook.presto.spi.ConnectorInsertTableHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; @@ -63,6 +64,12 @@ public Class getInsertTableHandleClass() return TestingHandle.class; } + @Override + public Class getDistributedProcedureHandleClass() + { + return TestingHandle.class; + } + @Override public Class getTransactionHandleClass() { diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/AnalyzerUtil.java b/presto-main-base/src/main/java/com/facebook/presto/util/AnalyzerUtil.java index ba91087dcc245..752e5616dbc33 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/AnalyzerUtil.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/AnalyzerUtil.java @@ -80,6 +80,8 @@ public static AnalyzerOptions createAnalyzerOptions(Session session, WarningColl .setLogFormattedQueryEnabled(isLogFormattedQueryEnabled(session)) .setWarningHandlingLevel(getWarningHandlingLevel(session)) .setWarningCollector(warningCollector) + .setSessionCatalogName(session.getCatalog()) + .setSessionSchemaName(session.getSchema()) .build(); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java b/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java index 6c210e9e0848c..3a94ed3c2d864 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java @@ -43,6 +43,7 @@ import com.facebook.presto.spi.plan.TableFinishNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TableWriterNode; +import com.facebook.presto.spi.plan.TableWriterNode.CallDistributedProcedureTarget; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; import com.facebook.presto.spi.plan.UnnestNode; @@ -55,6 +56,7 @@ import com.facebook.presto.sql.planner.optimizations.JoinNodeUtils; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; @@ -296,6 +298,13 @@ public Void visitSequence(SequenceNode node, Void context) return null; } + @Override + public Void visitCallDistributedProcedure(CallDistributedProcedureNode node, Void context) + { + printNode(node, format("CallDistributedProcedure[%s]", node.getTarget().map(CallDistributedProcedureTarget::getProcedureName).orElse(null)), NODE_COLORS.get(NodeType.TABLE_WRITER)); + return node.getSource().accept(this, context); + } + @Override public Void visitTableWriter(TableWriterNode node, Void context) { diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestCreateMaterializedViewTask.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestCreateMaterializedViewTask.java index 95dd8cdfbd3df..2044053bad053 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestCreateMaterializedViewTask.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestCreateMaterializedViewTask.java @@ -40,12 +40,14 @@ import com.facebook.presto.spi.analyzer.MetadataResolver; import com.facebook.presto.spi.analyzer.ViewDefinition; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.spi.security.AllowAllAccessControl; import com.facebook.presto.sql.parser.ParsingOptions; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.tree.CreateMaterializedView; +import com.facebook.presto.testing.TestProcedureRegistry; import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.ImmutableList; import org.testng.annotations.BeforeMethod; @@ -119,6 +121,7 @@ public void setUp() metadata = new MockMetadata( functionAndTypeManager, + new TestProcedureRegistry(), tablePropertyManager, columnPropertyManager, testCatalog.getConnectorId()); @@ -190,6 +193,7 @@ private static class MockMetadata extends AbstractMockMetadata { private final FunctionAndTypeManager functionAndTypeManager; + private final ProcedureRegistry procedureRegistry; private final TablePropertyManager tablePropertyManager; private final ColumnPropertyManager columnPropertyManager; private final ConnectorId catalogHandle; @@ -198,11 +202,13 @@ private static class MockMetadata public MockMetadata( FunctionAndTypeManager functionAndTypeManager, + ProcedureRegistry procedureRegistry, TablePropertyManager tablePropertyManager, ColumnPropertyManager columnPropertyManager, ConnectorId catalogHandle) { this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + this.procedureRegistry = requireNonNull(procedureRegistry, "procedureRegistry is null"); this.tablePropertyManager = requireNonNull(tablePropertyManager, "tablePropertyManager is null"); this.columnPropertyManager = requireNonNull(columnPropertyManager, "columnPropertyManager is null"); this.catalogHandle = requireNonNull(catalogHandle, "catalogHandle is null"); @@ -240,6 +246,12 @@ public FunctionAndTypeManager getFunctionAndTypeManager() return functionAndTypeManager; } + @Override + public ProcedureRegistry getProcedureRegistry() + { + return procedureRegistry; + } + @Override public Type getType(TypeSignature signature) { diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestExecuteProcedureHandle.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestExecuteProcedureHandle.java new file mode 100644 index 0000000000000..3eba2734147e3 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestExecuteProcedureHandle.java @@ -0,0 +1,117 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution; + +import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.json.JsonModule; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.ExecuteProcedureHandle; +import com.facebook.presto.metadata.DistributedProcedureHandle; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.HandleJsonModule; +import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.server.SliceDeserializer; +import com.facebook.presto.server.SliceSerializer; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.Serialization; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.testing.TestingHandle; +import com.facebook.presto.testing.TestingHandleResolver; +import com.facebook.presto.testing.TestingTransactionHandle; +import com.facebook.presto.type.TypeDeserializer; +import com.google.common.collect.ImmutableList; +import com.google.inject.Injector; +import com.google.inject.Key; +import com.google.inject.Module; +import io.airlift.slice.Slice; +import org.testng.annotations.Test; + +import java.util.UUID; + +import static com.facebook.airlift.configuration.ConfigBinder.configBinder; +import static com.facebook.airlift.json.JsonBinder.jsonBinder; +import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; +import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager; +import static com.google.inject.multibindings.Multibinder.newSetBinder; +import static org.testng.Assert.assertEquals; + +public class TestExecuteProcedureHandle +{ + @Test + public void testExecuteProcedureHandleRoundTrip() + { + String catalogName = "test_catalog"; + JsonCodec codec = createJsonCodec(catalogName); + UUID uuid = UUID.randomUUID(); + ExecuteProcedureHandle expected = createExecuteProcedureHandle(catalogName, uuid); + ExecuteProcedureHandle actual = codec.fromJson(codec.toJson(expected)); + + assertEquals(actual.getProcedureName(), expected.getProcedureName()); + assertEquals(actual.getSchemaTableName(), expected.getSchemaTableName()); + assertEquals(actual.getHandle().getClass(), expected.getHandle().getClass()); + assertEquals(actual.getHandle().getConnectorId(), expected.getHandle().getConnectorId()); + assertEquals(actual.getHandle().getTransactionHandle(), expected.getHandle().getTransactionHandle()); + assertEquals(actual.getHandle().getConnectorHandle(), expected.getHandle().getConnectorHandle()); + } + + private static JsonCodec createJsonCodec(String catalogName) + { + Module module = binder -> { + SqlParser sqlParser = new SqlParser(); + FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); + binder.install(new JsonModule()); + binder.install(new HandleJsonModule()); + binder.bind(SqlParser.class).toInstance(sqlParser); + binder.bind(TypeManager.class).toInstance(functionAndTypeManager); + configBinder(binder).bindConfig(FeaturesConfig.class); + newSetBinder(binder, Type.class); + jsonBinder(binder).addSerializerBinding(Slice.class).to(SliceSerializer.class); + jsonBinder(binder).addDeserializerBinding(Slice.class).to(SliceDeserializer.class); + jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); + jsonBinder(binder).addSerializerBinding(Expression.class).to(Serialization.ExpressionSerializer.class); + jsonBinder(binder).addDeserializerBinding(Expression.class).to(Serialization.ExpressionDeserializer.class); + jsonBinder(binder).addDeserializerBinding(FunctionCall.class).to(Serialization.FunctionCallDeserializer.class); + jsonBinder(binder).addKeySerializerBinding(VariableReferenceExpression.class).to(Serialization.VariableReferenceExpressionSerializer.class); + jsonBinder(binder).addKeyDeserializerBinding(VariableReferenceExpression.class).to(Serialization.VariableReferenceExpressionDeserializer.class); + jsonCodecBinder(binder).bindJsonCodec(ExecuteProcedureHandle.class); + }; + Bootstrap app = new Bootstrap(ImmutableList.of(module)); + Injector injector = app + .doNotInitializeLogging() + .quiet() + .initialize(); + injector.getInstance(HandleResolver.class) + .addConnectorName(catalogName, new TestingHandleResolver()); + return injector.getInstance(new Key>() {}); + } + + private static ExecuteProcedureHandle createExecuteProcedureHandle(String catalogName, UUID uuid) + { + DistributedProcedureHandle distributedProcedureHandle = new DistributedProcedureHandle( + new ConnectorId(catalogName), + new TestingTransactionHandle(uuid), + TestingHandle.INSTANCE); + return new ExecuteProcedureHandle(distributedProcedureHandle, + new SchemaTableName("schema1", "table1"), + QualifiedObjectName.valueOf(catalogName, "schema1", "table1")); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java b/presto-main-base/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java index c5ba12a7e8938..9cf5cb9a0217d 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java +++ b/presto-main-base/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java @@ -39,6 +39,7 @@ import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.function.SqlFunction; import com.facebook.presto.spi.plan.PartitioningHandle; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.security.GrantInfo; import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.security.Privilege; @@ -435,6 +436,18 @@ public Optional finishDeleteWithOutput(Session session, throw new UnsupportedOperationException(); } + @Override + public DistributedProcedureHandle beginCallDistributedProcedure(Session session, QualifiedObjectName procedureName, TableHandle tableHandle, Object[] arguments) + { + throw new UnsupportedOperationException(); + } + + @Override + public void finishCallDistributedProcedure(Session session, DistributedProcedureHandle procedureHandle, QualifiedObjectName procedureName, Collection fragments) + { + throw new UnsupportedOperationException(); + } + @Override public TableHandle beginUpdate(Session session, TableHandle tableHandle, List updatedColumns) { diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java index 03909e16bc1c3..a2c5a387ec59e 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java @@ -57,6 +57,10 @@ import com.facebook.presto.spi.function.Parameter; import com.facebook.presto.spi.function.RoutineCharacteristics; import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.procedure.LocalProcedure; +import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; +import com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.spi.security.AllowAllAccessControl; import com.facebook.presto.spi.session.PropertyMetadata; @@ -64,6 +68,7 @@ import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.tree.NodeLocation; import com.facebook.presto.sql.tree.Statement; +import com.facebook.presto.testing.TestProcedureRegistry; import com.facebook.presto.testing.TestingAccessControlManager; import com.facebook.presto.testing.TestingMetadata; import com.facebook.presto.testing.TestingWarningCollector; @@ -74,6 +79,7 @@ import org.intellij.lang.annotations.Language; import org.testng.annotations.BeforeClass; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.function.Consumer; @@ -92,6 +98,8 @@ import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC; import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.SQL; import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT; +import static com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure.SCHEMA; +import static com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure.TABLE_NAME; import static com.facebook.presto.spi.session.PropertyMetadata.integerProperty; import static com.facebook.presto.spi.session.PropertyMetadata.stringProperty; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; @@ -145,8 +153,7 @@ public void setup() CatalogManager catalogManager = new CatalogManager(); transactionManager = createTestTransactionManager(catalogManager); accessControl = new TestingAccessControlManager(transactionManager); - - metadata = createTestMetadataManager(transactionManager); + metadata = createTestMetadataManager(transactionManager, new FeaturesConfig(), new FunctionsConfig(), new TestProcedureRegistry()); metadata.getFunctionAndTypeManager().registerBuiltInFunctions(ImmutableList.of(APPLY_FUNCTION)); @@ -175,6 +182,19 @@ public void setup() new PassThroughFunction(), new RequiredColumnsFunction())); + List arguments = new ArrayList<>(); + arguments.add(new Argument(SCHEMA, StandardTypes.VARCHAR)); + arguments.add(new Argument(TABLE_NAME, StandardTypes.VARCHAR)); + + List procedures = new ArrayList<>(); + procedures.add(new LocalProcedure("system", "procedure", arguments)); + procedures.add(new TableDataRewriteDistributedProcedure("system", "distributed_procedure", + arguments, + (session, transactionContext, procedureHandle, fragments) -> null, + (transactionContext, procedureHandle, fragments) -> {}, + TestProcedureRegistry.TestProcedureContext::new)); + metadata.getProcedureRegistry().addProcedures(SECOND_CONNECTOR_ID, procedures); + Catalog tpchTestCatalog = createTestingCatalog(TPCH_CATALOG, TPCH_CONNECTOR_ID); catalogManager.registerCatalog(tpchTestCatalog); metadata.getAnalyzePropertyManager().addProperties(TPCH_CONNECTOR_ID, tpchTestCatalog.getConnector(TPCH_CONNECTOR_ID).getAnalyzeProperties()); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java index 2e3f4d2534806..d3a28af7b6339 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java @@ -75,6 +75,7 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.ORDER_BY_MUST_BE_IN_AGGREGATE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.ORDER_BY_MUST_BE_IN_SELECT; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.PROCEDURE_NOT_FOUND; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.REFERENCE_TO_OUTPUT_ATTRIBUTE_WITHIN_ORDER_BY_AGGREGATION; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.REFERENCE_TO_OUTPUT_ATTRIBUTE_WITHIN_ORDER_BY_GROUPING; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.SAMPLE_PERCENTAGE_OUT_OF_RANGE; @@ -445,6 +446,20 @@ public void testWindowsNotAllowed() assertFails(NESTED_WINDOW, "SELECT 1 FROM (VALUES 1) HAVING count(*) OVER () > 1"); } + @Test + public void testCallProcedure() + { + Session session = testSessionBuilder() + .setCatalog("c2") + .setSchema("t4") + .build(); + assertFails(session, PROCEDURE_NOT_FOUND, "call system.not_exist_procedure('a', 'b')"); + assertFails(session, PROCEDURE_NOT_FOUND, "call system.procedure('a', 'b')"); + assertFails(session, MISSING_SCHEMA, "call system.distributed_procedure('s1', 't4')"); + assertFails(session, MISSING_TABLE, "call system.distributed_procedure('s2', 't9')"); + analyze(session, "call system.distributed_procedure('s2', 't4')"); + } + @Test public void testGrouping() { diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java index d9a83a1dc955b..d91b16d116649 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java @@ -15,9 +15,18 @@ import com.facebook.presto.Session; import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.execution.TestingPageSourceProvider; import com.facebook.presto.functionNamespace.FunctionNamespaceManagerPlugin; import com.facebook.presto.functionNamespace.json.JsonFileBasedFunctionNamespaceManagerFactory; +import com.facebook.presto.spi.ConnectorHandleResolver; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorContext; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.FilterNode; @@ -28,10 +37,15 @@ import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.SemiJoinNode; import com.facebook.presto.spi.plan.SortNode; +import com.facebook.presto.spi.plan.TableFinishNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.ValuesNode; +import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; +import com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.transaction.IsolationLevel; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.planner.assertions.BasePlanTest; import com.facebook.presto.sql.planner.assertions.ExpressionMatcher; @@ -40,12 +54,17 @@ import com.facebook.presto.sql.planner.optimizations.AddLocalExchanges; import com.facebook.presto.sql.planner.optimizations.PlanOptimizer; import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.testing.TestProcedureRegistry; +import com.facebook.presto.testing.TestingHandleResolver; +import com.facebook.presto.testing.TestingMetadata; +import com.facebook.presto.testing.TestingSplitManager; import com.facebook.presto.tests.QueryTemplate; import com.facebook.presto.util.MorePredicates; import com.google.common.collect.ImmutableList; @@ -53,8 +72,12 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.function.Consumer; import java.util.function.Predicate; @@ -78,6 +101,7 @@ import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; import static com.facebook.presto.common.predicate.Domain.singleValue; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.StandardTypes.VARCHAR; import static com.facebook.presto.common.type.VarcharType.createVarcharType; import static com.facebook.presto.spi.StandardErrorCode.INVALID_LIMIT_CLAUSE; import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL; @@ -88,6 +112,8 @@ import static com.facebook.presto.spi.plan.JoinType.INNER; import static com.facebook.presto.spi.plan.JoinType.LEFT; import static com.facebook.presto.spi.plan.JoinType.RIGHT; +import static com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure.SCHEMA; +import static com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure.TABLE_NAME; import static com.facebook.presto.sql.Optimizer.PlanStage.OPTIMIZED; import static com.facebook.presto.sql.Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED; import static com.facebook.presto.sql.TestExpressionInterpreter.AVG_UDAF_CPP; @@ -155,8 +181,105 @@ public class TestLogicalPlanner public void setup() { setupJsonFunctionNamespaceManager(this.getQueryRunner()); + + // Register catalog `test` with a distributed procedure `distributed_fun` + this.getQueryRunner().createCatalog("test", + new ConnectorFactory() + { + @Override + public String getName() + { + return "test"; + } + + @Override + public ConnectorHandleResolver getHandleResolver() + { + return new TestingHandleResolver(); + } + + @Override + public Connector create(String catalogName, Map config, ConnectorContext context) + { + List arguments = new ArrayList<>(); + arguments.add(new Argument(SCHEMA, VARCHAR)); + arguments.add(new Argument(TABLE_NAME, VARCHAR)); + Set procedures = new HashSet<>(); + procedures.add(new TableDataRewriteDistributedProcedure("system", "distributed_fun", + arguments, + (session, transactionContext, procedureHandle, fragments) -> null, + (transactionContext, procedureHandle, fragments) -> {}, + TestProcedureRegistry.TestProcedureContext::new)); + + return new Connector() + { + private final ConnectorMetadata metadata = new TestingMetadata(); + + @Override + public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly) + { + return new ConnectorTransactionHandle() + {}; + } + + @Override + public ConnectorPageSourceProvider getPageSourceProvider() + { + return new TestingPageSourceProvider(); + } + + @Override + public ConnectorMetadata getMetadata(ConnectorTransactionHandle transaction) + { + return metadata; + } + + @Override + public ConnectorSplitManager getSplitManager() + { + return new TestingSplitManager(ImmutableList.of()); + } + + @Override + public Set getProcedures() + { + return procedures; + } + }; + } + }, ImmutableMap.of()); } + @Test + public void testCallDistributedProcedure() + { + Session session = getQueryRunner().getDefaultSession(); + + // Call non-existed distributed procedure + assertPlanFailedWithException("call test.system.no_fun('a', 'b')", session, + format("Distributed procedure not registered: test.system.no_fun", "test", "system", "no_fun")); + + // Call distributed procedure on non-existed target table + assertPlanFailedWithException("call test.system.distributed_fun('tiny', 'notable')", session, + format("Table %s.%s.%s does not exist", session.getCatalog().get(), "tiny", "notable")); + + // Call distributed procedure on partitioned target table + assertDistributedPlan("call test.system.distributed_fun('tiny', 'orders')", + anyTree(node(TableFinishNode.class, + exchange(REMOTE_STREAMING, GATHER, + node(CallDistributedProcedureNode.class, + exchange(LOCAL, GATHER, + tableScan("orders"))))))); + + // Call distributed procedure on unPartitioned target table + assertDistributedPlan("call test.system.distributed_fun('tiny', 'customer')", + anyTree(node(TableFinishNode.class, + exchange(REMOTE_STREAMING, GATHER, + node(CallDistributedProcedureNode.class, + exchange(LOCAL, GATHER, + exchange(REMOTE_STREAMING, REPARTITION, + tableScan("customer")))))))); + } @Test public void testAnalyze() { diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index e5838185f495f..d882fe7e54a5f 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -43,6 +43,7 @@ import com.facebook.presto.spi.plan.SemiJoinNode; import com.facebook.presto.spi.plan.SortNode; import com.facebook.presto.spi.plan.SpatialJoinNode; +import com.facebook.presto.spi.plan.TableFinishNode; import com.facebook.presto.spi.plan.TableWriterNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; @@ -60,6 +61,7 @@ import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; @@ -691,6 +693,16 @@ public static PlanMatchPattern enforceSingleRow(PlanMatchPattern source) return node(EnforceSingleRowNode.class, source); } + public static PlanMatchPattern callDistributedProcedure(PlanMatchPattern source) + { + return node(CallDistributedProcedureNode.class, source); + } + + public static PlanMatchPattern tableFinish(PlanMatchPattern source) + { + return node(TableFinishNode.class, source); + } + public static PlanMatchPattern tableWriter(List columns, List columnNames, PlanMatchPattern source) { return node(TableWriterNode.class, source).with(new TableWriterMatcher(columns, columnNames)); diff --git a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java index f5d2543048a46..b69a83ecb828a 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java @@ -99,6 +99,7 @@ import com.facebook.presto.memory.NodeMemoryConfig; import com.facebook.presto.memory.ReservedSystemMemoryConfig; import com.facebook.presto.metadata.AnalyzePropertyManager; +import com.facebook.presto.metadata.BuiltInProcedureRegistry; import com.facebook.presto.metadata.CatalogManager; import com.facebook.presto.metadata.ColumnPropertyManager; import com.facebook.presto.metadata.DiscoveryNodeManager; @@ -171,6 +172,7 @@ import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.plan.SimplePlanFragment; import com.facebook.presto.spi.plan.SimplePlanFragmentSerde; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.DomainTranslator; import com.facebook.presto.spi.relation.PredicateCompiler; @@ -654,6 +656,8 @@ public ListeningExecutorService createResourceManagerExecutor(ResourceManagerCon binder.bind(FunctionAndTypeManager.class).in(Scopes.SINGLETON); binder.bind(TableFunctionRegistry.class).in(Scopes.SINGLETON); binder.bind(MetadataManager.class).in(Scopes.SINGLETON); + binder.bind(BuiltInProcedureRegistry.class).in(Scopes.SINGLETON); + binder.bind(ProcedureRegistry.class).to(BuiltInProcedureRegistry.class).in(Scopes.SINGLETON); if (serverConfig.isCatalogServerEnabled() && serverConfig.isCoordinator()) { binder.bind(RemoteMetadataManager.class).in(Scopes.SINGLETON); diff --git a/presto-native-execution/presto_cpp/main/connectors/IcebergPrestoToVeloxConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/IcebergPrestoToVeloxConnector.cpp index 303ee763d4a60..6b363536fb544 100644 --- a/presto-native-execution/presto_cpp/main/connectors/IcebergPrestoToVeloxConnector.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/IcebergPrestoToVeloxConnector.cpp @@ -270,6 +270,34 @@ IcebergPrestoToVeloxConnector::toVeloxTableHandle( typeParser); } +std::unique_ptr +IcebergPrestoToVeloxConnector::toVeloxInsertTableHandle( + const protocol::ExecuteProcedureHandle* executeProcedureHandle, + const TypeParser& typeParser) const { + auto icebergDistributedProcedureHandle = std::dynamic_pointer_cast< + protocol::iceberg::IcebergDistributedProcedureHandle>( + executeProcedureHandle->handle.connectorHandle); + + VELOX_CHECK_NOT_NULL( + icebergDistributedProcedureHandle, + "Unexpected call distributed procedure handle type {}", + executeProcedureHandle->handle.connectorHandle->_type); + + const auto inputColumns = toHiveColumns( + icebergDistributedProcedureHandle->inputColumns, typeParser); + + return std::make_unique< + velox::connector::hive::iceberg::IcebergInsertTableHandle>( + inputColumns, + std::make_shared( + fmt::format("{}/data", icebergDistributedProcedureHandle->outputPath), + fmt::format("{}/data", icebergDistributedProcedureHandle->outputPath), + velox::connector::hive::LocationHandle::TableType::kExisting), + toVeloxFileFormat(icebergDistributedProcedureHandle->fileFormat), + std::optional(toFileCompressionKind( + icebergDistributedProcedureHandle->compressionCodec))); +} + std::unique_ptr IcebergPrestoToVeloxConnector::createConnectorProtocol() const { return std::make_unique(); diff --git a/presto-native-execution/presto_cpp/main/connectors/IcebergPrestoToVeloxConnector.h b/presto-native-execution/presto_cpp/main/connectors/IcebergPrestoToVeloxConnector.h index c6b6b8850fa53..07567b6b78d67 100644 --- a/presto-native-execution/presto_cpp/main/connectors/IcebergPrestoToVeloxConnector.h +++ b/presto-native-execution/presto_cpp/main/connectors/IcebergPrestoToVeloxConnector.h @@ -39,6 +39,11 @@ class IcebergPrestoToVeloxConnector final : public PrestoToVeloxConnector { const TypeParser& typeParser, velox::connector::ColumnHandleMap& assignments) const final; + std::unique_ptr + toVeloxInsertTableHandle( + const protocol::ExecuteProcedureHandle* executeProcedureHandle, + const TypeParser& typeParser) const final; + std::unique_ptr createConnectorProtocol() const final; diff --git a/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.h b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.h index a3fce5054dfce..bd48bcf3c5e41 100644 --- a/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.h +++ b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.h @@ -85,6 +85,14 @@ class PrestoToVeloxConnector { return {}; } + [[nodiscard]] virtual std::unique_ptr< + velox::connector::ConnectorInsertTableHandle> + toVeloxInsertTableHandle( + const protocol::ExecuteProcedureHandle* executeProcedureHandle, + const TypeParser& typeParser) const { + return {}; + } + [[nodiscard]] std::unique_ptr createVeloxPartitionFunctionSpec( const protocol::ConnectorPartitioningHandle* partitioningHandle, diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp index 0285c16ea48c7..bac6fa9cc16fa 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp @@ -1539,6 +1539,59 @@ VeloxQueryPlanConverterBase::toVeloxQueryPlan( sourceVeloxPlan); } +std::shared_ptr +VeloxQueryPlanConverterBase::toVeloxQueryPlan( + const std::shared_ptr& node, + const std::shared_ptr& tableWriteInfo, + const protocol::TaskId& taskId) { + const auto executeProcedureHandle = + std::dynamic_pointer_cast( + tableWriteInfo->writerTarget); + + if (!executeProcedureHandle) { + VELOX_UNSUPPORTED( + "Unsupported execute procedure handle: {}", + toJsonString(tableWriteInfo->writerTarget)); + } + + std::string connectorId = executeProcedureHandle->handle.connectorId; + auto& connector = getPrestoToVeloxConnector( + executeProcedureHandle->handle.connectorHandle->_type); + auto veloxHandle = connector.toVeloxInsertTableHandle( + executeProcedureHandle.get(), typeParser_); + auto connectorInsertHandle = std::shared_ptr(std::move(veloxHandle)); + + if (!connectorInsertHandle) { + VELOX_UNSUPPORTED( + "Unsupported execute procedure handle: {}", + toJsonString(tableWriteInfo->writerTarget)); + } + + auto insertTableHandle = std::make_shared( + connectorId, connectorInsertHandle); + + const auto outputType = toRowType( + generateOutputVariables( + {node->rowCountVariable, + node->fragmentVariable, + node->tableCommitContextVariable}, + nullptr), + typeParser_); + const auto sourceVeloxPlan = + toVeloxQueryPlan(node->source, tableWriteInfo, taskId); + + return std::make_shared( + node->id, + toRowType(node->columns, typeParser_), + node->columnNames, + std::nullopt, + std::move(insertTableHandle), + node->partitioningScheme != nullptr, + outputType, + getCommitStrategy(), + sourceVeloxPlan); +} + std::shared_ptr VeloxQueryPlanConverterBase::toVeloxQueryPlan( const std::shared_ptr& node, @@ -1927,6 +1980,10 @@ core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan( std::dynamic_pointer_cast(node)) { return toVeloxQueryPlan(tableWriter, tableWriteInfo, taskId); } + if (auto callDistributedProcedure = std::dynamic_pointer_cast< + const protocol::CallDistributedProcedureNode>(node)) { + return toVeloxQueryPlan(callDistributedProcedure, tableWriteInfo, taskId); + } if (auto deleteNode = std::dynamic_pointer_cast(node)) { return toVeloxQueryPlan(deleteNode, tableWriteInfo, taskId); diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.h b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.h index c21e6f0ea35d8..00a13f811e5fe 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.h +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.h @@ -160,6 +160,11 @@ class VeloxQueryPlanConverterBase { const std::shared_ptr& tableWriteInfo, const protocol::TaskId& taskId); + std::shared_ptr toVeloxQueryPlan( + const std::shared_ptr& node, + const std::shared_ptr& tableWriteInfo, + const protocol::TaskId& taskId); + std::shared_ptr toVeloxQueryPlan( const std::shared_ptr& node, const std::shared_ptr& tableWriteInfo, diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/hive/HiveConnectorProtocol.h b/presto-native-execution/presto_cpp/presto_protocol/connector/hive/HiveConnectorProtocol.h index be4f411c061e1..ee3280cf8a43d 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/hive/HiveConnectorProtocol.h +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/hive/HiveConnectorProtocol.h @@ -26,5 +26,6 @@ using HiveConnectorProtocol = ConnectorProtocolTemplate< HivePartitioningHandle, HiveTransactionHandle, NotImplemented, + NotImplemented, NotImplemented>; } // namespace facebook::presto::protocol::hive diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/IcebergConnectorProtocol.h b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/IcebergConnectorProtocol.h index 00ce0c64ca654..e8cfab306dfd4 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/IcebergConnectorProtocol.h +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/IcebergConnectorProtocol.h @@ -27,6 +27,7 @@ using IcebergConnectorProtocol = ConnectorProtocolTemplate< IcebergSplit, NotImplemented, hive::HiveTransactionHandle, + IcebergDistributedProcedureHandle, NotImplemented, NotImplemented>; diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp index 6d03a5ce52b12..f9a3d754430d6 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp @@ -741,6 +741,147 @@ void from_json(const json& j, PrestoIcebergPartitionSpec& p) { } } // namespace facebook::presto::protocol::iceberg namespace facebook::presto::protocol::iceberg { +IcebergDistributedProcedureHandle:: + IcebergDistributedProcedureHandle() noexcept { + _type = "hive-iceberg"; +} + +void to_json(json& j, const IcebergDistributedProcedureHandle& p) { + j = json::object(); + j["@type"] = "hive-iceberg"; + to_json_key( + j, + "schemaName", + p.schemaName, + "IcebergDistributedProcedureHandle", + "String", + "schemaName"); + to_json_key( + j, + "tableName", + p.tableName, + "IcebergDistributedProcedureHandle", + "IcebergTableName", + "tableName"); + to_json_key( + j, + "schema", + p.schema, + "IcebergDistributedProcedureHandle", + "PrestoIcebergSchema", + "schema"); + to_json_key( + j, + "partitionSpec", + p.partitionSpec, + "IcebergDistributedProcedureHandle", + "PrestoIcebergPartitionSpec", + "partitionSpec"); + to_json_key( + j, + "inputColumns", + p.inputColumns, + "IcebergDistributedProcedureHandle", + "List", + "inputColumns"); + to_json_key( + j, + "outputPath", + p.outputPath, + "IcebergDistributedProcedureHandle", + "String", + "outputPath"); + to_json_key( + j, + "fileFormat", + p.fileFormat, + "IcebergDistributedProcedureHandle", + "FileFormat", + "fileFormat"); + to_json_key( + j, + "compressionCodec", + p.compressionCodec, + "IcebergDistributedProcedureHandle", + "HiveCompressionCodec", + "compressionCodec"); + to_json_key( + j, + "storageProperties", + p.storageProperties, + "IcebergDistributedProcedureHandle", + "Map", + "storageProperties"); +} + +void from_json(const json& j, IcebergDistributedProcedureHandle& p) { + p._type = j["@type"]; + from_json_key( + j, + "schemaName", + p.schemaName, + "IcebergDistributedProcedureHandle", + "String", + "schemaName"); + from_json_key( + j, + "tableName", + p.tableName, + "IcebergDistributedProcedureHandle", + "IcebergTableName", + "tableName"); + from_json_key( + j, + "schema", + p.schema, + "IcebergDistributedProcedureHandle", + "PrestoIcebergSchema", + "schema"); + from_json_key( + j, + "partitionSpec", + p.partitionSpec, + "IcebergDistributedProcedureHandle", + "PrestoIcebergPartitionSpec", + "partitionSpec"); + from_json_key( + j, + "inputColumns", + p.inputColumns, + "IcebergDistributedProcedureHandle", + "List", + "inputColumns"); + from_json_key( + j, + "outputPath", + p.outputPath, + "IcebergDistributedProcedureHandle", + "String", + "outputPath"); + from_json_key( + j, + "fileFormat", + p.fileFormat, + "IcebergDistributedProcedureHandle", + "FileFormat", + "fileFormat"); + from_json_key( + j, + "compressionCodec", + p.compressionCodec, + "IcebergDistributedProcedureHandle", + "HiveCompressionCodec", + "compressionCodec"); + from_json_key( + j, + "storageProperties", + p.storageProperties, + "IcebergDistributedProcedureHandle", + "Map", + "storageProperties"); +} +} // namespace facebook::presto::protocol::iceberg +namespace facebook::presto::protocol::iceberg { void to_json(json& j, const SortField& p) { j = json::object(); diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.h b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.h index d82e660ab42a8..a04d9acf0d2c9 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.h +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.h @@ -183,6 +183,27 @@ struct PrestoIcebergPartitionSpec { void to_json(json& j, const PrestoIcebergPartitionSpec& p); void from_json(const json& j, PrestoIcebergPartitionSpec& p); } // namespace facebook::presto::protocol::iceberg +// IcebergDistributedProcedureHandle is special since it needs an usage of +// hive::. + +namespace facebook::presto::protocol::iceberg { +struct IcebergDistributedProcedureHandle + : public ConnectorDistributedProcedureHandle { + String schemaName = {}; + IcebergTableName tableName = {}; + PrestoIcebergSchema schema = {}; + PrestoIcebergPartitionSpec partitionSpec = {}; + List inputColumns = {}; + String outputPath = {}; + FileFormat fileFormat = {}; + hive::HiveCompressionCodec compressionCodec = {}; + Map storageProperties = {}; + + IcebergDistributedProcedureHandle() noexcept; +}; +void to_json(json& j, const IcebergDistributedProcedureHandle& p); +void from_json(const json& j, IcebergDistributedProcedureHandle& p); +} // namespace facebook::presto::protocol::iceberg namespace facebook::presto::protocol::iceberg { struct SortField { int sourceColumnId = {}; diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.yml b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.yml index 1a8be3d90b3b6..9ceec008fc7c7 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.yml +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.yml @@ -37,6 +37,11 @@ AbstractClasses: subclasses: - { name: IcebergInsertTableHandle, key: hive-iceberg } + ConnectorDistributedProcedureHandle: + super: JsonEncodedSubclass + subclasses: + - { name: IcebergDistributedProcedureHandle, key: hive-iceberg } + ConnectorTableLayoutHandle: super: JsonEncodedSubclass subclasses: @@ -62,6 +67,7 @@ JavaClasses: - presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableLayoutHandle.java - presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergOutputTableHandle.java - presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergInsertTableHandle.java + - presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergDistributedProcedureHandle.java - presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergColumnHandle.java - presto-iceberg/src/main/java/com/facebook/presto/iceberg/ColumnIdentity.java - presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPartitionField.java diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/special/IcebergDistributedProcedureHandle.hpp.inc b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/special/IcebergDistributedProcedureHandle.hpp.inc new file mode 100644 index 0000000000000..c5ab91c0416f8 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/special/IcebergDistributedProcedureHandle.hpp.inc @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// IcebergDistributedProcedureHandle is special since it needs an usage of +// hive::. + +namespace facebook::presto::protocol::iceberg { +struct IcebergDistributedProcedureHandle + : public ConnectorDistributedProcedureHandle { + String schemaName = {}; + IcebergTableName tableName = {}; + PrestoIcebergSchema schema = {}; + PrestoIcebergPartitionSpec partitionSpec = {}; + List inputColumns = {}; + String outputPath = {}; + FileFormat fileFormat = {}; + hive::HiveCompressionCodec compressionCodec = {}; + Map storageProperties = {}; + + IcebergDistributedProcedureHandle() noexcept; +}; +void to_json(json& j, const IcebergDistributedProcedureHandle& p); +void from_json(const json& j, IcebergDistributedProcedureHandle& p); +} // namespace facebook::presto::protocol::iceberg diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/TpchConnectorProtocol.h b/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/TpchConnectorProtocol.h index bca3818f33cd5..372e124f8e8ae 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/TpchConnectorProtocol.h +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/TpchConnectorProtocol.h @@ -29,6 +29,7 @@ using TpchConnectorProtocol = ConnectorProtocolTemplate< TpchPartitioningHandle, TpchTransactionHandle, NotImplemented, + NotImplemented, NotImplemented>; } // namespace facebook::presto::protocol::tpch diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/ConnectorProtocol.h b/presto-native-execution/presto_cpp/presto_protocol/core/ConnectorProtocol.h index d51a20dc496a0..130944584c90f 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/ConnectorProtocol.h +++ b/presto-native-execution/presto_cpp/presto_protocol/core/ConnectorProtocol.h @@ -75,6 +75,13 @@ class ConnectorProtocol { const std::string& thrift, std::shared_ptr& proto) const = 0; + virtual void to_json( + json& j, + const std::shared_ptr& p) const = 0; + virtual void from_json( + const json& j, + std::shared_ptr& p) const = 0; + virtual void to_json( json& j, const std::shared_ptr& p) const = 0; @@ -152,6 +159,7 @@ template < typename ConnectorSplitType = NotImplemented, typename ConnectorPartitioningHandleType = NotImplemented, typename ConnectorTransactionHandleType = NotImplemented, + typename ConnectorDistributedProcedureHandleType = NotImplemented, typename ConnectorDeleteTableHandleType = NotImplemented, typename ConnectorIndexHandleType = NotImplemented> class ConnectorProtocolTemplate final : public ConnectorProtocol { @@ -220,6 +228,18 @@ class ConnectorProtocolTemplate final : public ConnectorProtocol { deserializeTemplate(thrift, proto); } + void to_json( + json& j, + const std::shared_ptr& p) + const final { + to_json_template(j, p); + } + void from_json( + const json& j, + std::shared_ptr& p) const final { + from_json_template(j, p); + } + void to_json(json& j, const std::shared_ptr& p) const final { to_json_template(j, p); @@ -403,6 +423,7 @@ using SystemConnectorProtocol = ConnectorProtocolTemplate< SystemPartitioningHandle, SystemTransactionHandle, NotImplemented, + NotImplemented, NotImplemented>; } // namespace facebook::presto::protocol diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp index 874e2275577c1..264aab91700ac 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp @@ -774,6 +774,11 @@ void to_json(json& j, const std::shared_ptr& p) { j = *std::static_pointer_cast(p); return; } + if (type == + "com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode") { + j = *std::static_pointer_cast(p); + return; + } throw TypeError(type + " no abstract type PlanNode "); } @@ -970,6 +975,14 @@ void from_json(const json& j, std::shared_ptr& p) { p = std::static_pointer_cast(k); return; } + if (type == + "com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } throw TypeError(type + " no abstract type PlanNode "); } @@ -2290,6 +2303,10 @@ void to_json(json& j, const std::shared_ptr& p) { j = *std::static_pointer_cast(p); return; } + if (type == "ExecuteProcedureHandle") { + j = *std::static_pointer_cast(p); + return; + } throw TypeError(type + " no abstract type ExecutionWriterTarget "); } @@ -2334,6 +2351,13 @@ void from_json(const json& j, std::shared_ptr& p) { p = std::static_pointer_cast(k); return; } + if (type == "ExecuteProcedureHandle") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } throw TypeError(type + " no abstract type ExecutionWriterTarget "); } @@ -2955,209 +2979,595 @@ void from_json(const json& j, CacheQuotaRequirement& p) { } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +// Loosly copied this here from NLOHMANN_JSON_SERIALIZE_ENUM() -void to_json(json& j, const Column& p) { - j = json::object(); - to_json_key(j, "name", p.name, "Column", "String", "name"); - to_json_key(j, "type", p.type, "Column", "String", "type"); -} - -void from_json(const json& j, Column& p) { - from_json_key(j, "name", p.name, "Column", "String", "name"); - from_json_key(j, "type", p.type, "Column", "String", "type"); -} -} // namespace facebook::presto::protocol -namespace facebook::presto::protocol { - -void to_json(json& j, const Block& p) { - j = p.data; -} - -void from_json(const json& j, Block& p) { - p.data = std::string(j); -} -} // namespace facebook::presto::protocol -namespace facebook::presto::protocol { -ConstantExpression::ConstantExpression() noexcept { - _type = "constant"; -} - -void to_json(json& j, const ConstantExpression& p) { - j = json::object(); - j["@type"] = "constant"; - to_json_key( - j, - "valueBlock", - p.valueBlock, - "ConstantExpression", - "Block", - "valueBlock"); - to_json_key(j, "type", p.type, "ConstantExpression", "Type", "type"); +// NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays +static const std::pair ExchangeEncoding_enum_table[] = + { // NOLINT: cert-err58-cpp + {ExchangeEncoding::COLUMNAR, "COLUMNAR"}, + {ExchangeEncoding::ROW_WISE, "ROW_WISE"}}; +void to_json(json& j, const ExchangeEncoding& e) { + static_assert( + std::is_enum::value, + "ExchangeEncoding must be an enum!"); + const auto* it = std::find_if( + std::begin(ExchangeEncoding_enum_table), + std::end(ExchangeEncoding_enum_table), + [e](const std::pair& ej_pair) -> bool { + return ej_pair.first == e; + }); + j = ((it != std::end(ExchangeEncoding_enum_table)) + ? it + : std::begin(ExchangeEncoding_enum_table)) + ->second; } - -void from_json(const json& j, ConstantExpression& p) { - p._type = j["@type"]; - from_json_key( - j, - "valueBlock", - p.valueBlock, - "ConstantExpression", - "Block", - "valueBlock"); - from_json_key(j, "type", p.type, "ConstantExpression", "Type", "type"); +void from_json(const json& j, ExchangeEncoding& e) { + static_assert( + std::is_enum::value, + "ExchangeEncoding must be an enum!"); + const auto* it = std::find_if( + std::begin(ExchangeEncoding_enum_table), + std::end(ExchangeEncoding_enum_table), + [&j](const std::pair& ej_pair) -> bool { + return ej_pair.second == j; + }); + e = ((it != std::end(ExchangeEncoding_enum_table)) + ? it + : std::begin(ExchangeEncoding_enum_table)) + ->first; } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { -void to_json(json& j, const std::shared_ptr& p) { +void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { return; } String type = p->_type; + + if (type == "$remote") { + j = *std::static_pointer_cast(p); + return; + } getConnectorProtocol(type).to_json(j, p); } -void from_json(const json& j, std::shared_ptr& p) { +void from_json(const json& j, std::shared_ptr& p) { String type; try { type = p->getSubclassKey(j); } catch (json::parse_error& e) { - throw ParseError( - std::string(e.what()) + - " ConnectorOutputTableHandle ConnectorOutputTableHandle"); + throw ParseError(std::string(e.what()) + " ConnectorPartitioningHandle"); + } + + if (type == "$remote") { + auto k = std::make_shared(); + j.get_to(*k); + p = k; + return; } getConnectorProtocol(type).from_json(j, p); } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { -void to_json(json& j, const OutputTableHandle& p) { +void to_json(json& j, const PartitioningHandle& p) { j = json::object(); to_json_key( j, "connectorId", p.connectorId, - "OutputTableHandle", + "PartitioningHandle", "ConnectorId", "connectorId"); to_json_key( j, "transactionHandle", p.transactionHandle, - "OutputTableHandle", + "PartitioningHandle", "ConnectorTransactionHandle", "transactionHandle"); to_json_key( j, "connectorHandle", p.connectorHandle, - "OutputTableHandle", - "ConnectorOutputTableHandle", + "PartitioningHandle", + "ConnectorPartitioningHandle", "connectorHandle"); } -void from_json(const json& j, OutputTableHandle& p) { +void from_json(const json& j, PartitioningHandle& p) { from_json_key( j, "connectorId", p.connectorId, - "OutputTableHandle", + "PartitioningHandle", "ConnectorId", "connectorId"); from_json_key( j, "transactionHandle", p.transactionHandle, - "OutputTableHandle", + "PartitioningHandle", "ConnectorTransactionHandle", "transactionHandle"); from_json_key( j, "connectorHandle", p.connectorHandle, - "OutputTableHandle", - "ConnectorOutputTableHandle", + "PartitioningHandle", + "ConnectorPartitioningHandle", "connectorHandle"); } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { -void to_json(json& j, const SchemaTableName& p) { - j = json::object(); - to_json_key(j, "schema", p.schema, "SchemaTableName", "String", "schema"); - to_json_key(j, "table", p.table, "SchemaTableName", "String", "table"); -} - -void from_json(const json& j, SchemaTableName& p) { - from_json_key(j, "schema", p.schema, "SchemaTableName", "String", "schema"); - from_json_key(j, "table", p.table, "SchemaTableName", "String", "table"); -} -} // namespace facebook::presto::protocol -namespace facebook::presto::protocol { -CreateHandle::CreateHandle() noexcept { - _type = "CreateHandle"; -} - -void to_json(json& j, const CreateHandle& p) { +void to_json(json& j, const Partitioning& p) { j = json::object(); - j["@type"] = "CreateHandle"; to_json_key( - j, "handle", p.handle, "CreateHandle", "OutputTableHandle", "handle"); + j, "handle", p.handle, "Partitioning", "PartitioningHandle", "handle"); to_json_key( j, - "schemaTableName", - p.schemaTableName, - "CreateHandle", - "SchemaTableName", - "schemaTableName"); + "arguments", + p.arguments, + "Partitioning", + "List>", + "arguments"); } -void from_json(const json& j, CreateHandle& p) { - p._type = j["@type"]; +void from_json(const json& j, Partitioning& p) { from_json_key( - j, "handle", p.handle, "CreateHandle", "OutputTableHandle", "handle"); + j, "handle", p.handle, "Partitioning", "PartitioningHandle", "handle"); from_json_key( j, - "schemaTableName", - p.schemaTableName, - "CreateHandle", - "SchemaTableName", - "schemaTableName"); + "arguments", + p.arguments, + "Partitioning", + "List>", + "arguments"); } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { -void to_json(json& j, const DataOrganizationSpecification& p) { +void to_json(json& j, const PartitioningScheme& p) { j = json::object(); to_json_key( j, - "partitionBy", - p.partitionBy, - "DataOrganizationSpecification", + "partitioning", + p.partitioning, + "PartitioningScheme", + "Partitioning", + "partitioning"); + to_json_key( + j, + "outputLayout", + p.outputLayout, + "PartitioningScheme", "List", - "partitionBy"); + "outputLayout"); to_json_key( j, - "orderingScheme", - p.orderingScheme, - "DataOrganizationSpecification", - "OrderingScheme", - "orderingScheme"); + "hashColumn", + p.hashColumn, + "PartitioningScheme", + "VariableReferenceExpression", + "hashColumn"); + to_json_key( + j, + "replicateNullsAndAny", + p.replicateNullsAndAny, + "PartitioningScheme", + "bool", + "replicateNullsAndAny"); + to_json_key( + j, + "scaleWriters", + p.scaleWriters, + "PartitioningScheme", + "bool", + "scaleWriters"); + to_json_key( + j, + "encoding", + p.encoding, + "PartitioningScheme", + "ExchangeEncoding", + "encoding"); + to_json_key( + j, + "bucketToPartition", + p.bucketToPartition, + "PartitioningScheme", + "List", + "bucketToPartition"); } -void from_json(const json& j, DataOrganizationSpecification& p) { +void from_json(const json& j, PartitioningScheme& p) { from_json_key( j, - "partitionBy", - p.partitionBy, - "DataOrganizationSpecification", + "partitioning", + p.partitioning, + "PartitioningScheme", + "Partitioning", + "partitioning"); + from_json_key( + j, + "outputLayout", + p.outputLayout, + "PartitioningScheme", "List", - "partitionBy"); + "outputLayout"); from_json_key( j, - "orderingScheme", - p.orderingScheme, - "DataOrganizationSpecification", - "OrderingScheme", + "hashColumn", + p.hashColumn, + "PartitioningScheme", + "VariableReferenceExpression", + "hashColumn"); + from_json_key( + j, + "replicateNullsAndAny", + p.replicateNullsAndAny, + "PartitioningScheme", + "bool", + "replicateNullsAndAny"); + from_json_key( + j, + "scaleWriters", + p.scaleWriters, + "PartitioningScheme", + "bool", + "scaleWriters"); + from_json_key( + j, + "encoding", + p.encoding, + "PartitioningScheme", + "ExchangeEncoding", + "encoding"); + from_json_key( + j, + "bucketToPartition", + p.bucketToPartition, + "PartitioningScheme", + "List", + "bucketToPartition"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +CallDistributedProcedureNode::CallDistributedProcedureNode() noexcept { + _type = "com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode"; +} + +void to_json(json& j, const CallDistributedProcedureNode& p) { + j = json::object(); + j["@type"] = + "com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode"; + to_json_key( + j, "id", p.id, "CallDistributedProcedureNode", "PlanNodeId", "id"); + to_json_key( + j, + "source", + p.source, + "CallDistributedProcedureNode", + "PlanNode", + "source"); + to_json_key( + j, + "rowCountVariable", + p.rowCountVariable, + "CallDistributedProcedureNode", + "VariableReferenceExpression", + "rowCountVariable"); + to_json_key( + j, + "fragmentVariable", + p.fragmentVariable, + "CallDistributedProcedureNode", + "VariableReferenceExpression", + "fragmentVariable"); + to_json_key( + j, + "tableCommitContextVariable", + p.tableCommitContextVariable, + "CallDistributedProcedureNode", + "VariableReferenceExpression", + "tableCommitContextVariable"); + to_json_key( + j, + "columns", + p.columns, + "CallDistributedProcedureNode", + "List", + "columns"); + to_json_key( + j, + "columnNames", + p.columnNames, + "CallDistributedProcedureNode", + "List", + "columnNames"); + to_json_key( + j, + "notNullColumnVariables", + p.notNullColumnVariables, + "CallDistributedProcedureNode", + "List", + "notNullColumnVariables"); + to_json_key( + j, + "partitioningScheme", + p.partitioningScheme, + "CallDistributedProcedureNode", + "PartitioningScheme", + "partitioningScheme"); +} + +void from_json(const json& j, CallDistributedProcedureNode& p) { + p._type = j["@type"]; + from_json_key( + j, "id", p.id, "CallDistributedProcedureNode", "PlanNodeId", "id"); + from_json_key( + j, + "source", + p.source, + "CallDistributedProcedureNode", + "PlanNode", + "source"); + from_json_key( + j, + "rowCountVariable", + p.rowCountVariable, + "CallDistributedProcedureNode", + "VariableReferenceExpression", + "rowCountVariable"); + from_json_key( + j, + "fragmentVariable", + p.fragmentVariable, + "CallDistributedProcedureNode", + "VariableReferenceExpression", + "fragmentVariable"); + from_json_key( + j, + "tableCommitContextVariable", + p.tableCommitContextVariable, + "CallDistributedProcedureNode", + "VariableReferenceExpression", + "tableCommitContextVariable"); + from_json_key( + j, + "columns", + p.columns, + "CallDistributedProcedureNode", + "List", + "columns"); + from_json_key( + j, + "columnNames", + p.columnNames, + "CallDistributedProcedureNode", + "List", + "columnNames"); + from_json_key( + j, + "notNullColumnVariables", + p.notNullColumnVariables, + "CallDistributedProcedureNode", + "List", + "notNullColumnVariables"); + from_json_key( + j, + "partitioningScheme", + p.partitioningScheme, + "CallDistributedProcedureNode", + "PartitioningScheme", + "partitioningScheme"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { + +void to_json(json& j, const Column& p) { + j = json::object(); + to_json_key(j, "name", p.name, "Column", "String", "name"); + to_json_key(j, "type", p.type, "Column", "String", "type"); +} + +void from_json(const json& j, Column& p) { + from_json_key(j, "name", p.name, "Column", "String", "name"); + from_json_key(j, "type", p.type, "Column", "String", "type"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { + +void to_json(json& j, const Block& p) { + j = p.data; +} + +void from_json(const json& j, Block& p) { + p.data = std::string(j); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +ConstantExpression::ConstantExpression() noexcept { + _type = "constant"; +} + +void to_json(json& j, const ConstantExpression& p) { + j = json::object(); + j["@type"] = "constant"; + to_json_key( + j, + "valueBlock", + p.valueBlock, + "ConstantExpression", + "Block", + "valueBlock"); + to_json_key(j, "type", p.type, "ConstantExpression", "Type", "type"); +} + +void from_json(const json& j, ConstantExpression& p) { + p._type = j["@type"]; + from_json_key( + j, + "valueBlock", + p.valueBlock, + "ConstantExpression", + "Block", + "valueBlock"); + from_json_key(j, "type", p.type, "ConstantExpression", "Type", "type"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +void to_json(json& j, const std::shared_ptr& p) { + if (p == nullptr) { + return; + } + String type = p->_type; + getConnectorProtocol(type).to_json(j, p); +} + +void from_json(const json& j, std::shared_ptr& p) { + String type; + try { + type = p->getSubclassKey(j); + } catch (json::parse_error& e) { + throw ParseError( + std::string(e.what()) + + " ConnectorOutputTableHandle ConnectorOutputTableHandle"); + } + getConnectorProtocol(type).from_json(j, p); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { + +void to_json(json& j, const OutputTableHandle& p) { + j = json::object(); + to_json_key( + j, + "connectorId", + p.connectorId, + "OutputTableHandle", + "ConnectorId", + "connectorId"); + to_json_key( + j, + "transactionHandle", + p.transactionHandle, + "OutputTableHandle", + "ConnectorTransactionHandle", + "transactionHandle"); + to_json_key( + j, + "connectorHandle", + p.connectorHandle, + "OutputTableHandle", + "ConnectorOutputTableHandle", + "connectorHandle"); +} + +void from_json(const json& j, OutputTableHandle& p) { + from_json_key( + j, + "connectorId", + p.connectorId, + "OutputTableHandle", + "ConnectorId", + "connectorId"); + from_json_key( + j, + "transactionHandle", + p.transactionHandle, + "OutputTableHandle", + "ConnectorTransactionHandle", + "transactionHandle"); + from_json_key( + j, + "connectorHandle", + p.connectorHandle, + "OutputTableHandle", + "ConnectorOutputTableHandle", + "connectorHandle"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { + +void to_json(json& j, const SchemaTableName& p) { + j = json::object(); + to_json_key(j, "schema", p.schema, "SchemaTableName", "String", "schema"); + to_json_key(j, "table", p.table, "SchemaTableName", "String", "table"); +} + +void from_json(const json& j, SchemaTableName& p) { + from_json_key(j, "schema", p.schema, "SchemaTableName", "String", "schema"); + from_json_key(j, "table", p.table, "SchemaTableName", "String", "table"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +CreateHandle::CreateHandle() noexcept { + _type = "CreateHandle"; +} + +void to_json(json& j, const CreateHandle& p) { + j = json::object(); + j["@type"] = "CreateHandle"; + to_json_key( + j, "handle", p.handle, "CreateHandle", "OutputTableHandle", "handle"); + to_json_key( + j, + "schemaTableName", + p.schemaTableName, + "CreateHandle", + "SchemaTableName", + "schemaTableName"); +} + +void from_json(const json& j, CreateHandle& p) { + p._type = j["@type"]; + from_json_key( + j, "handle", p.handle, "CreateHandle", "OutputTableHandle", "handle"); + from_json_key( + j, + "schemaTableName", + p.schemaTableName, + "CreateHandle", + "SchemaTableName", + "schemaTableName"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { + +void to_json(json& j, const DataOrganizationSpecification& p) { + j = json::object(); + to_json_key( + j, + "partitionBy", + p.partitionBy, + "DataOrganizationSpecification", + "List", + "partitionBy"); + to_json_key( + j, + "orderingScheme", + p.orderingScheme, + "DataOrganizationSpecification", + "OrderingScheme", + "orderingScheme"); +} + +void from_json(const json& j, DataOrganizationSpecification& p) { + from_json_key( + j, + "partitionBy", + p.partitionBy, + "DataOrganizationSpecification", + "List", + "partitionBy"); + from_json_key( + j, + "orderingScheme", + p.orderingScheme, + "DataOrganizationSpecification", + "OrderingScheme", "orderingScheme"); } } // namespace facebook::presto::protocol @@ -3419,11 +3829,87 @@ void from_json(const json& j, DistinctLimitNode& p) { "hashVariable"); from_json_key( j, - "timeoutMillis", - p.timeoutMillis, - "DistinctLimitNode", - "int", - "timeoutMillis"); + "timeoutMillis", + p.timeoutMillis, + "DistinctLimitNode", + "int", + "timeoutMillis"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +void to_json( + json& j, + const std::shared_ptr& p) { + if (p == nullptr) { + return; + } + String type = p->_type; + getConnectorProtocol(type).to_json(j, p); +} + +void from_json( + const json& j, + std::shared_ptr& p) { + String type; + try { + type = p->getSubclassKey(j); + } catch (json::parse_error& e) { + throw ParseError( + std::string(e.what()) + + " ConnectorDistributedProcedureHandle ConnectorDistributedProcedureHandle"); + } + getConnectorProtocol(type).from_json(j, p); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { + +void to_json(json& j, const DistributedProcedureHandle& p) { + j = json::object(); + to_json_key( + j, + "connectorId", + p.connectorId, + "DistributedProcedureHandle", + "ConnectorId", + "connectorId"); + to_json_key( + j, + "transactionHandle", + p.transactionHandle, + "DistributedProcedureHandle", + "ConnectorTransactionHandle", + "transactionHandle"); + to_json_key( + j, + "connectorHandle", + p.connectorHandle, + "DistributedProcedureHandle", + "ConnectorDistributedProcedureHandle", + "connectorHandle"); +} + +void from_json(const json& j, DistributedProcedureHandle& p) { + from_json_key( + j, + "connectorId", + p.connectorId, + "DistributedProcedureHandle", + "ConnectorId", + "connectorId"); + from_json_key( + j, + "transactionHandle", + p.transactionHandle, + "DistributedProcedureHandle", + "ConnectorTransactionHandle", + "transactionHandle"); + from_json_key( + j, + "connectorHandle", + p.connectorHandle, + "DistributedProcedureHandle", + "ConnectorDistributedProcedureHandle", + "connectorHandle"); } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { @@ -4695,388 +5181,133 @@ void from_json(const json& j, ErrorType& e) { }); e = ((it != std::end(ErrorType_enum_table)) ? it - : std::begin(ErrorType_enum_table)) - ->first; -} -} // namespace facebook::presto::protocol -namespace facebook::presto::protocol { - -void to_json(json& j, const ErrorCode& p) { - j = json::object(); - to_json_key(j, "code", p.code, "ErrorCode", "int", "code"); - to_json_key(j, "name", p.name, "ErrorCode", "String", "name"); - to_json_key(j, "type", p.type, "ErrorCode", "ErrorType", "type"); - to_json_key(j, "retriable", p.retriable, "ErrorCode", "bool", "retriable"); -} - -void from_json(const json& j, ErrorCode& p) { - from_json_key(j, "code", p.code, "ErrorCode", "int", "code"); - from_json_key(j, "name", p.name, "ErrorCode", "String", "name"); - from_json_key(j, "type", p.type, "ErrorCode", "ErrorType", "type"); - from_json_key(j, "retriable", p.retriable, "ErrorCode", "bool", "retriable"); -} -} // namespace facebook::presto::protocol -namespace facebook::presto::protocol { - -void to_json(json& j, const ErrorLocation& p) { - j = json::object(); - to_json_key( - j, "lineNumber", p.lineNumber, "ErrorLocation", "int", "lineNumber"); - to_json_key( - j, - "columnNumber", - p.columnNumber, - "ErrorLocation", - "int", - "columnNumber"); -} - -void from_json(const json& j, ErrorLocation& p) { - from_json_key( - j, "lineNumber", p.lineNumber, "ErrorLocation", "int", "lineNumber"); - from_json_key( - j, - "columnNumber", - p.columnNumber, - "ErrorLocation", - "int", - "columnNumber"); -} -} // namespace facebook::presto::protocol -namespace facebook::presto::protocol { -// Loosly copied this here from NLOHMANN_JSON_SERIALIZE_ENUM() - -// NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays -static const std::pair ExchangeNodeScope_enum_table[] = - { // NOLINT: cert-err58-cpp - {ExchangeNodeScope::LOCAL, "LOCAL"}, - {ExchangeNodeScope::REMOTE_STREAMING, "REMOTE_STREAMING"}, - {ExchangeNodeScope::REMOTE_MATERIALIZED, "REMOTE_MATERIALIZED"}}; -void to_json(json& j, const ExchangeNodeScope& e) { - static_assert( - std::is_enum::value, - "ExchangeNodeScope must be an enum!"); - const auto* it = std::find_if( - std::begin(ExchangeNodeScope_enum_table), - std::end(ExchangeNodeScope_enum_table), - [e](const std::pair& ej_pair) -> bool { - return ej_pair.first == e; - }); - j = ((it != std::end(ExchangeNodeScope_enum_table)) - ? it - : std::begin(ExchangeNodeScope_enum_table)) - ->second; -} -void from_json(const json& j, ExchangeNodeScope& e) { - static_assert( - std::is_enum::value, - "ExchangeNodeScope must be an enum!"); - const auto* it = std::find_if( - std::begin(ExchangeNodeScope_enum_table), - std::end(ExchangeNodeScope_enum_table), - [&j](const std::pair& ej_pair) -> bool { - return ej_pair.second == j; - }); - e = ((it != std::end(ExchangeNodeScope_enum_table)) - ? it - : std::begin(ExchangeNodeScope_enum_table)) - ->first; -} -} // namespace facebook::presto::protocol -namespace facebook::presto::protocol { -// Loosly copied this here from NLOHMANN_JSON_SERIALIZE_ENUM() - -// NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays -static const std::pair ExchangeNodeType_enum_table[] = { - // NOLINT: cert-err58-cpp - {ExchangeNodeType::GATHER, "GATHER"}, - {ExchangeNodeType::REPARTITION, "REPARTITION"}, - {ExchangeNodeType::REPLICATE, "REPLICATE"}, -}; -void to_json(json& j, const ExchangeNodeType& e) { - static_assert( - std::is_enum::value, - "ExchangeNodeType must be an enum!"); - const auto* it = std::find_if( - std::begin(ExchangeNodeType_enum_table), - std::end(ExchangeNodeType_enum_table), - [e](const std::pair& ej_pair) -> bool { - return ej_pair.first == e; - }); - j = ((it != std::end(ExchangeNodeType_enum_table)) - ? it - : std::begin(ExchangeNodeType_enum_table)) - ->second; -} -void from_json(const json& j, ExchangeNodeType& e) { - static_assert( - std::is_enum::value, - "ExchangeNodeType must be an enum!"); - const auto* it = std::find_if( - std::begin(ExchangeNodeType_enum_table), - std::end(ExchangeNodeType_enum_table), - [&j](const std::pair& ej_pair) -> bool { - return ej_pair.second == j; - }); - e = ((it != std::end(ExchangeNodeType_enum_table)) - ? it - : std::begin(ExchangeNodeType_enum_table)) - ->first; -} -} // namespace facebook::presto::protocol -namespace facebook::presto::protocol { -// Loosly copied this here from NLOHMANN_JSON_SERIALIZE_ENUM() - -// NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays -static const std::pair ExchangeEncoding_enum_table[] = - { // NOLINT: cert-err58-cpp - {ExchangeEncoding::COLUMNAR, "COLUMNAR"}, - {ExchangeEncoding::ROW_WISE, "ROW_WISE"}}; -void to_json(json& j, const ExchangeEncoding& e) { - static_assert( - std::is_enum::value, - "ExchangeEncoding must be an enum!"); - const auto* it = std::find_if( - std::begin(ExchangeEncoding_enum_table), - std::end(ExchangeEncoding_enum_table), - [e](const std::pair& ej_pair) -> bool { - return ej_pair.first == e; - }); - j = ((it != std::end(ExchangeEncoding_enum_table)) - ? it - : std::begin(ExchangeEncoding_enum_table)) - ->second; -} -void from_json(const json& j, ExchangeEncoding& e) { - static_assert( - std::is_enum::value, - "ExchangeEncoding must be an enum!"); - const auto* it = std::find_if( - std::begin(ExchangeEncoding_enum_table), - std::end(ExchangeEncoding_enum_table), - [&j](const std::pair& ej_pair) -> bool { - return ej_pair.second == j; - }); - e = ((it != std::end(ExchangeEncoding_enum_table)) - ? it - : std::begin(ExchangeEncoding_enum_table)) - ->first; -} -} // namespace facebook::presto::protocol -namespace facebook::presto::protocol { -void to_json(json& j, const std::shared_ptr& p) { - if (p == nullptr) { - return; - } - String type = p->_type; - - if (type == "$remote") { - j = *std::static_pointer_cast(p); - return; - } - getConnectorProtocol(type).to_json(j, p); -} - -void from_json(const json& j, std::shared_ptr& p) { - String type; - try { - type = p->getSubclassKey(j); - } catch (json::parse_error& e) { - throw ParseError(std::string(e.what()) + " ConnectorPartitioningHandle"); - } - - if (type == "$remote") { - auto k = std::make_shared(); - j.get_to(*k); - p = k; - return; - } - getConnectorProtocol(type).from_json(j, p); -} -} // namespace facebook::presto::protocol -namespace facebook::presto::protocol { - -void to_json(json& j, const PartitioningHandle& p) { - j = json::object(); - to_json_key( - j, - "connectorId", - p.connectorId, - "PartitioningHandle", - "ConnectorId", - "connectorId"); - to_json_key( - j, - "transactionHandle", - p.transactionHandle, - "PartitioningHandle", - "ConnectorTransactionHandle", - "transactionHandle"); - to_json_key( - j, - "connectorHandle", - p.connectorHandle, - "PartitioningHandle", - "ConnectorPartitioningHandle", - "connectorHandle"); -} - -void from_json(const json& j, PartitioningHandle& p) { - from_json_key( - j, - "connectorId", - p.connectorId, - "PartitioningHandle", - "ConnectorId", - "connectorId"); - from_json_key( - j, - "transactionHandle", - p.transactionHandle, - "PartitioningHandle", - "ConnectorTransactionHandle", - "transactionHandle"); - from_json_key( - j, - "connectorHandle", - p.connectorHandle, - "PartitioningHandle", - "ConnectorPartitioningHandle", - "connectorHandle"); + : std::begin(ErrorType_enum_table)) + ->first; } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { -void to_json(json& j, const Partitioning& p) { +void to_json(json& j, const ErrorCode& p) { j = json::object(); - to_json_key( - j, "handle", p.handle, "Partitioning", "PartitioningHandle", "handle"); - to_json_key( - j, - "arguments", - p.arguments, - "Partitioning", - "List>", - "arguments"); + to_json_key(j, "code", p.code, "ErrorCode", "int", "code"); + to_json_key(j, "name", p.name, "ErrorCode", "String", "name"); + to_json_key(j, "type", p.type, "ErrorCode", "ErrorType", "type"); + to_json_key(j, "retriable", p.retriable, "ErrorCode", "bool", "retriable"); } -void from_json(const json& j, Partitioning& p) { - from_json_key( - j, "handle", p.handle, "Partitioning", "PartitioningHandle", "handle"); - from_json_key( - j, - "arguments", - p.arguments, - "Partitioning", - "List>", - "arguments"); +void from_json(const json& j, ErrorCode& p) { + from_json_key(j, "code", p.code, "ErrorCode", "int", "code"); + from_json_key(j, "name", p.name, "ErrorCode", "String", "name"); + from_json_key(j, "type", p.type, "ErrorCode", "ErrorType", "type"); + from_json_key(j, "retriable", p.retriable, "ErrorCode", "bool", "retriable"); } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { -void to_json(json& j, const PartitioningScheme& p) { +void to_json(json& j, const ErrorLocation& p) { j = json::object(); to_json_key( - j, - "partitioning", - p.partitioning, - "PartitioningScheme", - "Partitioning", - "partitioning"); - to_json_key( - j, - "outputLayout", - p.outputLayout, - "PartitioningScheme", - "List", - "outputLayout"); - to_json_key( - j, - "hashColumn", - p.hashColumn, - "PartitioningScheme", - "VariableReferenceExpression", - "hashColumn"); - to_json_key( - j, - "replicateNullsAndAny", - p.replicateNullsAndAny, - "PartitioningScheme", - "bool", - "replicateNullsAndAny"); - to_json_key( - j, - "scaleWriters", - p.scaleWriters, - "PartitioningScheme", - "bool", - "scaleWriters"); - to_json_key( - j, - "encoding", - p.encoding, - "PartitioningScheme", - "ExchangeEncoding", - "encoding"); + j, "lineNumber", p.lineNumber, "ErrorLocation", "int", "lineNumber"); to_json_key( j, - "bucketToPartition", - p.bucketToPartition, - "PartitioningScheme", - "List", - "bucketToPartition"); + "columnNumber", + p.columnNumber, + "ErrorLocation", + "int", + "columnNumber"); } -void from_json(const json& j, PartitioningScheme& p) { - from_json_key( - j, - "partitioning", - p.partitioning, - "PartitioningScheme", - "Partitioning", - "partitioning"); - from_json_key( - j, - "outputLayout", - p.outputLayout, - "PartitioningScheme", - "List", - "outputLayout"); - from_json_key( - j, - "hashColumn", - p.hashColumn, - "PartitioningScheme", - "VariableReferenceExpression", - "hashColumn"); - from_json_key( - j, - "replicateNullsAndAny", - p.replicateNullsAndAny, - "PartitioningScheme", - "bool", - "replicateNullsAndAny"); - from_json_key( - j, - "scaleWriters", - p.scaleWriters, - "PartitioningScheme", - "bool", - "scaleWriters"); +void from_json(const json& j, ErrorLocation& p) { from_json_key( - j, - "encoding", - p.encoding, - "PartitioningScheme", - "ExchangeEncoding", - "encoding"); + j, "lineNumber", p.lineNumber, "ErrorLocation", "int", "lineNumber"); from_json_key( j, - "bucketToPartition", - p.bucketToPartition, - "PartitioningScheme", - "List", - "bucketToPartition"); + "columnNumber", + p.columnNumber, + "ErrorLocation", + "int", + "columnNumber"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +// Loosly copied this here from NLOHMANN_JSON_SERIALIZE_ENUM() + +// NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays +static const std::pair ExchangeNodeScope_enum_table[] = + { // NOLINT: cert-err58-cpp + {ExchangeNodeScope::LOCAL, "LOCAL"}, + {ExchangeNodeScope::REMOTE_STREAMING, "REMOTE_STREAMING"}, + {ExchangeNodeScope::REMOTE_MATERIALIZED, "REMOTE_MATERIALIZED"}}; +void to_json(json& j, const ExchangeNodeScope& e) { + static_assert( + std::is_enum::value, + "ExchangeNodeScope must be an enum!"); + const auto* it = std::find_if( + std::begin(ExchangeNodeScope_enum_table), + std::end(ExchangeNodeScope_enum_table), + [e](const std::pair& ej_pair) -> bool { + return ej_pair.first == e; + }); + j = ((it != std::end(ExchangeNodeScope_enum_table)) + ? it + : std::begin(ExchangeNodeScope_enum_table)) + ->second; +} +void from_json(const json& j, ExchangeNodeScope& e) { + static_assert( + std::is_enum::value, + "ExchangeNodeScope must be an enum!"); + const auto* it = std::find_if( + std::begin(ExchangeNodeScope_enum_table), + std::end(ExchangeNodeScope_enum_table), + [&j](const std::pair& ej_pair) -> bool { + return ej_pair.second == j; + }); + e = ((it != std::end(ExchangeNodeScope_enum_table)) + ? it + : std::begin(ExchangeNodeScope_enum_table)) + ->first; +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +// Loosly copied this here from NLOHMANN_JSON_SERIALIZE_ENUM() + +// NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays +static const std::pair ExchangeNodeType_enum_table[] = { + // NOLINT: cert-err58-cpp + {ExchangeNodeType::GATHER, "GATHER"}, + {ExchangeNodeType::REPARTITION, "REPARTITION"}, + {ExchangeNodeType::REPLICATE, "REPLICATE"}, +}; +void to_json(json& j, const ExchangeNodeType& e) { + static_assert( + std::is_enum::value, + "ExchangeNodeType must be an enum!"); + const auto* it = std::find_if( + std::begin(ExchangeNodeType_enum_table), + std::end(ExchangeNodeType_enum_table), + [e](const std::pair& ej_pair) -> bool { + return ej_pair.first == e; + }); + j = ((it != std::end(ExchangeNodeType_enum_table)) + ? it + : std::begin(ExchangeNodeType_enum_table)) + ->second; +} +void from_json(const json& j, ExchangeNodeType& e) { + static_assert( + std::is_enum::value, + "ExchangeNodeType must be an enum!"); + const auto* it = std::find_if( + std::begin(ExchangeNodeType_enum_table), + std::end(ExchangeNodeType_enum_table), + [&j](const std::pair& ej_pair) -> bool { + return ej_pair.second == j; + }); + e = ((it != std::end(ExchangeNodeType_enum_table)) + ? it + : std::begin(ExchangeNodeType_enum_table)) + ->first; } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { @@ -5172,6 +5403,62 @@ void from_json(const json& j, ExchangeNode& p) { } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +ExecuteProcedureHandle::ExecuteProcedureHandle() noexcept { + _type = "ExecuteProcedureHandle"; +} + +void to_json(json& j, const ExecuteProcedureHandle& p) { + j = json::object(); + j["@type"] = "ExecuteProcedureHandle"; + to_json_key( + j, + "handle", + p.handle, + "ExecuteProcedureHandle", + "DistributedProcedureHandle", + "handle"); + to_json_key( + j, + "schemaTableName", + p.schemaTableName, + "ExecuteProcedureHandle", + "SchemaTableName", + "schemaTableName"); + to_json_key( + j, + "procedureName", + p.procedureName, + "ExecuteProcedureHandle", + "QualifiedObjectName", + "procedureName"); +} + +void from_json(const json& j, ExecuteProcedureHandle& p) { + p._type = j["@type"]; + from_json_key( + j, + "handle", + p.handle, + "ExecuteProcedureHandle", + "DistributedProcedureHandle", + "handle"); + from_json_key( + j, + "schemaTableName", + p.schemaTableName, + "ExecuteProcedureHandle", + "SchemaTableName", + "schemaTableName"); + from_json_key( + j, + "procedureName", + p.procedureName, + "ExecuteProcedureHandle", + "QualifiedObjectName", + "procedureName"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { // Loosly copied this here from NLOHMANN_JSON_SERIALIZE_ENUM() // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h index 2b1e4eb66c14e..5cb57d7287c87 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h @@ -291,6 +291,11 @@ void to_json(json& j, const std::shared_ptr& p); void from_json(const json& j, std::shared_ptr& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +struct ConnectorPartitioningHandle : public JsonEncodedSubclass {}; +void to_json(json& j, const std::shared_ptr& p); +void from_json(const json& j, std::shared_ptr& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { struct InputDistribution : public JsonEncodedSubclass {}; void to_json(json& j, const std::shared_ptr& p); void from_json(const json& j, std::shared_ptr& p); @@ -301,11 +306,6 @@ void to_json(json& j, const std::shared_ptr& p); void from_json(const json& j, std::shared_ptr& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { -struct ConnectorPartitioningHandle : public JsonEncodedSubclass {}; -void to_json(json& j, const std::shared_ptr& p); -void from_json(const json& j, std::shared_ptr& p); -} // namespace facebook::presto::protocol -namespace facebook::presto::protocol { struct ConnectorIndexHandle : public JsonEncodedSubclass {}; void to_json(json& j, const std::shared_ptr& p); void from_json(const json& j, std::shared_ptr& p); @@ -857,6 +857,57 @@ void to_json(json& j, const CacheQuotaRequirement& p); void from_json(const json& j, CacheQuotaRequirement& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +enum class ExchangeEncoding { COLUMNAR, ROW_WISE }; +extern void to_json(json& j, const ExchangeEncoding& e); +extern void from_json(const json& j, ExchangeEncoding& e); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct PartitioningHandle { + std::shared_ptr connectorId = {}; + std::shared_ptr transactionHandle = {}; + std::shared_ptr connectorHandle = {}; +}; +void to_json(json& j, const PartitioningHandle& p); +void from_json(const json& j, PartitioningHandle& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct Partitioning { + PartitioningHandle handle = {}; + List> arguments = {}; +}; +void to_json(json& j, const Partitioning& p); +void from_json(const json& j, Partitioning& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct PartitioningScheme { + Partitioning partitioning = {}; + List outputLayout = {}; + std::shared_ptr hashColumn = {}; + bool replicateNullsAndAny = {}; + bool scaleWriters = {}; + ExchangeEncoding encoding = {}; + std::shared_ptr> bucketToPartition = {}; +}; +void to_json(json& j, const PartitioningScheme& p); +void from_json(const json& j, PartitioningScheme& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct CallDistributedProcedureNode : public PlanNode { + std::shared_ptr source = {}; + VariableReferenceExpression rowCountVariable = {}; + VariableReferenceExpression fragmentVariable = {}; + VariableReferenceExpression tableCommitContextVariable = {}; + List columns = {}; + List columnNames = {}; + List notNullColumnVariables = {}; + std::shared_ptr partitioningScheme = {}; + + CallDistributedProcedureNode() noexcept; +}; +void to_json(json& j, const CallDistributedProcedureNode& p); +void from_json(const json& j, CallDistributedProcedureNode& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { struct Column { String name; @@ -1002,6 +1053,24 @@ void to_json(json& j, const DistinctLimitNode& p); void from_json(const json& j, DistinctLimitNode& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +struct ConnectorDistributedProcedureHandle : public JsonEncodedSubclass {}; +void to_json( + json& j, + const std::shared_ptr& p); +void from_json( + const json& j, + std::shared_ptr& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct DistributedProcedureHandle { + ConnectorId connectorId = {}; + std::shared_ptr transactionHandle = {}; + std::shared_ptr connectorHandle = {}; +}; +void to_json(json& j, const DistributedProcedureHandle& p); +void from_json(const json& j, DistributedProcedureHandle& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { struct DistributionSnapshot { double maxError = {}; double count = {}; @@ -1223,41 +1292,6 @@ extern void to_json(json& j, const ExchangeNodeType& e); extern void from_json(const json& j, ExchangeNodeType& e); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { -enum class ExchangeEncoding { COLUMNAR, ROW_WISE }; -extern void to_json(json& j, const ExchangeEncoding& e); -extern void from_json(const json& j, ExchangeEncoding& e); -} // namespace facebook::presto::protocol -namespace facebook::presto::protocol { -struct PartitioningHandle { - std::shared_ptr connectorId = {}; - std::shared_ptr transactionHandle = {}; - std::shared_ptr connectorHandle = {}; -}; -void to_json(json& j, const PartitioningHandle& p); -void from_json(const json& j, PartitioningHandle& p); -} // namespace facebook::presto::protocol -namespace facebook::presto::protocol { -struct Partitioning { - PartitioningHandle handle = {}; - List> arguments = {}; -}; -void to_json(json& j, const Partitioning& p); -void from_json(const json& j, Partitioning& p); -} // namespace facebook::presto::protocol -namespace facebook::presto::protocol { -struct PartitioningScheme { - Partitioning partitioning = {}; - List outputLayout = {}; - std::shared_ptr hashColumn = {}; - bool replicateNullsAndAny = {}; - bool scaleWriters = {}; - ExchangeEncoding encoding = {}; - std::shared_ptr> bucketToPartition = {}; -}; -void to_json(json& j, const PartitioningScheme& p); -void from_json(const json& j, PartitioningScheme& p); -} // namespace facebook::presto::protocol -namespace facebook::presto::protocol { struct ExchangeNode : public PlanNode { ExchangeNodeType type = {}; ExchangeNodeScope scope = {}; @@ -1273,6 +1307,17 @@ void to_json(json& j, const ExchangeNode& p); void from_json(const json& j, ExchangeNode& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +struct ExecuteProcedureHandle : public ExecutionWriterTarget { + DistributedProcedureHandle handle = {}; + SchemaTableName schemaTableName = {}; + QualifiedObjectName procedureName = {}; + + ExecuteProcedureHandle() noexcept; +}; +void to_json(json& j, const ExecuteProcedureHandle& p); +void from_json(const json& j, ExecuteProcedureHandle& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { enum class ErrorCause { UNKNOWN, LOW_PARTITION_COUNT, diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml index 9d930555c62ba..feecd99f33f32 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml @@ -85,6 +85,10 @@ AbstractClasses: super: JsonEncodedSubclass subclasses: + ConnectorDistributedProcedureHandle: + super: JsonEncodedSubclass + subclasses: + ConnectorTransactionHandle: super: JsonEncodedSubclass subclasses: @@ -129,6 +133,7 @@ AbstractClasses: - { name: InsertHandle, key: InsertHandle } - { name: DeleteHandle, key: DeleteHandle } - { name: UpdateHandle, key: UpdateHandle } + - { name: ExecuteProcedureHandle, key: ExecuteProcedureHandle} InputDistribution: super: JsonEncodedSubclass @@ -168,6 +173,7 @@ AbstractClasses: - { name: AssignUniqueId, key: com.facebook.presto.sql.planner.plan.AssignUniqueId } - { name: MergeJoinNode, key: .MergeJoinNode } - { name: WindowNode, key: .WindowNode } + - { name: CallDistributedProcedureNode, key: com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode } RowExpression: super: JsonEncodedSubclass @@ -303,6 +309,7 @@ JavaClasses: - presto-main-base/src/main/java/com/facebook/presto/metadata/OutputTableHandle.java - presto-main-base/src/main/java/com/facebook/presto/metadata/InsertTableHandle.java - presto-main-base/src/main/java/com/facebook/presto/metadata/DeleteTableHandle.java + - presto-main-base/src/main/java/com/facebook/presto/metadata/DistributedProcedureHandle.java - presto-main-base/src/main/java/com/facebook/presto/execution/TaskInfo.java - presto-main-base/src/main/java/com/facebook/presto/execution/TaskSource.java - presto-main-base/src/main/java/com/facebook/presto/execution/TaskState.java @@ -326,6 +333,7 @@ JavaClasses: - presto-spi/src/main/java/com/facebook/presto/spi/relation/LambdaDefinitionExpression.java - presto-spi/src/main/java/com/facebook/presto/spi/plan/SortNode.java - presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/AssignUniqueId.java + - presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/CallDistributedProcedureNode.java - presto-spi/src/main/java/com/facebook/presto/spi/SourceLocation.java - presto-spark-base/src/main/java/com/facebook/presto/spark/execution/http/BatchTaskUpdateRequest.java - presto-spi/src/main/java/com/facebook/presto/spi/plan/JoinType.java diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/CallDistributedProcedureNode.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/CallDistributedProcedureNode.cpp.inc new file mode 100644 index 0000000000000..f8d4336a78d18 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/CallDistributedProcedureNode.cpp.inc @@ -0,0 +1,144 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +namespace facebook::presto::protocol { +CallDistributedProcedureNode::CallDistributedProcedureNode() noexcept { + _type = "com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode"; +} + +void to_json(json& j, const CallDistributedProcedureNode& p) { + j = json::object(); + j["@type"] = + "com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode"; + to_json_key( + j, "id", p.id, "CallDistributedProcedureNode", "PlanNodeId", "id"); + to_json_key( + j, + "source", + p.source, + "CallDistributedProcedureNode", + "PlanNode", + "source"); + to_json_key( + j, + "rowCountVariable", + p.rowCountVariable, + "CallDistributedProcedureNode", + "VariableReferenceExpression", + "rowCountVariable"); + to_json_key( + j, + "fragmentVariable", + p.fragmentVariable, + "CallDistributedProcedureNode", + "VariableReferenceExpression", + "fragmentVariable"); + to_json_key( + j, + "tableCommitContextVariable", + p.tableCommitContextVariable, + "CallDistributedProcedureNode", + "VariableReferenceExpression", + "tableCommitContextVariable"); + to_json_key( + j, + "columns", + p.columns, + "CallDistributedProcedureNode", + "List", + "columns"); + to_json_key( + j, + "columnNames", + p.columnNames, + "CallDistributedProcedureNode", + "List", + "columnNames"); + to_json_key( + j, + "notNullColumnVariables", + p.notNullColumnVariables, + "CallDistributedProcedureNode", + "List", + "notNullColumnVariables"); + to_json_key( + j, + "partitioningScheme", + p.partitioningScheme, + "CallDistributedProcedureNode", + "PartitioningScheme", + "partitioningScheme"); +} + +void from_json(const json& j, CallDistributedProcedureNode& p) { + p._type = j["@type"]; + from_json_key( + j, "id", p.id, "CallDistributedProcedureNode", "PlanNodeId", "id"); + from_json_key( + j, + "source", + p.source, + "CallDistributedProcedureNode", + "PlanNode", + "source"); + from_json_key( + j, + "rowCountVariable", + p.rowCountVariable, + "CallDistributedProcedureNode", + "VariableReferenceExpression", + "rowCountVariable"); + from_json_key( + j, + "fragmentVariable", + p.fragmentVariable, + "CallDistributedProcedureNode", + "VariableReferenceExpression", + "fragmentVariable"); + from_json_key( + j, + "tableCommitContextVariable", + p.tableCommitContextVariable, + "CallDistributedProcedureNode", + "VariableReferenceExpression", + "tableCommitContextVariable"); + from_json_key( + j, + "columns", + p.columns, + "CallDistributedProcedureNode", + "List", + "columns"); + from_json_key( + j, + "columnNames", + p.columnNames, + "CallDistributedProcedureNode", + "List", + "columnNames"); + from_json_key( + j, + "notNullColumnVariables", + p.notNullColumnVariables, + "CallDistributedProcedureNode", + "List", + "notNullColumnVariables"); + from_json_key( + j, + "partitioningScheme", + p.partitioningScheme, + "CallDistributedProcedureNode", + "PartitioningScheme", + "partitioningScheme"); +} +} // namespace facebook::presto::protocol diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/CallDistributedProcedureNode.hpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/CallDistributedProcedureNode.hpp.inc new file mode 100644 index 0000000000000..718d09f32de18 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/CallDistributedProcedureNode.hpp.inc @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace facebook::presto::protocol { +struct CallDistributedProcedureNode : public PlanNode { + std::shared_ptr source = {}; + VariableReferenceExpression rowCountVariable = {}; + VariableReferenceExpression fragmentVariable = {}; + VariableReferenceExpression tableCommitContextVariable = {}; + List columns = {}; + List columnNames = {}; + List notNullColumnVariables = {}; + std::shared_ptr partitioningScheme = {}; + + CallDistributedProcedureNode() noexcept; +}; +void to_json(json& j, const CallDistributedProcedureNode& p); +void from_json(const json& j, CallDistributedProcedureNode& p); +} // namespace facebook::presto::protocol diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorDistributedProcedureHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorDistributedProcedureHandle.cpp.inc new file mode 100644 index 0000000000000..2017fcff6abf9 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorDistributedProcedureHandle.cpp.inc @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +namespace facebook::presto::protocol { +void to_json( + json& j, + const std::shared_ptr& p) { + if (p == nullptr) { + return; + } + String type = p->_type; + getConnectorProtocol(type).to_json(j, p); +} + +void from_json( + const json& j, + std::shared_ptr& p) { + String type; + try { + type = p->getSubclassKey(j); + } catch (json::parse_error& e) { + throw ParseError( + std::string(e.what()) + + " ConnectorDistributedProcedureHandle ConnectorDistributedProcedureHandle"); + } + getConnectorProtocol(type).from_json(j, p); +} +} // namespace facebook::presto::protocol diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorDistributedProcedureHandle.hpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorDistributedProcedureHandle.hpp.inc new file mode 100644 index 0000000000000..c799fb610eefe --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorDistributedProcedureHandle.hpp.inc @@ -0,0 +1,23 @@ + +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +namespace facebook::presto::protocol { +struct ConnectorDistributedProcedureHandle : public JsonEncodedSubclass {}; +void to_json( + json& j, + const std::shared_ptr& p); +void from_json( + const json& j, + std::shared_ptr& p); +} // namespace facebook::presto::protocol diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java index f0f297904497b..889aeca27deb9 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java @@ -68,6 +68,7 @@ import com.facebook.presto.memory.MemoryManagerConfig; import com.facebook.presto.memory.NodeMemoryConfig; import com.facebook.presto.metadata.AnalyzePropertyManager; +import com.facebook.presto.metadata.BuiltInProcedureRegistry; import com.facebook.presto.metadata.CatalogManager; import com.facebook.presto.metadata.ColumnPropertyManager; import com.facebook.presto.metadata.FunctionAndTypeManager; @@ -141,6 +142,7 @@ import com.facebook.presto.spi.memory.ClusterMemoryPoolManager; import com.facebook.presto.spi.plan.SimplePlanFragment; import com.facebook.presto.spi.plan.SimplePlanFragmentSerde; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.DomainTranslator; import com.facebook.presto.spi.relation.PredicateCompiler; @@ -383,6 +385,8 @@ protected void setup(Binder binder) binder.bind(Metadata.class).to(MetadataManager.class).in(Scopes.SINGLETON); binder.bind(StaticFunctionNamespaceStore.class).in(Scopes.SINGLETON); binder.bind(StaticTypeManagerStore.class).in(Scopes.SINGLETON); + binder.bind(BuiltInProcedureRegistry.class).in(Scopes.SINGLETON); + binder.bind(ProcedureRegistry.class).to(BuiltInProcedureRegistry.class).in(Scopes.SINGLETON); // type newSetBinder(binder, Type.class); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorDistributedProcedureHandle.java b/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorDistributedProcedureHandle.java new file mode 100644 index 0000000000000..d1b9deca1d190 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorDistributedProcedureHandle.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi; + +@SuppressWarnings("MarkerInterface") +public interface ConnectorDistributedProcedureHandle +{ +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorHandleResolver.java b/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorHandleResolver.java index cded59a56df0b..05af1070bc21b 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorHandleResolver.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorHandleResolver.java @@ -46,6 +46,11 @@ default Class getDeleteTableHandleClass() throw new UnsupportedOperationException(); } + default Class getDistributedProcedureHandleClass() + { + throw new UnsupportedOperationException(); + } + default Class getPartitioningHandleClass() { throw new UnsupportedOperationException(); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/analyzer/AnalyzerOptions.java b/presto-spi/src/main/java/com/facebook/presto/spi/analyzer/AnalyzerOptions.java index 2cb3140ca869b..eed3fe5119c59 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/analyzer/AnalyzerOptions.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/analyzer/AnalyzerOptions.java @@ -16,6 +16,8 @@ import com.facebook.presto.common.WarningHandlingLevel; import com.facebook.presto.spi.WarningCollector; +import java.util.Optional; + import static com.facebook.presto.common.WarningHandlingLevel.NORMAL; import static java.util.Objects.requireNonNull; @@ -27,17 +29,23 @@ public class AnalyzerOptions private final boolean isParseDecimalLiteralsAsDouble; private final boolean isLogFormattedQueryEnabled; private final WarningHandlingLevel warningHandlingLevel; + private final Optional sessionCatalogName; + private final Optional sessionSchemaName; private final WarningCollector warningCollector; private AnalyzerOptions( boolean isParseDecimalLiteralsAsDouble, boolean isLogFormattedQueryEnabled, WarningCollector warningCollector, + Optional sessionCatalogName, + Optional sessionSchemaName, WarningHandlingLevel warningHandlingLevel) { this.isParseDecimalLiteralsAsDouble = isParseDecimalLiteralsAsDouble; this.isLogFormattedQueryEnabled = isLogFormattedQueryEnabled; this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); + this.sessionCatalogName = requireNonNull(sessionCatalogName, "sessionCatalogName is null"); + this.sessionSchemaName = requireNonNull(sessionSchemaName, "sessionSchemaName is null"); this.warningHandlingLevel = requireNonNull(warningHandlingLevel, "warningHandlingLevel is null"); } @@ -56,6 +64,16 @@ public WarningCollector getWarningCollector() return warningCollector; } + public Optional getSessionCatalogName() + { + return sessionCatalogName; + } + + public Optional getSessionSchemaName() + { + return sessionSchemaName; + } + public WarningHandlingLevel getWarningHandlingLevel() { return warningHandlingLevel; @@ -72,6 +90,8 @@ public static class Builder private boolean isLogFormattedQueryEnabled; private WarningCollector warningCollector = WarningCollector.NOOP; private WarningHandlingLevel warningHandlingLevel = NORMAL; + private Optional sessionCatalogName = Optional.empty(); + private Optional sessionSchemaName = Optional.empty(); private Builder() {} @@ -99,9 +119,21 @@ public Builder setWarningHandlingLevel(WarningHandlingLevel warningHandlingLevel return this; } + public Builder setSessionCatalogName(Optional sessionCatalogName) + { + this.sessionCatalogName = sessionCatalogName; + return this; + } + + public Builder setSessionSchemaName(Optional sessionSchemaName) + { + this.sessionSchemaName = sessionSchemaName; + return this; + } + public AnalyzerOptions build() { - return new AnalyzerOptions(isParseDecimalLiteralsAsDouble, isLogFormattedQueryEnabled, warningCollector, warningHandlingLevel); + return new AnalyzerOptions(isParseDecimalLiteralsAsDouble, isLogFormattedQueryEnabled, warningCollector, sessionCatalogName, sessionSchemaName, warningHandlingLevel); } } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorContext.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorContext.java index a954291e9c822..a97d157aef4ca 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorContext.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorContext.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.function.FunctionMetadataManager; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpressionService; public interface ConnectorContext @@ -36,6 +37,11 @@ default TypeManager getTypeManager() throw new UnsupportedOperationException(); } + default ProcedureRegistry getProcedureRegistry() + { + throw new UnsupportedOperationException(); + } + default FunctionMetadataManager getFunctionMetadataManager() { throw new UnsupportedOperationException(); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java index e8a888bbc8ab9..241ea67ced191 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java @@ -14,11 +14,13 @@ package com.facebook.presto.spi.connector; import com.facebook.presto.common.CatalogSchemaName; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorInsertTableHandle; import com.facebook.presto.spi.ConnectorNewTableLayout; import com.facebook.presto.spi.ConnectorOutputTableHandle; @@ -561,6 +563,28 @@ default Optional getUpdateRowIdColumn(ConnectorSession session, Co return Optional.ofNullable(getUpdateRowIdColumnHandle(session, tableHandle, updatedColumns)); } + /** + * Begin call distributed procedure + */ + default ConnectorDistributedProcedureHandle beginCallDistributedProcedure( + ConnectorSession session, + QualifiedObjectName procedureName, + ConnectorTableLayoutHandle tableLayoutHandle, + Object[] arguments) + { + throw new PrestoException(NOT_SUPPORTED, "This connector does not support distributed procedure"); + } + + /** + * Finish call distributed procedure + * + * @param fragments all fragments returned by {@link com.facebook.presto.spi.UpdatablePageSource#finish()} + */ + default void finishCallDistributedProcedure(ConnectorSession session, ConnectorDistributedProcedureHandle procedureHandle, QualifiedObjectName procedureName, Collection fragments) + { + throw new PrestoException(NOT_SUPPORTED, "This connector does not support distributed procedure"); + } + /** * Begin delete query */ diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorPageSinkProvider.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorPageSinkProvider.java index 5b8665d0f14fc..506ecb755dc04 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorPageSinkProvider.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorPageSinkProvider.java @@ -14,11 +14,15 @@ package com.facebook.presto.spi.connector; import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorInsertTableHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorPageSink; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PageSinkContext; +import com.facebook.presto.spi.PrestoException; + +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; public interface ConnectorPageSinkProvider { @@ -30,4 +34,9 @@ default ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionH { throw new UnsupportedOperationException("ConnectorPageSinkProvider does not support connectorDeleteTableHandle"); } + + default ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorDistributedProcedureHandle procedureHandle, PageSinkContext pageSinkContext) + { + throw new PrestoException(NOT_SUPPORTED, "This connector does not support distributed procedure"); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorProcedureContext.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorProcedureContext.java new file mode 100644 index 0000000000000..a35038529ff9a --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorProcedureContext.java @@ -0,0 +1,18 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.connector; + +public interface ConnectorProcedureContext +{ +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java index 4d8d3b8bbe3d5..e8f969cfabdee 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java @@ -14,11 +14,13 @@ package com.facebook.presto.spi.connector.classloader; import com.facebook.presto.common.CatalogSchemaName; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorInsertTableHandle; import com.facebook.presto.spi.ConnectorNewTableLayout; import com.facebook.presto.spi.ConnectorOutputTableHandle; @@ -612,6 +614,26 @@ public Optional getUpdateRowIdColumn(ConnectorSession session, Con } } + @Override + public ConnectorDistributedProcedureHandle beginCallDistributedProcedure( + ConnectorSession session, + QualifiedObjectName procedureName, + ConnectorTableLayoutHandle tableLayoutHandle, + Object[] arguments) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.beginCallDistributedProcedure(session, procedureName, tableLayoutHandle, arguments); + } + } + + @Override + public void finishCallDistributedProcedure(ConnectorSession session, ConnectorDistributedProcedureHandle procedureHandle, QualifiedObjectName procedureName, Collection fragments) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + delegate.finishCallDistributedProcedure(session, procedureHandle, procedureName, fragments); + } + } + @Override public ConnectorDeleteTableHandle beginDelete(ConnectorSession session, ConnectorTableHandle tableHandle) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorPageSinkProvider.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorPageSinkProvider.java index 761b6537e2a14..099ea51b04567 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorPageSinkProvider.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorPageSinkProvider.java @@ -14,6 +14,7 @@ package com.facebook.presto.spi.connector.classloader; import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorInsertTableHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorPageSink; @@ -60,4 +61,12 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa return new ClassLoaderSafeConnectorPageSink(delegate.createPageSink(transactionHandle, session, deleteTableHandle, pageSinkContext), classLoader); } } + + @Override + public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorDistributedProcedureHandle procedureHandle, PageSinkContext pageSinkContext) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return new ClassLoaderSafeConnectorPageSink(delegate.createPageSink(transactionHandle, session, procedureHandle, pageSinkContext), classLoader); + } + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/TableWriterNode.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/TableWriterNode.java index c13fdcf5fda4a..1354da343998d 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/plan/TableWriterNode.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/TableWriterNode.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.spi.plan; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorTableMetadata; @@ -527,4 +528,63 @@ public String toString() return handle.toString(); } } + + public static class CallDistributedProcedureTarget + extends WriterTarget + { + private final QualifiedObjectName procedureName; + private final Object[] procedureArguments; + private final Optional sourceHandle; + private final SchemaTableName schemaTableName; + + public CallDistributedProcedureTarget( + QualifiedObjectName procedureName, + Object[] procedureArguments, + Optional sourceHandle, + SchemaTableName schemaTableName) + { + this.procedureName = requireNonNull(procedureName, "procedureName is null"); + this.procedureArguments = requireNonNull(procedureArguments, "procedureArguments is null"); + this.sourceHandle = requireNonNull(sourceHandle, "sourceHandle is null"); + this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); + } + + public QualifiedObjectName getProcedureName() + { + return procedureName; + } + + public Object[] getProcedureArguments() + { + return procedureArguments; + } + + public Optional getSourceHandle() + { + return sourceHandle; + } + + public SchemaTableName getSchemaTableName() + { + return schemaTableName; + } + + @Override + public Optional> getOutputColumns() + { + return Optional.empty(); + } + + @Override + public ConnectorId getConnectorId() + { + return sourceHandle.map(handle -> handle.getConnectorId()).orElse(null); + } + + @Override + public String toString() + { + return procedureName.toString(); + } + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/procedure/DistributedProcedure.java b/presto-spi/src/main/java/com/facebook/presto/spi/procedure/DistributedProcedure.java new file mode 100644 index 0000000000000..13c603defcde2 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/procedure/DistributedProcedure.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.procedure; + +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.StandardErrorCode; +import com.facebook.presto.spi.connector.ConnectorProcedureContext; +import io.airlift.slice.Slice; + +import java.util.Collection; +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public abstract class DistributedProcedure + extends Procedure +{ + private final DistributedProcedureType type; + + protected DistributedProcedure(DistributedProcedureType type, String schema, String name, List arguments) + { + super(schema, name, arguments); + this.type = requireNonNull(type, "distributed procedure type is null"); + } + + public DistributedProcedureType getType() + { + return type; + } + + public abstract ConnectorDistributedProcedureHandle begin(ConnectorSession session, ConnectorProcedureContext procedureContext, ConnectorTableLayoutHandle tableLayoutHandle, Object[] arguments); + + public abstract void finish(ConnectorProcedureContext procedureContext, ConnectorDistributedProcedureHandle procedureHandle, Collection fragments); + + public ConnectorProcedureContext createContext() + { + throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "createContext not supported"); + } + + public enum DistributedProcedureType + { + TABLE_DATA_REWRITE + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/procedure/LocalProcedure.java b/presto-spi/src/main/java/com/facebook/presto/spi/procedure/LocalProcedure.java new file mode 100644 index 0000000000000..a2b4237d92f0d --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/procedure/LocalProcedure.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.procedure; + +import com.facebook.presto.spi.ConnectorSession; + +import java.lang.invoke.MethodHandle; +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class LocalProcedure + extends Procedure +{ + private final MethodHandle methodHandle; + + public LocalProcedure(String schema, String name, List arguments) + { + super(schema, name, arguments); + this.methodHandle = null; + } + + public LocalProcedure(String schema, String name, List arguments, MethodHandle methodHandle) + { + super(schema, name, arguments); + this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); + + checkArgument(!methodHandle.isVarargsCollector(), "Method must have fixed arity"); + checkArgument(methodHandle.type().returnType() == void.class, "Method must return void"); + + long parameterCount = methodHandle.type().parameterList().stream() + .filter(type -> !ConnectorSession.class.isAssignableFrom(type)) + .count(); + checkArgument(parameterCount == arguments.size(), "Method parameter count must match arguments"); + } + + public MethodHandle getMethodHandle() + { + return methodHandle; + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/procedure/Procedure.java b/presto-spi/src/main/java/com/facebook/presto/spi/procedure/Procedure.java index 557336ae95a4c..a223f836ca3dc 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/procedure/Procedure.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/procedure/Procedure.java @@ -14,10 +14,8 @@ package com.facebook.presto.spi.procedure; import com.facebook.presto.common.type.TypeSignature; -import com.facebook.presto.spi.ConnectorSession; import jakarta.annotation.Nullable; -import java.lang.invoke.MethodHandle; import java.util.ArrayList; import java.util.HashSet; import java.util.List; @@ -30,19 +28,17 @@ import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.joining; -public class Procedure +public abstract class Procedure { private final String schema; private final String name; private final List arguments; - private final MethodHandle methodHandle; - public Procedure(String schema, String name, List arguments, MethodHandle methodHandle) + public Procedure(String schema, String name, List arguments) { this.schema = checkNotNullOrEmpty(schema, "schema").toLowerCase(ENGLISH); this.name = checkNotNullOrEmpty(name, "name").toLowerCase(ENGLISH); this.arguments = unmodifiableList(new ArrayList<>(arguments)); - this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); Set names = new HashSet<>(); for (Argument argument : arguments) { @@ -54,14 +50,6 @@ public Procedure(String schema, String name, List arguments, MethodHan throw new IllegalArgumentException("Optional arguments should follow required ones"); } } - - checkArgument(!methodHandle.isVarargsCollector(), "Method must have fixed arity"); - checkArgument(methodHandle.type().returnType() == void.class, "Method must return void"); - - long parameterCount = methodHandle.type().parameterList().stream() - .filter(type -> !ConnectorSession.class.isAssignableFrom(type)) - .count(); - checkArgument(parameterCount == arguments.size(), "Method parameter count must match arguments"); } public String getSchema() @@ -79,11 +67,6 @@ public List getArguments() return arguments; } - public MethodHandle getMethodHandle() - { - return methodHandle; - } - @Override public String toString() { @@ -164,7 +147,7 @@ private static String checkNotNullOrEmpty(String value, String name) return value; } - private static void checkArgument(boolean assertion, String message) + protected static void checkArgument(boolean assertion, String message) { if (!assertion) { throw new IllegalArgumentException(message); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/procedure/ProcedureRegistry.java b/presto-spi/src/main/java/com/facebook/presto/spi/procedure/ProcedureRegistry.java new file mode 100644 index 0000000000000..fec73771d1475 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/procedure/ProcedureRegistry.java @@ -0,0 +1,32 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.procedure; + +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.SchemaTableName; + +import java.util.Collection; + +public interface ProcedureRegistry +{ + void addProcedures(ConnectorId connectorId, Collection procedures); + + void removeProcedures(ConnectorId connectorId); + + Procedure resolve(ConnectorId connectorId, SchemaTableName name); + + DistributedProcedure resolveDistributed(ConnectorId connectorId, SchemaTableName name); + + boolean isDistributedProcedure(ConnectorId connectorId, SchemaTableName name); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/procedure/TableDataRewriteDistributedProcedure.java b/presto-spi/src/main/java/com/facebook/presto/spi/procedure/TableDataRewriteDistributedProcedure.java new file mode 100644 index 0000000000000..1102780c38d73 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/procedure/TableDataRewriteDistributedProcedure.java @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.procedure; + +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.connector.ConnectorProcedureContext; +import io.airlift.slice.Slice; + +import java.util.Collection; +import java.util.List; +import java.util.OptionalInt; +import java.util.function.Supplier; + +import static com.facebook.presto.spi.procedure.DistributedProcedure.DistributedProcedureType.TABLE_DATA_REWRITE; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class TableDataRewriteDistributedProcedure + extends DistributedProcedure +{ + public static final String SCHEMA = "schema"; + public static final String TABLE_NAME = "table_name"; + public static final String FILTER = "filter"; + + private final BeginCallDistributedProcedure beginCallDistributedProcedure; + private final FinishCallDistributedProcedure finishCallDistributedProcedure; + private Supplier contextSupplier; + private int schemaIndex = -1; + private int tableNameIndex = -1; + private OptionalInt filterIndex = OptionalInt.empty(); + + public TableDataRewriteDistributedProcedure(String schema, String name, + List arguments, + BeginCallDistributedProcedure beginCallDistributedProcedure, + FinishCallDistributedProcedure finishCallDistributedProcedure, + Supplier contextSupplier) + { + super(TABLE_DATA_REWRITE, schema, name, arguments); + this.beginCallDistributedProcedure = requireNonNull(beginCallDistributedProcedure, "beginCallDistributedProcedure is null"); + this.finishCallDistributedProcedure = requireNonNull(finishCallDistributedProcedure, "finishCallDistributedProcedure is null"); + this.contextSupplier = requireNonNull(contextSupplier, "contextSupplier is null"); + for (int i = 0; i < getArguments().size(); i++) { + if (getArguments().get(i).getName().equals(SCHEMA)) { + checkArgument(getArguments().get(i).getType().toString().equalsIgnoreCase("varchar"), + format("Argument `%s` must be string type", SCHEMA)); + schemaIndex = i; + } + else if (getArguments().get(i).getName().equals(TABLE_NAME)) { + checkArgument(getArguments().get(i).getType().toString().equalsIgnoreCase("varchar"), + format("Argument `%s` must be string type", TABLE_NAME)); + tableNameIndex = i; + } + else if (getArguments().get(i).getName().equals(FILTER)) { + filterIndex = OptionalInt.of(i); + } + } + checkArgument(schemaIndex >= 0 && tableNameIndex >= 0, + format("A distributed procedure need at least 2 arguments: `%s` and `%s` for the target table", SCHEMA, TABLE_NAME)); + } + + @Override + public ConnectorDistributedProcedureHandle begin(ConnectorSession session, ConnectorProcedureContext procedureContext, ConnectorTableLayoutHandle tableLayoutHandle, Object[] arguments) + { + return this.beginCallDistributedProcedure.begin(session, procedureContext, tableLayoutHandle, arguments); + } + + @Override + public void finish(ConnectorProcedureContext procedureContext, ConnectorDistributedProcedureHandle procedureHandle, Collection fragments) + { + this.finishCallDistributedProcedure.finish(procedureContext, procedureHandle, fragments); + } + + public ConnectorProcedureContext createContext() + { + return contextSupplier.get(); + } + + public String getSchema(Object[] parameters) + { + return (String) parameters[schemaIndex]; + } + + public String getTableName(Object[] parameters) + { + return (String) parameters[tableNameIndex]; + } + + public String getFilter(Object[] parameters) + { + if (filterIndex.isPresent()) { + return (String) parameters[filterIndex.getAsInt()]; + } + else { + return "TRUE"; + } + } + + @FunctionalInterface + public interface BeginCallDistributedProcedure + { + ConnectorDistributedProcedureHandle begin(ConnectorSession session, ConnectorProcedureContext procedureContext, ConnectorTableLayoutHandle tableLayoutHandle, Object[] arguments); + } + + @FunctionalInterface + public interface FinishCallDistributedProcedure + { + void finish(ConnectorProcedureContext procedureContext, ConnectorDistributedProcedureHandle procedureHandle, Collection fragments); + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/TestingProcedures.java b/presto-tests/src/main/java/com/facebook/presto/tests/TestingProcedures.java index 26561d70a75e3..20851c66ab85d 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/TestingProcedures.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/TestingProcedures.java @@ -16,6 +16,7 @@ import com.facebook.presto.annotation.UsedByGeneratedCode; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.procedure.LocalProcedure; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.procedure.Procedure.Argument; import com.facebook.presto.testing.ProcedureTester; @@ -169,7 +170,7 @@ public List getProcedures(String schema) private Procedure procedure(String schema, String name, String methodName, List arguments) { - return new Procedure(schema, name, arguments, handle(methodName)); + return new LocalProcedure(schema, name, arguments, handle(methodName)); } private MethodHandle handle(String name) diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestProcedureCall.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestProcedureCall.java index 2ba2039b27de4..9e4147be54bca 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/TestProcedureCall.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestProcedureCall.java @@ -14,9 +14,9 @@ package com.facebook.presto.tests; import com.facebook.presto.Session; -import com.facebook.presto.metadata.ProcedureRegistry; import com.facebook.presto.server.testing.TestingPrestoServer; import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.testing.ProcedureTester; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.tpch.TpchQueryRunnerBuilder; diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestProcedureCreation.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestProcedureCreation.java index dc30164935f37..708b3fc2a48c9 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/TestProcedureCreation.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestProcedureCreation.java @@ -15,14 +15,20 @@ package com.facebook.presto.tests; import com.facebook.presto.spi.ConnectorSession; -import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.LocalProcedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; +import com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure; +import com.facebook.presto.testing.TestProcedureRegistry; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; import java.util.List; import static com.facebook.presto.common.block.MethodHandleUtil.methodHandle; +import static com.facebook.presto.common.type.StandardTypes.INTEGER; +import static com.facebook.presto.common.type.StandardTypes.TIMESTAMP; import static com.facebook.presto.common.type.StandardTypes.VARCHAR; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @Test(singleThreaded = true) @@ -32,17 +38,17 @@ public class TestProcedureCreation public void shouldThrowExceptionWhenOptionalArgumentIsNotLast() { assertThatThrownBy(() -> createTestProcedure(ImmutableList.of( - new Procedure.Argument("name", VARCHAR, false, null), - new Procedure.Argument("name2", VARCHAR, true, null)))) + new Argument("name", VARCHAR, false, null), + new Argument("name2", VARCHAR, true, null)))) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Optional arguments should follow required ones"); assertThatThrownBy(() -> createTestProcedure(ImmutableList.of( - new Procedure.Argument("name", VARCHAR, true, null), - new Procedure.Argument("name2", VARCHAR, true, null), - new Procedure.Argument("name3", VARCHAR, true, null), - new Procedure.Argument("name4", VARCHAR, false, null), - new Procedure.Argument("name5", VARCHAR, true, null)))) + new Argument("name", VARCHAR, true, null), + new Argument("name2", VARCHAR, true, null), + new Argument("name3", VARCHAR, true, null), + new Argument("name4", VARCHAR, false, null), + new Argument("name5", VARCHAR, true, null)))) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Optional arguments should follow required ones"); } @@ -51,8 +57,8 @@ public void shouldThrowExceptionWhenOptionalArgumentIsNotLast() public void shouldThrowExceptionWhenArgumentNameRepeats() { assertThatThrownBy(() -> createTestProcedure(ImmutableList.of( - new Procedure.Argument("name", VARCHAR, false, null), - new Procedure.Argument("name", VARCHAR, true, null)))) + new Argument("name", VARCHAR, false, null), + new Argument("name", VARCHAR, true, null)))) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Duplicate argument name: 'name'"); } @@ -60,7 +66,7 @@ public void shouldThrowExceptionWhenArgumentNameRepeats() @Test public void shouldThrowExceptionWhenProcedureIsNonVoid() { - assertThatThrownBy(() -> new Procedure( + assertThatThrownBy(() -> new LocalProcedure( "schema", "name", ImmutableList.of(), @@ -72,7 +78,7 @@ public void shouldThrowExceptionWhenProcedureIsNonVoid() @Test public void shouldThrowExceptionWhenMethodHandleIsNull() { - assertThatThrownBy(() -> new Procedure( + assertThatThrownBy(() -> new LocalProcedure( "schema", "name", ImmutableList.of(), @@ -84,7 +90,7 @@ public void shouldThrowExceptionWhenMethodHandleIsNull() @Test public void shouldThrowExceptionWhenMethodHandleHasVarargs() { - assertThatThrownBy(() -> new Procedure( + assertThatThrownBy(() -> new LocalProcedure( "schema", "name", ImmutableList.of(), @@ -96,19 +102,90 @@ public void shouldThrowExceptionWhenMethodHandleHasVarargs() @Test public void shouldThrowExceptionWhenArgumentCountDoesntMatch() { - assertThatThrownBy(() -> new Procedure( + assertThatThrownBy(() -> new LocalProcedure( "schema", "name", ImmutableList.of( - new Procedure.Argument("name", VARCHAR, true, null), - new Procedure.Argument("name2", VARCHAR, true, null), - new Procedure.Argument("name3", VARCHAR, true, null)), + new Argument("name", VARCHAR, true, null), + new Argument("name2", VARCHAR, true, null), + new Argument("name3", VARCHAR, true, null)), methodHandle(Procedures.class, "fun1", ConnectorSession.class, Object.class))) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Method parameter count must match arguments"); } - private static Procedure createTestProcedure(List arguments) + @Test + public void showCreateDistributedProcedure() + { + assertThat(new TableDataRewriteDistributedProcedure( + "schema", + "name", + ImmutableList.of( + new Argument("name", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("schema", VARCHAR, false, null)), + (session, transactionContext, tableLayoutHandle, arguments) -> null, + (transactionContext, procedureHandle, fragments) -> {}, + TestProcedureRegistry.TestProcedureContext::new)).isNotNull(); + } + + @Test + public void shouldThrowExceptionForDistributedProcedureWithWrongArgument() + { + assertThatThrownBy(() -> new TableDataRewriteDistributedProcedure( + "schema", + "name", + ImmutableList.of( + new Argument("name", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("name3", VARCHAR, false, null)), + (session, transactionContext, tableLayoutHandle, arguments) -> null, + (transactionContext, procedureHandle, fragments) -> {}, + TestProcedureRegistry.TestProcedureContext::new)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("A distributed procedure need at least 2 arguments: `schema` and `table_name` for the target table"); + + assertThatThrownBy(() -> new TableDataRewriteDistributedProcedure( + "schema", + "name", + ImmutableList.of( + new Argument("name", VARCHAR), + new Argument("name2", VARCHAR), + new Argument("schema", VARCHAR, false, null)), + (session, transactionContext, tableLayoutHandle, arguments) -> null, + (transactionContext, procedureHandle, fragments) -> {}, + TestProcedureRegistry.TestProcedureContext::new)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("A distributed procedure need at least 2 arguments: `schema` and `table_name` for the target table"); + + assertThatThrownBy(() -> new TableDataRewriteDistributedProcedure( + "schema", + "name", + ImmutableList.of( + new Argument("name", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("schema", INTEGER, false, 123)), + (session, transactionContext, tableLayoutHandle, arguments) -> null, + (transactionContext, procedureHandle, fragments) -> {}, + TestProcedureRegistry.TestProcedureContext::new)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Argument `schema` must be string type"); + + assertThatThrownBy(() -> new TableDataRewriteDistributedProcedure( + "schema", + "name", + ImmutableList.of( + new Argument("name", VARCHAR), + new Argument("table_name", TIMESTAMP), + new Argument("schema", VARCHAR, false, null)), + (session, transactionContext, tableLayoutHandle, arguments) -> null, + (transactionContext, procedureHandle, fragments) -> {}, + TestProcedureRegistry.TestProcedureContext::new)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Argument `table_name` must be string type"); + } + + private static LocalProcedure createTestProcedure(List arguments) { int argumentsCount = arguments.size(); String functionName = "fun" + argumentsCount; @@ -120,7 +197,7 @@ private static Procedure createTestProcedure(List arguments) clazzes[i + 1] = Object.class; } - return new Procedure( + return new LocalProcedure( "schema", "name", arguments,