diff --git a/.github/workflows/prestocpp-linux-build-and-unit-test.yml b/.github/workflows/prestocpp-linux-build-and-unit-test.yml index b28f2292252ea..f3fc4843c3321 100644 --- a/.github/workflows/prestocpp-linux-build-and-unit-test.yml +++ b/.github/workflows/prestocpp-linux-build-and-unit-test.yml @@ -370,7 +370,7 @@ jobs: # Use different Maven options to install. MAVEN_OPTS: "-Xmx2G -XX:+ExitOnOutOfMemoryError" run: | - for i in $(seq 1 3); do ./mvnw clean install $MAVEN_FAST_INSTALL -pl 'presto-native-execution' -am && s=0 && break || s=$? && sleep 10; done; (exit $s) + for i in $(seq 1 3); do ./mvnw clean install $MAVEN_FAST_INSTALL -pl 'presto-native-sidecar-plugin' -am && s=0 && break || s=$? && sleep 10; done; (exit $s) - name: Run presto-native sidecar tests if: | diff --git a/pom.xml b/pom.xml index d5ac80313ec19..f73719dacd94a 100644 --- a/pom.xml +++ b/pom.xml @@ -224,6 +224,7 @@ presto-router-example-plugin-scheduler presto-plan-checker-router-plugin presto-sql-invoked-functions-plugin + presto-native-sql-invoked-functions-plugin diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java index f34359aeba6af..1c7856a62caea 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; -import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.Varchars; @@ -55,7 +54,6 @@ import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.VarcharType.VARCHAR; -import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static com.facebook.presto.metadata.CastType.CAST; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_NOT_FOUND; import static com.facebook.presto.spi.StandardWarningCode.SAMPLED_FIELDS; @@ -150,7 +148,7 @@ private PlanNode addSamplingFilter(PlanNode tableScanNode, Optional arguments) - { - FunctionHandle functionHandle = functionAndTypeManager.lookupFunction(qualifiedObjectName, fromTypes(arguments.stream().map(RowExpression::getType).collect(toImmutableList()))); - return call(String.valueOf(qualifiedObjectName), functionHandle, returnType, arguments); - } - public static CallExpression call(FunctionAndTypeResolver functionAndTypeResolver, String name, Type returnType, RowExpression... arguments) { FunctionHandle functionHandle = functionAndTypeResolver.lookupFunction(name, fromTypes(Arrays.stream(arguments).map(RowExpression::getType).collect(toImmutableList()))); diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java index 3650f8cb36531..59a60b0024842 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java @@ -29,7 +29,7 @@ private NativeQueryRunnerUtils() {} public static Map getNativeWorkerHiveProperties() { return ImmutableMap.of("hive.parquet.pushdown-filter-enabled", "true", - "hive.orc-compression-codec", "ZSTD", "hive.storage-format", "DWRF"); + "hive.orc-compression-codec", "ZSTD", "hive.storage-format", "DWRF"); } public static Map getNativeWorkerIcebergProperties() @@ -59,6 +59,8 @@ public static Map getNativeSidecarProperties() .put("coordinator-sidecar-enabled", "true") .put("exclude-invalid-worker-session-properties", "true") .put("presto.default-namespace", "native.default") + // inline-sql-functions is overridden to be true in sidecar enabled native clusters. + .put("inline-sql-functions", "true") .build(); } diff --git a/presto-native-sidecar-plugin/pom.xml b/presto-native-sidecar-plugin/pom.xml index b2bc40e8f2bd4..d04424440469a 100644 --- a/presto-native-sidecar-plugin/pom.xml +++ b/presto-native-sidecar-plugin/pom.xml @@ -260,9 +260,25 @@ + com.facebook.presto presto-built-in-worker-function-tools + ${project.version} + + + + com.facebook.presto + presto-native-sql-invoked-functions-plugin + ${project.version} + test + + + + com.facebook.presto + presto-sql-invoked-functions-plugin + ${project.version} + test diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java index 776d4920e2f16..c8c7e1123f974 100644 --- a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java +++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sidecar; +import com.facebook.presto.scalar.sql.NativeSqlInvokedFunctionsPlugin; import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManagerFactory; import com.facebook.presto.sidecar.sessionpropertyproviders.NativeSystemSessionPropertyProviderFactory; import com.facebook.presto.sidecar.typemanager.NativeTypeManagerFactory; @@ -37,5 +38,6 @@ public static void setupNativeSidecarPlugin(QueryRunner queryRunner) "function-implementation-type", "CPP")); queryRunner.loadTypeManager(NativeTypeManagerFactory.NAME); queryRunner.loadPlanCheckerProviderManager("native", ImmutableMap.of()); + queryRunner.installPlugin(new NativeSqlInvokedFunctionsPlugin()); } } diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java index b4716f98362cf..fe0b18c24b2fb 100644 --- a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java +++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java @@ -16,6 +16,8 @@ import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils; +import com.facebook.presto.scalar.sql.NativeSqlInvokedFunctionsPlugin; +import com.facebook.presto.scalar.sql.SqlInvokedFunctionsPlugin; import com.facebook.presto.sidecar.functionNamespace.FunctionDefinitionProvider; import com.facebook.presto.sidecar.functionNamespace.NativeFunctionDefinitionProvider; import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManager; @@ -45,9 +47,12 @@ import java.util.stream.Collectors; import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.presto.SystemSessionProperties.INLINE_SQL_FUNCTIONS; +import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_ENABLED; import static com.facebook.presto.SystemSessionProperties.REMOVE_MAP_CAST; import static com.facebook.presto.common.Utils.checkArgument; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createCustomer; import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createLineitem; import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createNation; import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrders; @@ -65,6 +70,7 @@ public class TestNativeSidecarPlugin private static final String REGEX_FUNCTION_NAMESPACE = "native.default.*"; private static final String REGEX_SESSION_NAMESPACE = "Native Execution only.*"; private static final long SIDECAR_HTTP_CLIENT_MAX_CONTENT_SIZE_MB = 128; + private static final int INLINED_SQL_FUNCTIONS_COUNT = 7; @Override protected void createTables() @@ -75,6 +81,7 @@ protected void createTables() createOrders(queryRunner); createOrdersEx(queryRunner); createRegion(queryRunner); + createCustomer(queryRunner); } @Override @@ -93,9 +100,11 @@ protected QueryRunner createQueryRunner() protected QueryRunner createExpectedQueryRunner() throws Exception { - return PrestoNativeQueryRunnerUtils.javaHiveQueryRunnerBuilder() + QueryRunner queryRunner = PrestoNativeQueryRunnerUtils.javaHiveQueryRunnerBuilder() .setAddStorageFormatToPath(true) .build(); + queryRunner.installPlugin(new SqlInvokedFunctionsPlugin()); + return queryRunner; } public static void setupNativeSidecarPlugin(QueryRunner queryRunner) @@ -113,6 +122,7 @@ public static void setupNativeSidecarPlugin(QueryRunner queryRunner) "sidecar.http-client.max-content-length", SIDECAR_HTTP_CLIENT_MAX_CONTENT_SIZE_MB + "MB")); queryRunner.loadTypeManager(NativeTypeManagerFactory.NAME); queryRunner.loadPlanCheckerProviderManager("native", ImmutableMap.of()); + queryRunner.installPlugin(new NativeSqlInvokedFunctionsPlugin()); } @Test @@ -163,6 +173,7 @@ public void testSetNativeWorkerSessionProperty() @Test public void testShowFunctions() { + int inlinedSQLFunctionsCount = 0; @Language("SQL") String sql = "SHOW FUNCTIONS"; MaterializedResult actualResult = computeActual(sql); List actualRows = actualResult.getMaterializedRows(); @@ -176,11 +187,17 @@ public void testShowFunctions() // function namespace should be present. String fullFunctionName = row.get(5).toString(); - if (Pattern.matches(REGEX_FUNCTION_NAMESPACE, fullFunctionName)) { - continue; + if (!Pattern.matches(REGEX_FUNCTION_NAMESPACE, fullFunctionName)) { + // If no namespace match found, check if it's an inlined SQL Invoked function. + String language = row.get(9).toString(); + if (language.equalsIgnoreCase("SQL")) { + inlinedSQLFunctionsCount++; + continue; + } + fail(format("No namespace match found for row: %s", row)); } - fail(format("No namespace match found for row: %s", row)); } + assertEquals(inlinedSQLFunctionsCount, INLINED_SQL_FUNCTIONS_COUNT); } @Test @@ -321,7 +338,7 @@ public void testApproxPercentile() public void testInformationSchemaTables() { assertQuery("select lower(table_name) from information_schema.tables " - + "where table_name = 'lineitem' or table_name = 'LINEITEM' "); + + "where table_name = 'lineitem' or table_name = 'LINEITEM' "); } @Test @@ -423,6 +440,105 @@ public void testRemoveMapCast() "values 0.5, 0.1"); } + @Test + public void testOverriddenInlinedSqlInvokedFunctions() + { + // String functions + assertQuery("SELECT trail(comment, cast(nationkey as integer)) FROM nation"); + assertQuery("SELECT name, comment, replace_first(comment, 'iron', 'gold') from nation"); + + // Array functions + assertQuery("SELECT array_intersect(ARRAY['apple', 'banana', 'cherry'], ARRAY['apple', 'mango', 'fig'])"); + assertQuery("SELECT array_frequency(split(comment, '')) from nation"); + assertQuery("SELECT array_duplicates(ARRAY[regionkey]), array_duplicates(ARRAY[comment]) from nation"); + assertQuery("SELECT array_has_duplicates(ARRAY[custkey]) from orders"); + assertQuery("SELECT array_max_by(ARRAY[comment], x -> length(x)) from orders"); + assertQuery("SELECT array_min_by(ARRAY[ROW('USA', 1), ROW('INDIA', 2), ROW('UK', 3)], x -> x[2])"); + assertQuery("SELECT array_sort_desc(map_keys(map_union(quantity_by_linenumber))) FROM orders_ex"); + assertQuery("SELECT remove_nulls(ARRAY[CAST(regionkey AS VARCHAR), comment, NULL]) from nation"); + assertQuery("SELECT array_top_n(ARRAY[CAST(nationkey AS VARCHAR)], 3) from nation"); + assertQuerySucceeds("SELECT array_sort_desc(quantities, x -> abs(x)) FROM orders_ex"); + + // Map functions + assertQuery("SELECT map_normalize(MAP(ARRAY['a', 'b', 'c'], ARRAY[1, 4, 5]))"); + assertQuery("SELECT map_normalize(MAP(ARRAY['a', 'b', 'c'], ARRAY[1, 0, -1]))"); + assertQuery("SELECT name, map_normalize(MAP(ARRAY['regionkey', 'length'], ARRAY[regionkey, length(comment)])) from nation"); + assertQuery("SELECT name, map_remove_null_values(map(ARRAY['region', 'comment', 'nullable'], " + + "ARRAY[CAST(regionkey AS VARCHAR), comment, NULL])) from nation"); + assertQuery("SELECT name, map_key_exists(map(ARRAY['nation', 'comment'], ARRAY[CAST(nationkey AS VARCHAR), comment]), 'comment') from nation"); + assertQuery("SELECT map_keys_by_top_n_values(MAP(ARRAY[orderkey], ARRAY[custkey]), 2) from orders"); + assertQuery("SELECT map_top_n(MAP(ARRAY[CAST(nationkey AS VARCHAR)], ARRAY[comment]), 3) from nation"); + assertQuery("SELECT map_top_n_keys(MAP(ARRAY[orderkey], ARRAY[custkey]), 3) from orders"); + assertQuery("SELECT map_top_n_values(MAP(ARRAY[orderkey], ARRAY[custkey]), 3) from orders"); + assertQuery("SELECT all_keys_match(MAP(ARRAY[comment], ARRAY[custkey]), k -> length(k) > 5) from orders"); + assertQuery("SELECT any_keys_match(MAP(ARRAY[comment], ARRAY[custkey]), k -> starts_with(k, 'abc')) from orders"); + assertQuery("SELECT any_values_match(MAP(ARRAY[orderkey], ARRAY[totalprice]), k -> abs(k) > 20) from orders"); + assertQuery("SELECT no_values_match(MAP(ARRAY[orderkey], ARRAY[comment]), k -> length(k) > 2) from orders"); + assertQuery("SELECT no_keys_match(MAP(ARRAY[comment], ARRAY[custkey]), k -> ends_with(k, 'a')) from orders"); + } + + @Test + public void testNonOverriddenInlinedSqlInvokedFunctionsWhenConfigEnabled() + { + // Array functions + assertQuery("SELECT array_split_into_chunks(split(comment, ''), 2) from nation"); + assertQuery("SELECT array_least_frequent(quantities) from orders_ex"); + assertQuery("SELECT array_least_frequent(split(comment, ''), 5) from nation"); + assertQuerySucceeds("SELECT array_top_n(ARRAY[orderkey], 25, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from orders"); + + // Map functions + assertQuerySucceeds("SELECT map_top_n_values(MAP(ARRAY[comment], ARRAY[nationkey]), 2, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation"); + assertQuerySucceeds("SELECT map_top_n_keys(MAP(ARRAY[regionkey], ARRAY[nationkey]), 5, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation"); + + Session sessionWithKeyBasedSampling = Session.builder(getSession()) + .setSystemProperty(KEY_BASED_SAMPLING_ENABLED, "true") + .build(); + + @Language("SQL") String query = "select count(1) FROM lineitem l left JOIN orders o ON l.orderkey = o.orderkey JOIN customer c ON o.custkey = c.custkey"; + + assertQuery(query, "select cast(60175 as bigint)"); + assertQuery(sessionWithKeyBasedSampling, query, "select cast(16185 as bigint)"); + } + + @Test + public void testNonOverriddenInlinedSqlInvokedFunctionsWhenConfigDisabled() + { + // When inline_sql_functions is set to false, the below queries should fail as the implementations don't exist on the native worker + Session session = Session.builder(getSession()) + .setSystemProperty(KEY_BASED_SAMPLING_ENABLED, "true") + .setSystemProperty(INLINE_SQL_FUNCTIONS, "false") + .build(); + + // Array functions + assertQueryFails(session, + "SELECT array_split_into_chunks(split(comment, ''), 2) from nation", + ".*Scalar function name not registered: native.default.array_split_into_chunks.*"); + assertQueryFails(session, + "SELECT array_least_frequent(quantities) from orders_ex", + ".*Scalar function name not registered: native.default.array_least_frequent.*"); + assertQueryFails(session, + "SELECT array_least_frequent(split(comment, ''), 2) from nation", + ".*Scalar function name not registered: native.default.array_least_frequent.*"); + assertQueryFails(session, + "SELECT array_top_n(ARRAY[orderkey], 25, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from orders", + " Scalar function native\\.default\\.array_top_n not registered with arguments.*", + true); + + // Map functions + assertQueryFails(session, + "SELECT map_top_n_values(MAP(ARRAY[comment], ARRAY[nationkey]), 2, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation", + ".*Scalar function native\\.default\\.map_top_n_values not registered with arguments.*", + true); + assertQueryFails(session, + "SELECT map_top_n_keys(MAP(ARRAY[regionkey], ARRAY[nationkey]), 5, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation", + ".*Scalar function native\\.default\\.map_top_n_keys not registered with arguments.*", + true); + + assertQueryFails(session, + "select count(1) FROM lineitem l left JOIN orders o ON l.orderkey = o.orderkey JOIN customer c ON o.custkey = c.custkey", + ".*Scalar function name not registered: native.default.key_sampling_percent.*"); + } + private String generateRandomTableName() { String tableName = "tmp_presto_" + UUID.randomUUID().toString().replace("-", ""); diff --git a/presto-native-sql-invoked-functions-plugin/pom.xml b/presto-native-sql-invoked-functions-plugin/pom.xml new file mode 100644 index 0000000000000..7d837a28a8bfb --- /dev/null +++ b/presto-native-sql-invoked-functions-plugin/pom.xml @@ -0,0 +1,29 @@ + + 4.0.0 + + com.facebook.presto + presto-root + 0.295-SNAPSHOT + + + presto-native-sql-invoked-functions-plugin + Presto Native - Sql invoked functions plugin + presto-plugin + + + ${project.parent.basedir} + + + + + com.facebook.presto + presto-spi + provided + + + com.google.guava + guava + + + diff --git a/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeArraySqlFunctions.java b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeArraySqlFunctions.java new file mode 100644 index 0000000000000..841883d99ae8f --- /dev/null +++ b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeArraySqlFunctions.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.scalar.sql; + +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.SqlInvokedScalarFunction; +import com.facebook.presto.spi.function.SqlParameter; +import com.facebook.presto.spi.function.SqlParameters; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; + +public class NativeArraySqlFunctions +{ + private NativeArraySqlFunctions() {} + + @SqlInvokedScalarFunction(value = "array_split_into_chunks", deterministic = true, calledOnNullInput = false) + @Description("Returns an array of arrays splitting input array into chunks of given length. " + + "If array is not evenly divisible it will split into as many possible chunks and " + + "return the left over elements for the last array. Returns null for null inputs, but not elements.") + @TypeParameter("T") + @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "sz", type = "int")}) + @SqlType("array(array(T))") + public static String arraySplitIntoChunks() + { + return "RETURN IF(sz <= 0, " + + "fail('Invalid slice size: ' || cast(sz as varchar) || '. Size must be greater than zero.'), " + + "IF(cardinality(input) / sz > 10000, " + + "fail('Cannot split array of size: ' || cast(cardinality(input) as varchar) || ' into more than 10000 parts.'), " + + "transform(" + + "sequence(1, cardinality(input), sz), " + + "x -> slice(input, x, sz))))"; + } + + @SqlInvokedScalarFunction(value = "array_least_frequent", deterministic = true, calledOnNullInput = true) + @Description("Determines the least frequent element in the array. If there are multiple elements, the function returns the smallest element") + @TypeParameter("T") + @SqlParameter(name = "input", type = "array(T)") + @SqlType("array") + public static String arrayLeastFrequent() + { + return "RETURN IF(COALESCE(CARDINALITY(REMOVE_NULLS(input)), 0) = 0, NULL, TRANSFORM(SLICE(ARRAY_SORT(TRANSFORM(MAP_ENTRIES(ARRAY_FREQUENCY(REMOVE_NULLS(input))), x -> ROW(x[2], x[1]))), 1, 1), x -> x[2]))"; + } + + @SqlInvokedScalarFunction(value = "array_least_frequent", deterministic = true, calledOnNullInput = true) + @Description("Determines the n least frequent element in the array in the ascending order of the elements.") + @TypeParameter("T") + @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "n", type = "bigint")}) + @SqlType("array") + public static String arrayNLeastFrequent() + { + return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), IF(COALESCE(CARDINALITY(REMOVE_NULLS(input)), 0) = 0, NULL, TRANSFORM(SLICE(ARRAY_SORT(TRANSFORM(MAP_ENTRIES(ARRAY_FREQUENCY(REMOVE_NULLS(input))), x -> ROW(x[2], x[1]))), 1, n), x -> x[2])))"; + } + + @SqlInvokedScalarFunction(value = "array_top_n", deterministic = true, calledOnNullInput = true) + @Description("Returns the top N values of the given map sorted using the provided lambda comparator.") + @TypeParameter("T") + @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "n", type = "int"), @SqlParameter(name = "f", type = "function(T, T, bigint)")}) + @SqlType("array") + public static String arrayTopNComparator() + { + return "RETURN IF(n < 0, fail('Parameter n: ' || cast(n as varchar) || ' to ARRAY_TOP_N is negative'), SLICE(REVERSE(ARRAY_SORT(input, f)), 1, n))"; + } +} diff --git a/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeMapSqlFunctions.java b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeMapSqlFunctions.java new file mode 100644 index 0000000000000..9eccc84d6d8c8 --- /dev/null +++ b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeMapSqlFunctions.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.scalar.sql; + +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.SqlInvokedScalarFunction; +import com.facebook.presto.spi.function.SqlParameter; +import com.facebook.presto.spi.function.SqlParameters; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; + +public class NativeMapSqlFunctions +{ + private NativeMapSqlFunctions() {} + + @SqlInvokedScalarFunction(value = "map_top_n_keys", deterministic = true, calledOnNullInput = true) + @Description("Returns the top N keys of the given map sorting its keys using the provided lambda comparator.") + @TypeParameter("K") + @TypeParameter("V") + @SqlParameters({@SqlParameter(name = "input", type = "map(K, V)"), @SqlParameter(name = "n", type = "bigint"), @SqlParameter(name = "f", type = "function(K, K, bigint)")}) + @SqlType("array") + public static String mapTopNKeysComparator() + { + return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), slice(reverse(array_sort(map_keys(input), f)), 1, n))"; + } + + @SqlInvokedScalarFunction(value = "map_top_n_values", deterministic = true, calledOnNullInput = true) + @Description("Returns the top N values of the given map sorted using the provided lambda comparator.") + @TypeParameter("K") + @TypeParameter("V") + @SqlParameters({@SqlParameter(name = "input", type = "map(K, V)"), @SqlParameter(name = "n", type = "bigint"), @SqlParameter(name = "f", type = "function(V, V, bigint)")}) + @SqlType("array") + public static String mapTopNValuesComparator() + { + return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), slice(reverse(array_sort(remove_nulls(map_values(input)), f)) || filter(map_values(input), x -> x is null), 1, n))"; + } +} diff --git a/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSimpleSamplingPercent.java b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSimpleSamplingPercent.java new file mode 100644 index 0000000000000..a710391760714 --- /dev/null +++ b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSimpleSamplingPercent.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.scalar.sql; + +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.SqlInvokedScalarFunction; +import com.facebook.presto.spi.function.SqlParameter; +import com.facebook.presto.spi.function.SqlType; + +public class NativeSimpleSamplingPercent +{ + private NativeSimpleSamplingPercent() {} + + @SqlInvokedScalarFunction(value = "key_sampling_percent", deterministic = true, calledOnNullInput = false) + @Description("Returns a value between 0.0 and 1.0 using the hash of the given input string") + @SqlParameter(name = "input", type = "varchar") + @SqlType("double") + public static String keySamplingPercent() + { + return "return (abs(from_ieee754_64(xxhash64(cast(input as varbinary)))) % 100) / 100. "; + } +} diff --git a/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSqlInvokedFunctionsPlugin.java b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSqlInvokedFunctionsPlugin.java new file mode 100644 index 0000000000000..69d7ff1e78522 --- /dev/null +++ b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSqlInvokedFunctionsPlugin.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.scalar.sql; + +import com.facebook.presto.spi.Plugin; +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +public class NativeSqlInvokedFunctionsPlugin + implements Plugin +{ + @Override + public Set> getSqlInvokedFunctions() + { + return ImmutableSet.>builder() + .add(NativeArraySqlFunctions.class) + .add(NativeMapSqlFunctions.class) + .add(NativeSimpleSamplingPercent.class) + .build(); + } +} diff --git a/presto-native-tests/pom.xml b/presto-native-tests/pom.xml index e90785cbb6940..7e23670117b1d 100644 --- a/presto-native-tests/pom.xml +++ b/presto-native-tests/pom.xml @@ -192,6 +192,13 @@ units test + + + com.facebook.presto + presto-native-sql-invoked-functions-plugin + ${project.version} + test + diff --git a/presto-plan-checker-router-plugin/pom.xml b/presto-plan-checker-router-plugin/pom.xml index 60e34e4147d19..db1c4b0f1736c 100644 --- a/presto-plan-checker-router-plugin/pom.xml +++ b/presto-plan-checker-router-plugin/pom.xml @@ -223,6 +223,13 @@ presto-hive-metastore test + + + com.facebook.presto + presto-native-sql-invoked-functions-plugin + ${project.version} + test + diff --git a/presto-product-tests/conf/docker/common/compose-commons.sh b/presto-product-tests/conf/docker/common/compose-commons.sh index eae9f18ce9583..5c20783716b60 100644 --- a/presto-product-tests/conf/docker/common/compose-commons.sh +++ b/presto-product-tests/conf/docker/common/compose-commons.sh @@ -39,6 +39,16 @@ if [[ -z "${PRESTO_SERVER_DIR:-}" ]]; then source "${PRODUCT_TESTS_ROOT}/target/classes/presto.env" PRESTO_SERVER_DIR="${PROJECT_ROOT}/presto-server/target/presto-server-${PRESTO_VERSION}/" fi + +# The following plugin results in a function signature conflict when loaded in Java/ sidecar disabled native clusters. +# This plugin is only meant for sidecar enabled native clusters, hence exclude it. +PLUGIN_TO_EXCLUDE="native-sql-invoked-functions-plugin" + +if [[ -d "${PRESTO_SERVER_DIR}/plugin/${PLUGIN_TO_EXCLUDE}" ]]; then + echo "Excluding plugin: $PLUGIN_TO_EXCLUDE" + rm -rf "${PRESTO_SERVER_DIR}/plugin/${PLUGIN_TO_EXCLUDE}" +fi + export_canonical_path PRESTO_SERVER_DIR if [[ -z "${PRESTO_CLI_JAR:-}" ]]; then diff --git a/presto-server/src/main/provisio/presto.xml b/presto-server/src/main/provisio/presto.xml index b14b36a768e69..d15a041c7d1f5 100644 --- a/presto-server/src/main/provisio/presto.xml +++ b/presto-server/src/main/provisio/presto.xml @@ -292,4 +292,10 @@ + + + + + +