optimizers = ImmutableList.of(
@@ -262,9 +272,14 @@ protected void assertMinimallyOptimizedPlanDoesNotMatch(@Language("SQL") String
}
protected void assertPlanWithSession(@Language("SQL") String sql, Session session, boolean noExchange, PlanMatchPattern pattern)
+ {
+ assertPlanWithSession(sql, session, noExchange, false, pattern);
+ }
+
+ protected void assertPlanWithSession(@Language("SQL") String sql, Session session, boolean noExchange, boolean nativeExecutionEnabled, PlanMatchPattern pattern)
{
queryRunner.inTransaction(session, transactionSession -> {
- Plan actualPlan = queryRunner.createPlan(transactionSession, sql, Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, noExchange, WarningCollector.NOOP);
+ Plan actualPlan = queryRunner.createPlan(transactionSession, sql, Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, noExchange, nativeExecutionEnabled, WarningCollector.NOOP);
PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getStatsCalculator(), actualPlan, pattern);
return null;
});
diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestAddExchangesPlansWithFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestAddExchangesPlansWithFunctions.java
new file mode 100644
index 0000000000000..577779a780ad4
--- /dev/null
+++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestAddExchangesPlansWithFunctions.java
@@ -0,0 +1,636 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.facebook.presto.sql.planner.optimizations;
+
+import com.facebook.presto.common.QualifiedObjectName;
+import com.facebook.presto.common.type.StandardTypes;
+import com.facebook.presto.common.type.TypeSignature;
+import com.facebook.presto.functionNamespace.SqlInvokedFunctionNamespaceManagerConfig;
+import com.facebook.presto.functionNamespace.execution.NoopSqlFunctionExecutor;
+import com.facebook.presto.functionNamespace.execution.SqlFunctionExecutors;
+import com.facebook.presto.functionNamespace.testing.InMemoryFunctionNamespaceManager;
+import com.facebook.presto.metadata.SqlScalarFunction;
+import com.facebook.presto.operator.scalar.CombineHashFunction;
+import com.facebook.presto.spi.function.FunctionImplementationType;
+import com.facebook.presto.spi.function.Parameter;
+import com.facebook.presto.spi.function.RoutineCharacteristics;
+import com.facebook.presto.spi.function.SqlInvokedFunction;
+import com.facebook.presto.sql.analyzer.FeaturesConfig;
+import com.facebook.presto.sql.analyzer.FunctionsConfig;
+import com.facebook.presto.sql.planner.assertions.BasePlanTest;
+import com.facebook.presto.testing.LocalQueryRunner;
+import com.facebook.presto.tpch.TpchConnectorFactory;
+import com.facebook.presto.type.BigintOperators;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import org.testng.annotations.Test;
+
+import java.util.stream.Collectors;
+
+import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
+import static com.facebook.presto.operator.scalar.annotations.ScalarFromAnnotationsParser.parseFunctionDefinitions;
+import static com.facebook.presto.spi.function.FunctionVersion.notVersioned;
+import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC;
+import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP;
+import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.JAVA;
+import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT;
+import static com.facebook.presto.spi.plan.JoinType.INNER;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan;
+import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING;
+import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER;
+import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
+
+/**
+ * These are plan tests similar to what we have for other optimizers (e.g. {@link com.facebook.presto.sql.planner.TestPredicatePushdown})
+ * They test that the plan for a query after the optimizer runs is as expected.
+ * These are separate from {@link TestAddExchanges} because those are unit tests for
+ * how layouts get chosen.
+ *
+ * Key behavior tested: When CPP functions are used with system tables, the filter containing
+ * the CPP function is preserved above the exchange (not pushed down) to ensure the filter
+ * executes in a different fragment from the system table scan. This validates the fragment
+ * boundary between CPP function evaluation and system table access.
+ */
+public class TestAddExchangesPlansWithFunctions
+ extends BasePlanTest
+{
+ public TestAddExchangesPlansWithFunctions()
+ {
+ super(TestAddExchangesPlansWithFunctions::createTestQueryRunner);
+ }
+
+ private static final SqlInvokedFunction CPP_FOO = new SqlInvokedFunction(
+ new QualifiedObjectName("dummy", "unittest", "cpp_foo"),
+ ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))),
+ parseTypeSignature(StandardTypes.BIGINT),
+ "cpp_foo(x)",
+ RoutineCharacteristics.builder().setLanguage(CPP).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(),
+ "",
+ notVersioned());
+
+ private static final SqlInvokedFunction CPP_BAZ = new SqlInvokedFunction(
+ new QualifiedObjectName("dummy", "unittest", "cpp_baz"),
+ ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))),
+ parseTypeSignature(StandardTypes.BIGINT),
+ "cpp_baz(x)",
+ RoutineCharacteristics.builder().setLanguage(CPP).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(),
+ "",
+ notVersioned());
+
+ private static final SqlInvokedFunction JAVA_BAR = new SqlInvokedFunction(
+ new QualifiedObjectName("dummy", "unittest", "java_bar"),
+ ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))),
+ parseTypeSignature(StandardTypes.BIGINT),
+ "java_bar(x)",
+ RoutineCharacteristics.builder().setLanguage(JAVA).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(),
+ "",
+ notVersioned());
+
+ private static final SqlInvokedFunction JAVA_FEE = new SqlInvokedFunction(
+ new QualifiedObjectName("dummy", "unittest", "java_fee"),
+ ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))),
+ parseTypeSignature(StandardTypes.BIGINT),
+ "java_fee(x)",
+ RoutineCharacteristics.builder().setLanguage(JAVA).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(),
+ "",
+ notVersioned());
+
+ private static final SqlInvokedFunction NOT = new SqlInvokedFunction(
+ new QualifiedObjectName("dummy", "unittest", "not"),
+ ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BOOLEAN))),
+ parseTypeSignature(StandardTypes.BOOLEAN),
+ "not(x)",
+ RoutineCharacteristics.builder().setLanguage(CPP).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(),
+ "",
+ notVersioned());
+
+ private static final SqlInvokedFunction CPP_ARRAY_CONSTRUCTOR = new SqlInvokedFunction(
+ new QualifiedObjectName("dummy", "unittest", "array_constructor"),
+ ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT)), new Parameter("y", parseTypeSignature(StandardTypes.BIGINT))),
+ parseTypeSignature("array(bigint)"),
+ "array_constructor(x, y)",
+ RoutineCharacteristics.builder().setLanguage(CPP).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(),
+ "",
+ notVersioned());
+
+ private static LocalQueryRunner createTestQueryRunner()
+ {
+ LocalQueryRunner queryRunner = new LocalQueryRunner(testSessionBuilder()
+ .setCatalog("tpch")
+ .setSchema("tiny")
+ .build(),
+ new FeaturesConfig(),
+ new FunctionsConfig().setDefaultNamespacePrefix("dummy.unittest"));
+ queryRunner.createCatalog("tpch", new TpchConnectorFactory(), ImmutableMap.of());
+ queryRunner.getMetadata().getFunctionAndTypeManager().addFunctionNamespace(
+ "dummy",
+ new InMemoryFunctionNamespaceManager(
+ "dummy",
+ new SqlFunctionExecutors(
+ ImmutableMap.of(
+ CPP, FunctionImplementationType.CPP,
+ JAVA, FunctionImplementationType.JAVA),
+ new NoopSqlFunctionExecutor()),
+ new SqlInvokedFunctionNamespaceManagerConfig().setSupportedFunctionLanguages("cpp")));
+ queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(CPP_FOO, true);
+ queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(CPP_BAZ, true);
+ queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(JAVA_BAR, true);
+ queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(JAVA_FEE, true);
+ queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(NOT, true);
+ queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(CPP_ARRAY_CONSTRUCTOR, true);
+ parseFunctionDefinitions(BigintOperators.class).stream()
+ .map(TestAddExchangesPlansWithFunctions::convertToSqlInvokedFunction)
+ .forEach(function -> queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(function, true));
+ parseFunctionDefinitions(CombineHashFunction.class).stream()
+ .map(TestAddExchangesPlansWithFunctions::convertToSqlInvokedFunction)
+ .forEach(function -> queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(function, true));
+ return queryRunner;
+ }
+
+ public static SqlInvokedFunction convertToSqlInvokedFunction(SqlScalarFunction scalarFunction)
+ {
+ QualifiedObjectName functionName = new QualifiedObjectName("dummy", "unittest", scalarFunction.getSignature().getName().getObjectName());
+ TypeSignature returnType = scalarFunction.getSignature().getReturnType();
+ RoutineCharacteristics characteristics = RoutineCharacteristics.builder()
+ .setLanguage(RoutineCharacteristics.Language.JAVA) // Assuming JAVA as the language
+ .setDeterminism(RoutineCharacteristics.Determinism.DETERMINISTIC)
+ .setNullCallClause(RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT)
+ .build();
+
+ // Convert scalar function arguments to SqlInvokedFunction parameters
+ ImmutableList parameters = scalarFunction.getSignature().getArgumentTypes().stream()
+ .map(type -> new Parameter(type.toString(), TypeSignature.parseTypeSignature(type.toString())))
+ .collect(Collectors.collectingAndThen(Collectors.toList(), ImmutableList::copyOf));
+
+ // Create the SqlInvokedFunction
+ return new SqlInvokedFunction(
+ functionName,
+ parameters,
+ returnType,
+ scalarFunction.getSignature().getName().toString(), // Using the function name as the body for simplicity
+ characteristics,
+ "", // Empty description
+ notVersioned());
+ }
+
+ @Test
+ public void testFilterWithCppFunctionDoesNotGetPushedIntoSystemTableScan()
+ {
+ // java_fee and java_bar are java functions, they are both pushed down past the exchange
+ assertNativeDistributedPlan("SELECT java_fee(ordinal_position) FROM information_schema.columns WHERE java_bar(ordinal_position) = 1",
+ anyTree(
+ exchange(REMOTE_STREAMING, GATHER,
+ project(ImmutableMap.of("java_fee", expression("java_fee(ordinal_position)")),
+ filter("java_bar(ordinal_position) = BIGINT'1'",
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))));
+ // cpp_foo is a CPP function, it is not pushed down past the exchange because the source is a system table scan
+ // The filter is preserved above the exchange to prove that the filter is not in the same fragment as the system table scan
+ assertNativeDistributedPlan("SELECT cpp_baz(ordinal_position) FROM information_schema.columns WHERE cpp_foo(ordinal_position) = 1",
+ anyTree(
+ project(ImmutableMap.of("cpp_baz", expression("cpp_baz(ordinal_position)")),
+ filter("cpp_foo(ordinal_position) = BIGINT'1'",
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))));
+ }
+
+ @Test
+ public void testJoinWithCppFunctionDoesNotGetPushedIntoSystemTableScan()
+ {
+ // java_bar is a java function, it is pushed down past the exchange
+ assertNativeDistributedPlan(
+ "SELECT c1.table_name FROM information_schema.columns c1 JOIN information_schema.columns c2 ON c1.ordinal_position = c2.ordinal_position WHERE java_bar(c1.ordinal_position) = 1",
+ anyTree(
+ exchange(
+ join(INNER, ImmutableList.of(equiJoinClause("ordinal_position", "ordinal_position_4")),
+ anyTree(
+ exchange(REMOTE_STREAMING, GATHER,
+ project(
+ filter("java_bar(ordinal_position) = BIGINT'1'",
+ tableScan("columns", ImmutableMap.of(
+ "ordinal_position", "ordinal_position",
+ "table_name", "table_name")))))),
+ anyTree(
+ exchange(REMOTE_STREAMING, GATHER,
+ project(
+ filter("java_bar(ordinal_position_4) = BIGINT'1'",
+ tableScan("columns", ImmutableMap.of(
+ "ordinal_position_4", "ordinal_position"))))))))));
+
+ // cpp_foo is a CPP function, it is not pushed down past the exchange because the source is a system table scan
+ assertNativeDistributedPlan(
+ "SELECT cpp_baz(c1.ordinal_position) FROM information_schema.columns c1 JOIN information_schema.columns c2 ON c1.ordinal_position = c2.ordinal_position WHERE cpp_foo(c1.ordinal_position) = 1",
+ output(
+ exchange(
+ project(ImmutableMap.of("cpp_baz", expression("cpp_baz(ordinal_position)")),
+ join(INNER, ImmutableList.of(equiJoinClause("ordinal_position", "ordinal_position_4")),
+ anyTree(
+ filter("cpp_foo(ordinal_position) = BIGINT'1'",
+ exchange(REMOTE_STREAMING, GATHER,
+ project(
+ tableScan("columns", ImmutableMap.of(
+ "ordinal_position", "ordinal_position")))))),
+ anyTree(
+ filter("cpp_foo(ordinal_position_4) = BIGINT'1'",
+ exchange(REMOTE_STREAMING, GATHER,
+ project(
+ tableScan("columns", ImmutableMap.of(
+ "ordinal_position_4", "ordinal_position")))))))))));
+ }
+
+ @Test
+ public void testMixedFunctionTypesInComplexPredicates()
+ {
+ // Test AND condition with mixed Java and CPP functions
+ assertNativeDistributedPlan(
+ "SELECT * FROM information_schema.columns WHERE java_bar(ordinal_position) = 1 AND cpp_foo(ordinal_position) > 0",
+ anyTree(
+ filter("java_bar(ordinal_position) = BIGINT'1' AND cpp_foo(ordinal_position) > BIGINT'0'",
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))));
+
+ // Test OR condition with mixed functions - entire predicate should be evaluated after exchange
+ assertNativeDistributedPlan(
+ "SELECT * FROM information_schema.columns WHERE java_bar(ordinal_position) = 1 OR cpp_foo(ordinal_position) = 2",
+ anyTree(
+ filter("java_bar(ordinal_position) = BIGINT'1' OR cpp_foo(ordinal_position) = BIGINT'2'",
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))));
+ }
+
+ @Test
+ public void testNestedFunctionCalls()
+ {
+ // CPP function nested inside Java function - should not push down
+ assertNativeDistributedPlan(
+ "SELECT * FROM information_schema.columns WHERE java_bar(cpp_foo(ordinal_position)) = 1",
+ anyTree(
+ filter("java_bar(cpp_foo(ordinal_position)) = BIGINT'1'",
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))));
+
+ // Java function nested inside CPP function - should not push down
+ assertNativeDistributedPlan(
+ "SELECT * FROM information_schema.columns WHERE cpp_foo(java_bar(ordinal_position)) = 1",
+ anyTree(
+ filter("cpp_foo(java_bar(ordinal_position)) = BIGINT'1'",
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))));
+
+ // Multiple levels of nesting
+ assertNativeDistributedPlan(
+ "SELECT * FROM information_schema.columns WHERE cpp_foo(java_bar(cpp_foo(ordinal_position))) = 1",
+ anyTree(
+ filter("cpp_foo(java_bar(cpp_foo(ordinal_position))) = BIGINT'1'",
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))));
+ }
+
+ @Test
+ public void testMixedSystemAndRegularTables()
+ {
+ // System table with CPP function joined with regular table
+ assertNativeDistributedPlan(
+ "SELECT * FROM information_schema.columns c JOIN nation n ON c.ordinal_position = n.nationkey WHERE cpp_foo(c.ordinal_position) = 1",
+ output(
+ join(INNER, ImmutableList.of(equiJoinClause("ordinal_position", "nationkey")),
+ filter("cpp_foo(ordinal_position) = BIGINT'1'",
+ exchange(REMOTE_STREAMING, GATHER,
+ project(ImmutableMap.of("ordinal_position", expression("ordinal_position")),
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))),
+ anyTree(
+ project(ImmutableMap.of("nationkey", expression("nationkey")),
+ filter(
+ tableScan("nation", ImmutableMap.of("nationkey", "nationkey"))))))));
+
+ // Regular table with CPP function (should work normally without extra exchange)
+ assertNativeDistributedPlan(
+ "SELECT * FROM nation WHERE cpp_foo(nationkey) = 1",
+ anyTree(
+ exchange(REMOTE_STREAMING, GATHER,
+ filter("cpp_foo(nationkey) = BIGINT'1'",
+ tableScan("nation", ImmutableMap.of("nationkey", "nationkey"))))));
+ }
+
+ @Test
+ public void testAggregationsWithMixedFunctions()
+ {
+ // Aggregation with CPP function in GROUP BY
+ assertNativeDistributedPlan(
+ "SELECT DISTINCT cpp_foo(ordinal_position) FROM information_schema.columns",
+ anyTree(
+ project(ImmutableMap.of("cpp_foo", expression("cpp_foo(ordinal_position)")),
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))));
+
+ // Aggregation with Java function in GROUP BY - can be pushed down
+ assertNativeDistributedPlan(
+ "SELECT DISTINCT java_bar(ordinal_position) FROM information_schema.columns",
+ anyTree(
+ exchange(REMOTE_STREAMING, GATHER,
+ project(ImmutableMap.of("java_bar", expression("java_bar(ordinal_position)")),
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))));
+ }
+
+ @Test
+ public void testComplexPredicateWithMultipleFunctions()
+ {
+ // Complex predicate with multiple CPP and Java functions
+ // Since the predicate contains CPP functions (cpp_foo, baz), the exchange is inserted before the system table scan
+ // The RemoveRedundantExchanges rule removes the inner exchange that was added by ExtractIneligiblePredicatesFromSystemTableScans
+ assertNativeDistributedPlan(
+ "SELECT * FROM information_schema.columns WHERE (cpp_foo(ordinal_position) > 0 AND java_bar(ordinal_position) < 100) OR cpp_baz(ordinal_position) = 50",
+ anyTree(
+ filter(
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))));
+ }
+
+ @Test
+ public void testProjectionWithMixedFunctions()
+ {
+ // Projection with both Java and CPP functions
+ assertNativeDistributedPlan(
+ "SELECT java_bar(ordinal_position) as java_result, cpp_foo(ordinal_position) as cpp_result FROM information_schema.columns",
+ anyTree(
+ project(ImmutableMap.of(
+ "java_result", expression("java_bar(ordinal_position)"),
+ "cpp_result", expression("cpp_foo(ordinal_position)")),
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))));
+ }
+
+ @Test
+ public void testCaseStatementsWithCppFunctions()
+ {
+ // CASE statement with CPP function in condition
+ // The RemoveRedundantExchanges optimizer removes the redundant exchange
+ assertNativeDistributedPlan(
+ "SELECT CASE WHEN cpp_foo(ordinal_position) > 0 THEN 'positive' ELSE 'negative' END FROM information_schema.columns",
+ anyTree(
+ project(
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))));
+
+ // CASE statement with CPP function in result
+ // The RemoveRedundantExchanges optimizer removes the redundant exchange
+ assertNativeDistributedPlan(
+ "SELECT CASE WHEN ordinal_position > 0 THEN cpp_foo(ordinal_position) ELSE 0 END FROM information_schema.columns",
+ anyTree(
+ project(
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))));
+ }
+
+ @Test
+ public void testBuiltinFunctionWithExplicitNamespace()
+ {
+ // Test that built-in functions with explicit namespace are handled correctly
+ assertNativeDistributedPlan(
+ "SELECT * FROM information_schema.columns WHERE presto.default.length(table_name) > 10",
+ anyTree(
+ exchange(REMOTE_STREAMING, GATHER,
+ filter("length(table_name) > BIGINT'10'",
+ tableScan("columns", ImmutableMap.of("table_name", "table_name"))))));
+ }
+
+ @Test(enabled = false) // TODO: Window functions are resolved with namespace which causes issues in tests
+ public void testWindowFunctionsWithCppFunctions()
+ {
+ // Window function with CPP function in partition by
+ assertNativeDistributedPlan(
+ "SELECT row_number() OVER (PARTITION BY cpp_foo(ordinal_position)) FROM information_schema.columns",
+ anyTree(
+ exchange(
+ project(
+ anyTree(
+ project(
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))))));
+
+ // Window function with CPP function in order by
+ assertNativeDistributedPlan(
+ "SELECT row_number() OVER (ORDER BY cpp_foo(ordinal_position)) FROM information_schema.columns",
+ anyTree(
+ exchange(
+ project(
+ anyTree(
+ project(
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))))));
+ }
+
+ @Test
+ public void testMultipleSystemTableJoins()
+ {
+ // Multiple system tables with CPP functions
+ // This test verifies that when joining two system tables with a CPP function comparison,
+ // an exchange is added between the table scan and the join to ensure CPP functions
+ // execute in a separate fragment from system table access
+ assertNativeDistributedPlan(
+ "SELECT * FROM information_schema.columns c1 " +
+ "JOIN information_schema.columns c2 ON cpp_foo(c1.ordinal_position) = cpp_foo(c2.ordinal_position)",
+ anyTree(
+ exchange(
+ join(INNER, ImmutableList.of(equiJoinClause("cpp_foo", "foo_4")),
+ exchange(
+ project(ImmutableMap.of("cpp_foo", expression("cpp_foo")),
+ project(ImmutableMap.of("cpp_foo", expression("cpp_foo(ordinal_position)")),
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))),
+ anyTree(
+ exchange(
+ project(ImmutableMap.of("foo_4", expression("foo_4")),
+ project(ImmutableMap.of("foo_4", expression("cpp_foo(ordinal_position_4)")),
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position_4", "ordinal_position")))))))))));
+ }
+
+ @Test
+ public void testInPredicateWithCppFunction()
+ {
+ // IN predicate with CPP function
+ assertNativeDistributedPlan(
+ "SELECT * FROM information_schema.columns WHERE cpp_foo(ordinal_position) IN (1, 2, 3)",
+ anyTree(
+ filter("cpp_foo(ordinal_position) IN (BIGINT'1', BIGINT'2', BIGINT'3')",
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))));
+ }
+
+ @Test
+ public void testBetweenPredicateWithCppFunction()
+ {
+ // BETWEEN predicate with CPP function
+ assertNativeDistributedPlan(
+ "SELECT * FROM information_schema.columns WHERE cpp_foo(ordinal_position) BETWEEN 1 AND 10",
+ anyTree(
+ filter("cpp_foo(ordinal_position) BETWEEN BIGINT'1' AND BIGINT'10'",
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))));
+ }
+
+ @Test
+ public void testNullHandlingWithCppFunctions()
+ {
+ // IS NULL check with CPP function
+ assertNativeDistributedPlan(
+ "SELECT * FROM information_schema.columns WHERE cpp_foo(ordinal_position) IS NULL",
+ anyTree(
+ filter("cpp_foo(ordinal_position) IS NULL",
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))));
+
+ // COALESCE with CPP function
+ // The RemoveRedundantExchanges optimizer removes the redundant exchange
+ assertNativeDistributedPlan(
+ "SELECT COALESCE(cpp_foo(ordinal_position), 0) FROM information_schema.columns",
+ anyTree(
+ project(
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))));
+ }
+
+ @Test
+ public void testUnionWithCppFunctions()
+ {
+ // UNION ALL with CPP functions from system tables
+ assertNativeDistributedPlan(
+ "SELECT cpp_foo(ordinal_position) FROM information_schema.columns " +
+ "UNION ALL SELECT cpp_foo(nationkey) FROM nation",
+ output(
+ exchange(
+ anyTree(
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))),
+ anyTree(
+ tableScan("nation", ImmutableMap.of("nationkey", "nationkey"))))));
+ }
+
+ @Test
+ public void testExistsSubqueryWithCppFunction()
+ {
+ // EXISTS subquery with CPP function
+ assertNativeDistributedPlan(
+ "SELECT * FROM information_schema.columns c WHERE EXISTS (SELECT 1 FROM nation n WHERE cpp_foo(c.ordinal_position) = n.nationkey)",
+ anyTree(
+ join(
+ anyTree(
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))),
+ anyTree(
+ tableScan("nation", ImmutableMap.of("nationkey", "nationkey"))))));
+ }
+
+ @Test
+ public void testLimitWithCppFunction()
+ {
+ // LIMIT with CPP function in ORDER BY
+ assertNativeDistributedPlan(
+ "SELECT * FROM information_schema.columns ORDER BY cpp_foo(ordinal_position) LIMIT 10",
+ output(
+ project(
+ anyTree(
+ project(
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))))));
+ }
+
+ @Test
+ public void testCastOperationsWithCppFunctions()
+ {
+ // CAST operations with CPP functions
+ // The RemoveRedundantExchanges optimizer removes the redundant exchange
+ assertNativeDistributedPlan(
+ "SELECT CAST(cpp_foo(ordinal_position) AS VARCHAR) FROM information_schema.columns",
+ anyTree(
+ project(
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))));
+ }
+
+ @Test
+ public void testArrayConstructorWithCppFunction()
+ {
+ // Array constructor with CPP function
+ assertNativeDistributedPlan(
+ "SELECT ARRAY[cpp_foo(ordinal_position), cpp_baz(ordinal_position)] FROM information_schema.columns",
+ anyTree(
+ project(
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))));
+ }
+
+ @Test
+ public void testRowConstructorWithCppFunction()
+ {
+ // ROW constructor with CPP function
+ // The RemoveRedundantExchanges optimizer removes the redundant exchange
+ assertNativeDistributedPlan(
+ "SELECT ROW(cpp_foo(ordinal_position), table_name) FROM information_schema.columns",
+ anyTree(
+ project(
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of(
+ "ordinal_position", "ordinal_position",
+ "table_name", "table_name"))))));
+ }
+
+ @Test
+ public void testIsNotNullWithCppFunction()
+ {
+ // IS NOT NULL check with CPP function
+ assertNativeDistributedPlan(
+ "SELECT * FROM information_schema.columns WHERE cpp_foo(ordinal_position) IS NOT NULL",
+ anyTree(
+ filter("cpp_foo(ordinal_position) IS NOT NULL",
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))));
+ }
+
+ @Test
+ public void testComplexJoinWithMultipleCppFunctions()
+ {
+ // Complex join with multiple CPP functions in different positions
+ // The filters are pushed into FilterProject nodes and the join happens on the expression cpp_foo(c1.ordinal_position)
+ assertNativeDistributedPlan(
+ "SELECT c1.table_name, n.name FROM information_schema.columns c1 " +
+ "JOIN nation n ON cpp_foo(c1.ordinal_position) = n.nationkey " +
+ "WHERE cpp_baz(c1.ordinal_position) > 0 AND cpp_foo(n.nationkey) < 100",
+ anyTree(
+ join(INNER, ImmutableList.of(equiJoinClause("cpp_foo", "nationkey")),
+ project(ImmutableMap.of("table_name", expression("table_name"), "cpp_foo", expression("cpp_foo")),
+ project(ImmutableMap.of("cpp_foo", expression("cpp_foo(ordinal_position)")),
+ filter("cpp_baz(ordinal_position) > BIGINT'0' AND cpp_foo(cpp_foo(ordinal_position)) < BIGINT'100'",
+ exchange(REMOTE_STREAMING, GATHER,
+ tableScan("columns", ImmutableMap.of(
+ "ordinal_position", "ordinal_position",
+ "table_name", "table_name")))))),
+ anyTree(
+ project(ImmutableMap.of("nationkey", expression("nationkey"),
+ "name", expression("name")),
+ filter("cpp_foo(nationkey) < BIGINT'100'",
+ tableScan("nation", ImmutableMap.of(
+ "nationkey", "nationkey",
+ "name", "name"))))))));
+ }
+}
diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/sanity/TestCheckNoIneligibleFunctionsInCoordinatorFragments.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/sanity/TestCheckNoIneligibleFunctionsInCoordinatorFragments.java
new file mode 100644
index 0000000000000..106b526a6b75e
--- /dev/null
+++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/sanity/TestCheckNoIneligibleFunctionsInCoordinatorFragments.java
@@ -0,0 +1,474 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.sanity;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.common.QualifiedObjectName;
+import com.facebook.presto.common.type.StandardTypes;
+import com.facebook.presto.functionNamespace.SqlInvokedFunctionNamespaceManagerConfig;
+import com.facebook.presto.functionNamespace.execution.NoopSqlFunctionExecutor;
+import com.facebook.presto.functionNamespace.execution.SqlFunctionExecutors;
+import com.facebook.presto.functionNamespace.testing.InMemoryFunctionNamespaceManager;
+import com.facebook.presto.metadata.Metadata;
+import com.facebook.presto.spi.ConnectorId;
+import com.facebook.presto.spi.TableHandle;
+import com.facebook.presto.spi.TestingColumnHandle;
+import com.facebook.presto.spi.WarningCollector;
+import com.facebook.presto.spi.function.FunctionImplementationType;
+import com.facebook.presto.spi.function.Parameter;
+import com.facebook.presto.spi.function.RoutineCharacteristics;
+import com.facebook.presto.spi.function.SqlInvokedFunction;
+import com.facebook.presto.spi.plan.JoinType;
+import com.facebook.presto.spi.plan.Partitioning;
+import com.facebook.presto.spi.plan.PartitioningScheme;
+import com.facebook.presto.spi.plan.PlanNode;
+import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
+import com.facebook.presto.spi.relation.VariableReferenceExpression;
+import com.facebook.presto.sql.analyzer.FeaturesConfig;
+import com.facebook.presto.sql.analyzer.FunctionsConfig;
+import com.facebook.presto.sql.planner.assertions.BasePlanTest;
+import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
+import com.facebook.presto.sql.planner.plan.ExchangeNode;
+import com.facebook.presto.testing.LocalQueryRunner;
+import com.facebook.presto.testing.TestingMetadata.TestingTableHandle;
+import com.facebook.presto.testing.TestingTransactionHandle;
+import com.facebook.presto.tpch.TpchConnectorFactory;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import org.testng.annotations.Test;
+
+import java.util.Optional;
+import java.util.function.Function;
+
+import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
+import static com.facebook.presto.common.type.BigintType.BIGINT;
+import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
+import static com.facebook.presto.common.type.VarcharType.VARCHAR;
+import static com.facebook.presto.spi.function.FunctionVersion.notVersioned;
+import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC;
+import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP;
+import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.JAVA;
+import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT;
+import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
+import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment;
+import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
+
+public class TestCheckNoIneligibleFunctionsInCoordinatorFragments
+ extends BasePlanTest
+{
+ // CPP function for testing (similar to TestAddExchangesPlansWithFunctions)
+ private static final SqlInvokedFunction CPP_FUNC = new SqlInvokedFunction(
+ new QualifiedObjectName("dummy", "unittest", "cpp_func"),
+ ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.VARCHAR))),
+ parseTypeSignature(StandardTypes.VARCHAR),
+ "cpp_func(x)",
+ RoutineCharacteristics.builder().setLanguage(CPP).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(),
+ "",
+ notVersioned());
+
+ // JAVA function for testing
+ private static final SqlInvokedFunction JAVA_FUNC = new SqlInvokedFunction(
+ new QualifiedObjectName("dummy", "unittest", "java_func"),
+ ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.VARCHAR))),
+ parseTypeSignature(StandardTypes.VARCHAR),
+ "java_func(x)",
+ RoutineCharacteristics.builder().setLanguage(JAVA).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(),
+ "",
+ notVersioned());
+
+ public TestCheckNoIneligibleFunctionsInCoordinatorFragments()
+ {
+ super(TestCheckNoIneligibleFunctionsInCoordinatorFragments::createTestQueryRunner);
+ }
+
+ private static LocalQueryRunner createTestQueryRunner()
+ {
+ LocalQueryRunner queryRunner = new LocalQueryRunner(
+ testSessionBuilder()
+ .setCatalog("local")
+ .setSchema("tiny")
+ .build(),
+ new FeaturesConfig().setNativeExecutionEnabled(true),
+ new FunctionsConfig().setDefaultNamespacePrefix("dummy.unittest"));
+
+ queryRunner.createCatalog("local", new TpchConnectorFactory(), ImmutableMap.of());
+
+ // Add function namespace with both CPP and JAVA functions
+ queryRunner.getMetadata().getFunctionAndTypeManager().addFunctionNamespace(
+ "dummy",
+ new InMemoryFunctionNamespaceManager(
+ "dummy",
+ new SqlFunctionExecutors(
+ ImmutableMap.of(
+ CPP, FunctionImplementationType.CPP,
+ JAVA, FunctionImplementationType.JAVA),
+ new NoopSqlFunctionExecutor()),
+ new SqlInvokedFunctionNamespaceManagerConfig().setSupportedFunctionLanguages("CPP,JAVA")));
+
+ // Register the functions
+ queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(CPP_FUNC, false);
+ queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(JAVA_FUNC, false);
+
+ return queryRunner;
+ }
+
+ @Test
+ public void testSystemTableScanWithJavaFunctionPasses()
+ {
+ // System table scan with Java function in same fragment should pass
+ validatePlan(
+ p -> {
+ VariableReferenceExpression col = p.variable("col", VARCHAR);
+ VariableReferenceExpression result = p.variable("result", VARCHAR);
+
+ // Create a system table scan - using proper system connector ID
+ TableHandle systemTableHandle = new TableHandle(
+ ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")),
+ new TestingTableHandle(),
+ TestingTransactionHandle.create(),
+ Optional.empty());
+
+ PlanNode tableScan = p.tableScan(
+ systemTableHandle,
+ ImmutableList.of(col),
+ ImmutableMap.of(col, new TestingColumnHandle("col")));
+
+ // Java function (using our registered java_func)
+ return p.project(
+ assignment(result, p.rowExpression("java_func(col)")),
+ tableScan);
+ });
+ }
+
+ @Test(expectedExceptions = IllegalStateException.class,
+ expectedExceptionsMessageRegExp = "Fragment contains both system table scan and non-Java functions.*")
+ public void testSystemTableScanWithCppFunctionInProjectFails()
+ {
+ // System table scan with C++ function in same fragment should fail
+ validatePlan(
+ p -> {
+ VariableReferenceExpression col = p.variable("col", VARCHAR);
+ VariableReferenceExpression result = p.variable("result", VARCHAR);
+
+ // System table scan
+ TableHandle systemTableHandle = new TableHandle(
+ ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")),
+ new TestingTableHandle(),
+ TestingTransactionHandle.create(),
+ Optional.empty());
+
+ PlanNode systemScan = p.tableScan(
+ systemTableHandle,
+ ImmutableList.of(col),
+ ImmutableMap.of(col, new TestingColumnHandle("col")));
+
+ // C++ function (using our registered cpp_func)
+ return p.project(
+ assignment(result, p.rowExpression("cpp_func(col)")),
+ systemScan);
+ });
+ }
+
+ @Test(expectedExceptions = IllegalStateException.class,
+ expectedExceptionsMessageRegExp = "Fragment contains both system table scan and non-Java functions.*")
+ public void testSystemTableScanWithCppFunctionInFilterFails()
+ {
+ // System table scan with C++ function in filter should fail
+ validatePlan(
+ p -> {
+ VariableReferenceExpression col = p.variable("col", VARCHAR);
+
+ // System table scan
+ TableHandle systemTableHandle = new TableHandle(
+ ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")),
+ new TestingTableHandle(),
+ TestingTransactionHandle.create(),
+ Optional.empty());
+
+ PlanNode systemScan = p.tableScan(
+ systemTableHandle,
+ ImmutableList.of(col),
+ ImmutableMap.of(col, new TestingColumnHandle("col")));
+
+ // Filter with C++ function
+ return p.filter(
+ p.rowExpression("cpp_func(col) = 'test'"),
+ systemScan);
+ });
+ }
+
+ @Test
+ public void testSystemTableScanWithCppFunctionSeparatedByExchangePasses()
+ {
+ // System table scan and C++ function separated by exchange should pass
+ validatePlan(
+ p -> {
+ VariableReferenceExpression col = p.variable("col", VARCHAR);
+ VariableReferenceExpression result = p.variable("result", VARCHAR);
+
+ // System table scan
+ TableHandle systemTableHandle = new TableHandle(
+ ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")),
+ new TestingTableHandle(),
+ TestingTransactionHandle.create(),
+ Optional.empty());
+
+ PlanNode systemScan = p.tableScan(
+ systemTableHandle,
+ ImmutableList.of(col),
+ ImmutableMap.of(col, new TestingColumnHandle("col")));
+
+ // Exchange creates fragment boundary
+ PartitioningScheme partitioningScheme = new PartitioningScheme(
+ Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()),
+ ImmutableList.of(col));
+
+ ExchangeNode exchange = new ExchangeNode(
+ Optional.empty(),
+ p.getIdAllocator().getNextId(),
+ ExchangeNode.Type.GATHER,
+ ExchangeNode.Scope.LOCAL,
+ partitioningScheme,
+ ImmutableList.of(systemScan),
+ ImmutableList.of(ImmutableList.of(col)),
+ false,
+ Optional.empty());
+
+ // C++ function in different fragment
+ return p.project(
+ assignment(result, p.rowExpression("cpp_func(col)")),
+ exchange);
+ });
+ }
+
+ @Test
+ public void testRegularTableScanWithCppFunctionPasses()
+ {
+ // Regular table scan with C++ function should pass (no system table)
+ validatePlan(
+ p -> {
+ VariableReferenceExpression col = p.variable("col", VARCHAR);
+ VariableReferenceExpression result = p.variable("result", VARCHAR);
+
+ // Regular table scan (not system)
+ TableHandle regularTableHandle = new TableHandle(
+ new ConnectorId("local"),
+ new TestingTableHandle(),
+ TestingTransactionHandle.create(),
+ Optional.empty());
+
+ PlanNode regularScan = p.tableScan(
+ regularTableHandle,
+ ImmutableList.of(col),
+ ImmutableMap.of(col, new TestingColumnHandle("col")));
+
+ // C++ function
+ return p.project(
+ assignment(result, p.rowExpression("cpp_func(col)")),
+ regularScan);
+ });
+ }
+
+ @Test(expectedExceptions = IllegalStateException.class,
+ expectedExceptionsMessageRegExp = "Fragment contains both system table scan and non-Java functions.*")
+ public void testMultipleFragmentsWithCppFunctionInSystemFragment()
+ {
+ // Complex plan where CPP function is in same fragment as system table scan (should fail)
+ validatePlan(
+ p -> {
+ VariableReferenceExpression col1 = p.variable("col1", VARCHAR);
+ VariableReferenceExpression col2 = p.variable("col2", BIGINT);
+ VariableReferenceExpression col3 = p.variable("col3", BIGINT);
+
+ // Fragment 1: System table scan
+ TableHandle systemTableHandle = new TableHandle(
+ ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")),
+ new TestingTableHandle(),
+ TestingTransactionHandle.create(),
+ Optional.empty());
+
+ PlanNode systemScan = p.tableScan(
+ systemTableHandle,
+ ImmutableList.of(col1),
+ ImmutableMap.of(col1, new TestingColumnHandle("col1")));
+
+ // Convert to numeric for join (using CPP function - this should fail)
+ PlanNode project1 = p.project(
+ assignment(col2, p.rowExpression("cast(cpp_func(col1) as bigint)")),
+ systemScan);
+
+ // Fragment 2: Regular values with computation
+ PlanNode values = p.values(col3);
+
+ // Exchange to separate fragments
+ PartitioningScheme partitioningScheme1 = new PartitioningScheme(
+ Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()),
+ ImmutableList.of(col2));
+
+ ExchangeNode exchange1 = new ExchangeNode(
+ Optional.empty(),
+ p.getIdAllocator().getNextId(),
+ ExchangeNode.Type.GATHER,
+ ExchangeNode.Scope.LOCAL,
+ partitioningScheme1,
+ ImmutableList.of(project1),
+ ImmutableList.of(ImmutableList.of(col2)),
+ false,
+ Optional.empty());
+
+ PartitioningScheme partitioningScheme2 = new PartitioningScheme(
+ Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()),
+ ImmutableList.of(col3));
+
+ ExchangeNode exchange2 = new ExchangeNode(
+ Optional.empty(),
+ p.getIdAllocator().getNextId(),
+ ExchangeNode.Type.GATHER,
+ ExchangeNode.Scope.LOCAL,
+ partitioningScheme2,
+ ImmutableList.of(values),
+ ImmutableList.of(ImmutableList.of(col3)),
+ false,
+ Optional.empty());
+
+ // Join the results
+ return p.join(
+ JoinType.INNER,
+ exchange1,
+ exchange2,
+ p.rowExpression("col2 = col3"));
+ });
+ }
+
+ @Test
+ public void testMultipleFragmentsWithExchange()
+ {
+ // Complex plan with multiple fragments properly separated (Java function - should pass)
+ validatePlan(
+ p -> {
+ VariableReferenceExpression col1 = p.variable("col1", VARCHAR);
+ VariableReferenceExpression col2 = p.variable("col2", BIGINT);
+ VariableReferenceExpression col3 = p.variable("col3", BIGINT);
+
+ // Fragment 1: System table scan
+ TableHandle systemTableHandle = new TableHandle(
+ ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")),
+ new TestingTableHandle(),
+ TestingTransactionHandle.create(),
+ Optional.empty());
+
+ PlanNode systemScan = p.tableScan(
+ systemTableHandle,
+ ImmutableList.of(col1),
+ ImmutableMap.of(col1, new TestingColumnHandle("col1")));
+
+ // Convert to numeric for join (using Java function)
+ PlanNode project1 = p.project(
+ assignment(col2, p.rowExpression("cast(java_func(col1) as bigint)")),
+ systemScan);
+
+ // Fragment 2: Regular values with computation
+ PlanNode values = p.values(col3);
+
+ // Exchange to separate fragments
+ PartitioningScheme partitioningScheme1 = new PartitioningScheme(
+ Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()),
+ ImmutableList.of(col2));
+
+ ExchangeNode exchange1 = new ExchangeNode(
+ Optional.empty(),
+ p.getIdAllocator().getNextId(),
+ ExchangeNode.Type.GATHER,
+ ExchangeNode.Scope.LOCAL,
+ partitioningScheme1,
+ ImmutableList.of(project1),
+ ImmutableList.of(ImmutableList.of(col2)),
+ false,
+ Optional.empty());
+
+ PartitioningScheme partitioningScheme2 = new PartitioningScheme(
+ Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()),
+ ImmutableList.of(col3));
+
+ ExchangeNode exchange2 = new ExchangeNode(
+ Optional.empty(),
+ p.getIdAllocator().getNextId(),
+ ExchangeNode.Type.GATHER,
+ ExchangeNode.Scope.LOCAL,
+ partitioningScheme2,
+ ImmutableList.of(values),
+ ImmutableList.of(ImmutableList.of(col3)),
+ false,
+ Optional.empty());
+
+ // Join the results
+ return p.join(
+ JoinType.INNER,
+ exchange1,
+ exchange2,
+ p.rowExpression("col2 = col3"));
+ });
+ }
+
+ @Test
+ public void testFilterAndProjectWithSystemTable()
+ {
+ // Test filter and project both with Java functions on system table
+ validatePlan(
+ p -> {
+ VariableReferenceExpression col = p.variable("col", VARCHAR);
+ VariableReferenceExpression len = p.variable("len", BIGINT);
+
+ // System table scan
+ TableHandle systemTableHandle = new TableHandle(
+ ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")),
+ new TestingTableHandle(),
+ TestingTransactionHandle.create(),
+ Optional.empty());
+
+ PlanNode systemScan = p.tableScan(
+ systemTableHandle,
+ ImmutableList.of(col),
+ ImmutableMap.of(col, new TestingColumnHandle("col")));
+
+ // Filter with Java function
+ PlanNode filtered = p.filter(
+ p.rowExpression("java_func(col) = 'test'"),
+ systemScan);
+
+ // Project with Java function
+ return p.project(
+ assignment(len, p.rowExpression("cast(java_func(col) as bigint)")),
+ filtered);
+ });
+ }
+
+ private void validatePlan(Function planProvider)
+ {
+ Session session = testSessionBuilder()
+ .setCatalog("local")
+ .setSchema("tiny")
+ .build();
+
+ PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();
+ Metadata metadata = getQueryRunner().getMetadata();
+ PlanBuilder builder = new PlanBuilder(TEST_SESSION, idAllocator, metadata);
+ PlanNode planNode = planProvider.apply(builder);
+
+ getQueryRunner().inTransaction(session, transactionSession -> {
+ new CheckNoIneligibleFunctionsInCoordinatorFragments().validate(planNode, transactionSession, metadata, WarningCollector.NOOP);
+ return null;
+ });
+ }
+}
diff --git a/presto-main/pom.xml b/presto-main/pom.xml
index 277a3126aab40..07e0f4828ceb0 100644
--- a/presto-main/pom.xml
+++ b/presto-main/pom.xml
@@ -267,13 +267,11 @@
io.projectreactor.netty
reactor-netty-core
- 1.1.29
io.projectreactor.netty
reactor-netty-http
- 1.1.29
diff --git a/presto-native-execution/presto_cpp/main/QueryContextManager.cpp b/presto-native-execution/presto_cpp/main/QueryContextManager.cpp
index f67df7f111f23..e48a1aa0acc8d 100644
--- a/presto-native-execution/presto_cpp/main/QueryContextManager.cpp
+++ b/presto-native-execution/presto_cpp/main/QueryContextManager.cpp
@@ -70,7 +70,7 @@ void updateFromSystemConfigs(
const auto& systemConfigName = configNameEntry.second;
if (queryConfigs.count(veloxConfigName) == 0) {
const auto propertyOpt = systemConfig->optionalProperty(systemConfigName);
- if (propertyOpt.hasValue()) {
+ if (propertyOpt.has_value()) {
queryConfigs[veloxConfigName] = propertyOpt.value();
}
}
diff --git a/presto-native-execution/presto_cpp/main/common/ConfigReader.cpp b/presto-native-execution/presto_cpp/main/common/ConfigReader.cpp
index 027193a2f045a..b10d9962c0bb9 100644
--- a/presto-native-execution/presto_cpp/main/common/ConfigReader.cpp
+++ b/presto-native-execution/presto_cpp/main/common/ConfigReader.cpp
@@ -87,7 +87,7 @@ std::string requiredProperty(
const velox::config::ConfigBase& properties,
const std::string& name) {
auto value = properties.get(name);
- if (!value.hasValue()) {
+ if (!value.has_value()) {
VELOX_USER_FAIL("Missing configuration property {}", name);
}
return value.value();
@@ -120,7 +120,7 @@ std::string getOptionalProperty(
const std::string& name,
const std::string& defaultValue) {
auto value = properties.get(name);
- if (!value.hasValue()) {
+ if (!value.has_value()) {
return defaultValue;
}
return value.value();
diff --git a/presto-native-execution/presto_cpp/main/common/Configs.cpp b/presto-native-execution/presto_cpp/main/common/Configs.cpp
index 893e21f224d0c..91a47abccb551 100644
--- a/presto-native-execution/presto_cpp/main/common/Configs.cpp
+++ b/presto-native-execution/presto_cpp/main/common/Configs.cpp
@@ -110,7 +110,7 @@ folly::Optional ConfigBase::setValue(
propertyName);
auto oldValue = config_->get(propertyName);
config_->set(propertyName, value);
- if (oldValue.hasValue()) {
+ if (oldValue.has_value()) {
return oldValue;
}
return registeredProps_[propertyName];
@@ -372,7 +372,7 @@ SystemConfig::remoteFunctionServerLocation() const {
// First check if there is a UDS path registered. If there's one, use it.
auto remoteServerUdsPath =
optionalProperty(kRemoteFunctionServerThriftUdsPath);
- if (remoteServerUdsPath.hasValue()) {
+ if (remoteServerUdsPath.has_value()) {
return folly::SocketAddress::makeFromPath(remoteServerUdsPath.value());
}
@@ -382,13 +382,13 @@ SystemConfig::remoteFunctionServerLocation() const {
auto remoteServerPort =
optionalProperty(kRemoteFunctionServerThriftPort);
- if (remoteServerPort.hasValue()) {
+ if (remoteServerPort.has_value()) {
// Fallback to localhost if address is not specified.
- return remoteServerAddress.hasValue()
+ return remoteServerAddress.has_value()
? folly::
SocketAddress{remoteServerAddress.value(), remoteServerPort.value()}
: folly::SocketAddress{"::1", remoteServerPort.value()};
- } else if (remoteServerAddress.hasValue()) {
+ } else if (remoteServerAddress.has_value()) {
VELOX_FAIL(
"Remote function server port not provided using '{}'.",
kRemoteFunctionServerThriftPort);
@@ -959,7 +959,7 @@ int NodeConfig::prometheusExecutorThreads() const {
static constexpr int
kNodePrometheusExecutorThreadsDefault = 2;
auto resultOpt = optionalProperty(kNodePrometheusExecutorThreads);
- if (resultOpt.hasValue()) {
+ if (resultOpt.has_value()) {
return resultOpt.value();
}
return kNodePrometheusExecutorThreadsDefault;
@@ -967,7 +967,7 @@ int NodeConfig::prometheusExecutorThreads() const {
std::string NodeConfig::nodeId() const {
auto resultOpt = optionalProperty(kNodeId);
- if (resultOpt.hasValue()) {
+ if (resultOpt.has_value()) {
return resultOpt.value();
}
// Generate the nodeId which must be a UUID. nodeId must be a singleton.
@@ -985,7 +985,7 @@ std::string NodeConfig::nodeInternalAddress(
auto resultOpt = optionalProperty(kNodeInternalAddress);
/// node.ip(kNodeIp) is legacy config replaced with node.internal-address, but
/// still valid config in Presto, so handling both.
- if (!resultOpt.hasValue()) {
+ if (!resultOpt.has_value()) {
resultOpt = optionalProperty(kNodeIp);
}
if (resultOpt.has_value()) {
diff --git a/presto-native-execution/presto_cpp/main/common/Configs.h b/presto-native-execution/presto_cpp/main/common/Configs.h
index 6512e9e47046f..148469e0c351b 100644
--- a/presto-native-execution/presto_cpp/main/common/Configs.h
+++ b/presto-native-execution/presto_cpp/main/common/Configs.h
@@ -95,7 +95,7 @@ class ConfigBase {
template
folly::Optional optionalProperty(const std::string& propertyName) const {
auto valOpt = config_->get(propertyName);
- if (valOpt.hasValue()) {
+ if (valOpt.has_value()) {
return valOpt.value();
}
const auto it = registeredProps_.find(propertyName);
@@ -115,7 +115,7 @@ class ConfigBase {
folly::Optional optionalProperty(
const std::string& propertyName) const {
auto val = config_->get(propertyName);
- if (val.hasValue()) {
+ if (val.has_value()) {
return val;
}
const auto it = registeredProps_.find(propertyName);
diff --git a/presto-native-execution/scripts/dockerfiles/centos-dependency.dockerfile b/presto-native-execution/scripts/dockerfiles/centos-dependency.dockerfile
index 207270d69f970..e67770e78a0b8 100644
--- a/presto-native-execution/scripts/dockerfiles/centos-dependency.dockerfile
+++ b/presto-native-execution/scripts/dockerfiles/centos-dependency.dockerfile
@@ -25,9 +25,10 @@ COPY velox/CMake/resolve_dependency_modules/arrow/cmake-compatibility.patch /vel
ENV VELOX_ARROW_CMAKE_PATCH=/velox/cmake-compatibility.patch
RUN bash -c "mkdir build && \
(cd build && ../scripts/setup-centos.sh && \
- ../velox/scripts/setup-centos9.sh install_adapters && \
../scripts/setup-adapters.sh && \
source ../velox/scripts/setup-centos9.sh && \
+ source ../velox/scripts/setup-centos-adapters.sh && \
+ install_adapters && \
install_clang15 && \
install_cuda 12.8) && \
rm -rf build"
diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java
index 2e25e413e3b21..31bd66b36092c 100644
--- a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java
+++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java
@@ -318,9 +318,8 @@ public void testApproxPercentile()
@Test
public void testInformationSchemaTables()
{
- assertQueryFails("select lower(table_name) from information_schema.tables "
- + "where table_name = 'lineitem' or table_name = 'LINEITEM' ",
- "Compiler failed");
+ assertQuery("select lower(table_name) from information_schema.tables "
+ + "where table_name = 'lineitem' or table_name = 'LINEITEM' ");
}
@Test
diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/SystemConnectorTests.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/SystemConnectorTests.java
index fa48afa418524..3317d07e90f8c 100644
--- a/presto-product-tests/src/main/java/com/facebook/presto/tests/SystemConnectorTests.java
+++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/SystemConnectorTests.java
@@ -13,14 +13,20 @@
*/
package com.facebook.presto.tests;
+import com.google.common.collect.ImmutableList;
import io.prestodb.tempto.ProductTest;
+import io.prestodb.tempto.assertions.QueryAssert;
import org.testng.annotations.Test;
import java.sql.JDBCType;
+import java.util.List;
+import java.util.stream.Collectors;
import static com.facebook.presto.tests.TestGroups.JDBC;
import static com.facebook.presto.tests.TestGroups.SYSTEM_CONNECTOR;
import static com.facebook.presto.tests.utils.JdbcDriverUtils.usingTeradataJdbcDriver;
+import static com.facebook.presto.tests.utils.QueryExecutors.onPresto;
+import static io.prestodb.tempto.assertions.QueryAssert.Row.row;
import static io.prestodb.tempto.assertions.QueryAssert.assertThat;
import static io.prestodb.tempto.query.QueryExecutor.defaultQueryExecutor;
import static io.prestodb.tempto.query.QueryExecutor.query;
@@ -107,4 +113,41 @@ public void selectMetadataCatalogs()
.hasColumns(VARCHAR, VARCHAR)
.hasAnyRows();
}
+
+ @Test(groups = SYSTEM_CONNECTOR)
+ public void selectJdbcColumns()
+ {
+ try {
+ String hiveSQL = "select table_name, column_name from system.jdbc.columns where table_cat = 'hive' AND table_schem = 'default'";
+ String icebergSQL = "select table_name, column_name from system.jdbc.columns where table_cat = 'iceberg' AND table_schem = 'default'";
+
+ List preexistingHiveColumns = onPresto().executeQuery(hiveSQL).rows().stream()
+ .map(list -> row(list.toArray()))
+ .collect(Collectors.toList());
+
+ List preexistingIcebergColumns = onPresto().executeQuery(icebergSQL).rows().stream()
+ .map(list -> row(list.toArray()))
+ .collect(Collectors.toList());
+
+ onPresto().executeQuery("CREATE TABLE hive.default.test_hive_system_jdbc_columns (_double DOUBLE)");
+ onPresto().executeQuery("CREATE TABLE iceberg.default.test_iceberg_system_jdbc_columns (_string VARCHAR, _integer INTEGER)");
+
+ assertThat(onPresto().executeQuery(hiveSQL))
+ .containsOnly(ImmutableList.builder()
+ .addAll(preexistingHiveColumns)
+ .add(row("test_hive_system_jdbc_columns", "_double"))
+ .build());
+
+ assertThat(onPresto().executeQuery(icebergSQL))
+ .containsOnly(ImmutableList.builder()
+ .addAll(preexistingIcebergColumns)
+ .add(row("test_iceberg_system_jdbc_columns", "_string"))
+ .add(row("test_iceberg_system_jdbc_columns", "_integer"))
+ .build());
+ }
+ finally {
+ onPresto().executeQuery("DROP TABLE IF EXISTS hive.default.test_hive_system_jdbc_columns");
+ onPresto().executeQuery("DROP TABLE IF EXISTS iceberg.default.test_iceberg_system_jdbc_columns");
+ }
+ }
}
diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java b/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java
index 2da8ad1970eec..0164dd10c3af1 100644
--- a/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java
+++ b/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java
@@ -153,4 +153,9 @@ default Iterable getClientRequestFilterFactories()
{
return emptyList();
}
+
+ default Set> getSqlInvokedFunctions()
+ {
+ return emptySet();
+ }
}
diff --git a/presto-tests/src/test/java/com/facebook/presto/functions/TestDuplicateSqlInvokedFunctions.java b/presto-tests/src/test/java/com/facebook/presto/functions/TestDuplicateSqlInvokedFunctions.java
new file mode 100644
index 0000000000000..e3df08e051039
--- /dev/null
+++ b/presto-tests/src/test/java/com/facebook/presto/functions/TestDuplicateSqlInvokedFunctions.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.functions;
+
+import com.facebook.presto.spi.function.Description;
+import com.facebook.presto.spi.function.SqlInvokedScalarFunction;
+import com.facebook.presto.spi.function.SqlParameter;
+import com.facebook.presto.spi.function.SqlType;
+import com.facebook.presto.spi.function.TypeParameter;
+
+public final class TestDuplicateSqlInvokedFunctions
+{
+ private TestDuplicateSqlInvokedFunctions() {}
+
+ @SqlInvokedScalarFunction(value = "array_intersect", deterministic = true, calledOnNullInput = false)
+ @Description("Intersects elements of all arrays in the given array")
+ @TypeParameter("T")
+ @SqlParameter(name = "input", type = "array>")
+ @SqlType("array")
+ public static String arrayIntersectArray()
+ {
+ return "RETURN reduce(input, IF((cardinality(input) = 0), ARRAY[], input[1]), (s, x) -> array_intersect(s, x), (s) -> s)";
+ }
+}
diff --git a/presto-tests/src/test/java/com/facebook/presto/functions/TestFunctions.java b/presto-tests/src/test/java/com/facebook/presto/functions/TestFunctions.java
new file mode 100644
index 0000000000000..532c12acbc65a
--- /dev/null
+++ b/presto-tests/src/test/java/com/facebook/presto/functions/TestFunctions.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.functions;
+
+import com.facebook.presto.common.type.StandardTypes;
+import com.facebook.presto.spi.function.Description;
+import com.facebook.presto.spi.function.ScalarFunction;
+import com.facebook.presto.spi.function.SqlType;
+import io.airlift.slice.Slice;
+
+public final class TestFunctions
+{
+ private TestFunctions()
+ {}
+
+ @Description("Returns modulo of value by numberOfBuckets")
+ @ScalarFunction
+ @SqlType(StandardTypes.BIGINT)
+ public static long modulo(
+ @SqlType(StandardTypes.BIGINT) long value,
+ @SqlType(StandardTypes.BIGINT) long numberOfBuckets)
+ {
+ return value % numberOfBuckets;
+ }
+
+ @Description(("Return the input string"))
+ @ScalarFunction
+ @SqlType(StandardTypes.VARCHAR)
+ public static Slice identity(@SqlType(StandardTypes.VARCHAR) Slice slice)
+ {
+ return slice;
+ }
+}
diff --git a/presto-tests/src/test/java/com/facebook/presto/functions/TestPluginLoadedDuplicateSqlInvokedFunctions.java b/presto-tests/src/test/java/com/facebook/presto/functions/TestPluginLoadedDuplicateSqlInvokedFunctions.java
new file mode 100644
index 0000000000000..4ab7e80a82906
--- /dev/null
+++ b/presto-tests/src/test/java/com/facebook/presto/functions/TestPluginLoadedDuplicateSqlInvokedFunctions.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.functions;
+
+import com.facebook.presto.common.type.TimeZoneKey;
+import com.facebook.presto.server.testing.TestingPrestoServer;
+import com.facebook.presto.spi.Plugin;
+import com.facebook.presto.tests.TestingPrestoClient;
+import com.google.common.collect.ImmutableSet;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import java.util.Set;
+import java.util.regex.Pattern;
+
+import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE;
+import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
+import static java.lang.String.format;
+import static org.testng.Assert.fail;
+
+public class TestPluginLoadedDuplicateSqlInvokedFunctions
+{
+ protected TestingPrestoServer server;
+ protected TestingPrestoClient client;
+
+ @BeforeClass
+ public void setup()
+ throws Exception
+ {
+ server = new TestingPrestoServer();
+ server.installPlugin(new TestDuplicateFunctionsPlugin());
+ client = new TestingPrestoClient(server, testSessionBuilder()
+ .setTimeZoneKey(TimeZoneKey.getTimeZoneKey("America/Bahia_Banderas"))
+ .build());
+ }
+
+ public void assertInvalidFunction(String expr, String exceptionPattern)
+ {
+ try {
+ client.execute("SELECT " + expr);
+ fail("Function expected to fail but not");
+ }
+ catch (Exception e) {
+ if (!(e.getMessage().matches(exceptionPattern))) {
+ fail(format("Expected exception message '%s' to match '%s' but not",
+ e.getMessage(), exceptionPattern));
+ }
+ }
+ }
+
+ private static class TestDuplicateFunctionsPlugin
+ implements Plugin
+ {
+ @Override
+ public Set> getSqlInvokedFunctions()
+ {
+ return ImmutableSet.>builder()
+ .add(TestDuplicateSqlInvokedFunctions.class)
+ .build();
+ }
+ }
+
+ @Test
+ public void testDuplicateFunctionsLoaded()
+ {
+ assertInvalidFunction(JAVA_BUILTIN_NAMESPACE + ".modulo(10,3)",
+ Pattern.quote(format("java.lang.IllegalArgumentException: Function already registered: %s.array_intersect(array(array(T))):array(T)", JAVA_BUILTIN_NAMESPACE)));
+ }
+}
diff --git a/presto-tests/src/test/java/com/facebook/presto/functions/TestPluginLoadedSqlInvokedFunctions.java b/presto-tests/src/test/java/com/facebook/presto/functions/TestPluginLoadedSqlInvokedFunctions.java
new file mode 100644
index 0000000000000..233d04e737603
--- /dev/null
+++ b/presto-tests/src/test/java/com/facebook/presto/functions/TestPluginLoadedSqlInvokedFunctions.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.functions;
+
+import com.facebook.presto.common.type.TimeZoneKey;
+import com.facebook.presto.common.type.Type;
+import com.facebook.presto.server.testing.TestingPrestoServer;
+import com.facebook.presto.spi.Plugin;
+import com.facebook.presto.testing.MaterializedResult;
+import com.facebook.presto.tests.TestingPrestoClient;
+import com.google.common.collect.ImmutableSet;
+import org.intellij.lang.annotations.Language;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import java.util.Set;
+
+import static com.facebook.presto.common.type.BigintType.BIGINT;
+import static com.facebook.presto.common.type.IntegerType.INTEGER;
+import static com.facebook.presto.common.type.VarcharType.VARCHAR;
+import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE;
+import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
+import static java.lang.String.format;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertThrows;
+import static org.testng.Assert.fail;
+
+public class TestPluginLoadedSqlInvokedFunctions
+{
+ protected TestingPrestoServer server;
+ protected TestingPrestoClient client;
+
+ private static final String CATALOG_NAME = JAVA_BUILTIN_NAMESPACE.getCatalogName();
+
+ @BeforeClass
+ public void setup()
+ throws Exception
+ {
+ server = new TestingPrestoServer();
+ server.installPlugin(new TestFunctionsPlugin());
+ client = new TestingPrestoClient(server, testSessionBuilder()
+ .setTimeZoneKey(TimeZoneKey.getTimeZoneKey("America/Bahia_Banderas"))
+ .build());
+ }
+
+ public void assertInvalidFunction(String expr, String exceptionPattern)
+ {
+ try {
+ client.execute("SELECT " + expr);
+ fail("Function expected to fail but not");
+ }
+ catch (Exception e) {
+ if (!(e.getMessage().matches(exceptionPattern))) {
+ fail(format("Expected exception message '%s' to match '%s' but not",
+ e.getMessage(), exceptionPattern));
+ }
+ }
+ }
+
+ private static class TestFunctionsPlugin
+ implements Plugin
+ {
+ @Override
+ public Set> getSqlInvokedFunctions()
+ {
+ return ImmutableSet.>builder()
+ .add(TestSqlInvokedFunctionsPlugin.class)
+ .build();
+ }
+
+ @Override
+ public Set> getFunctions()
+ {
+ return ImmutableSet.>builder()
+ .add(TestFunctions.class)
+ .build();
+ }
+ }
+
+ public void check(@Language("SQL") String query, Type expectedType, Object expectedValue)
+ {
+ MaterializedResult result = client.execute(query).getResult();
+ assertEquals(result.getRowCount(), 1);
+ assertEquals(result.getTypes().get(0), expectedType);
+ Object actual = result.getMaterializedRows().get(0).getField(0);
+ assertEquals(actual, expectedValue);
+ }
+
+ @Test
+ public void testNewFunctionNamespaceFunction()
+ {
+ check("SELECT " + JAVA_BUILTIN_NAMESPACE + ".modulo(10,3)", BIGINT, 1L);
+ check("SELECT " + JAVA_BUILTIN_NAMESPACE + ".identity('test-functions')", VARCHAR, "test-functions");
+ check("SELECT " + JAVA_BUILTIN_NAMESPACE + ".custom_square(2, 3)", INTEGER, 4);
+ check("SELECT " + JAVA_BUILTIN_NAMESPACE + ".custom_square(null, 3)", INTEGER, 9);
+ }
+
+ @Test
+ public void testInvalidFunctionAndNamespace()
+ {
+ assertInvalidFunction(CATALOG_NAME + ".namespace.modulo(10,3)", format("line 1:8: Function %s.namespace.modulo not registered", CATALOG_NAME));
+ assertInvalidFunction(CATALOG_NAME + ".system.some_func(10)", format("line 1:8: Function %s.system.some_func not registered", CATALOG_NAME));
+ }
+
+ @Test(dependsOnMethods =
+ {"testNewFunctionNamespaceFunction",
+ "testInvalidFunctionAndNamespace"})
+ public void testDuplicateFunctionsLoaded()
+ {
+ // Because we trigger the conflict check as soon as the plugins are loaded,
+ // this will throw an Exception: Function already registered: presto.default.modulo(bigint,bigint):bigint while installing the plugin itself
+ assertThrows(IllegalArgumentException.class, () -> server.installPlugin(new TestFunctionsPlugin()));
+ }
+}
diff --git a/presto-tests/src/test/java/com/facebook/presto/functions/TestSqlInvokedFunctionsPlugin.java b/presto-tests/src/test/java/com/facebook/presto/functions/TestSqlInvokedFunctionsPlugin.java
new file mode 100644
index 0000000000000..65ec6162758f8
--- /dev/null
+++ b/presto-tests/src/test/java/com/facebook/presto/functions/TestSqlInvokedFunctionsPlugin.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.functions;
+
+import com.facebook.presto.spi.function.Description;
+import com.facebook.presto.spi.function.SqlInvokedScalarFunction;
+import com.facebook.presto.spi.function.SqlParameter;
+import com.facebook.presto.spi.function.SqlParameters;
+import com.facebook.presto.spi.function.SqlType;
+
+public final class TestSqlInvokedFunctionsPlugin
+{
+ private TestSqlInvokedFunctionsPlugin()
+ {}
+
+ @SqlInvokedScalarFunction(value = "custom_square", deterministic = true, calledOnNullInput = false)
+ @Description("Custom SQL to test NULLIF in Functions")
+ @SqlParameters({@SqlParameter(name = "x", type = "integer"), @SqlParameter(name = "y", type = "integer")})
+ @SqlType("integer")
+ public static String customSquare()
+ {
+ return "RETURN IF(NULLIF(x, y) IS NOT NULL, x * x, y * y)";
+ }
+}