Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ jobs:
# Use different Maven options to install.
MAVEN_OPTS: "-Xmx2G -XX:+ExitOnOutOfMemoryError"
run: |
for i in $(seq 1 3); do ./mvnw clean install $MAVEN_FAST_INSTALL -pl 'presto-native-execution' -am && s=0 && break || s=$? && sleep 10; done; (exit $s)
for i in $(seq 1 3); do ./mvnw clean install $MAVEN_FAST_INSTALL -pl 'presto-native-sidecar-plugin' -am && s=0 && break || s=$? && sleep 10; done; (exit $s)
- name: Run presto-native sidecar tests
if: |
Expand Down
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@
<module>presto-router-example-plugin-scheduler</module>
<module>presto-plan-checker-router-plugin</module>
<module>presto-sql-invoked-functions-plugin</module>
<module>presto-native-sql-invoked-functions-plugin</module>
</modules>

<dependencyManagement>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.Varchars;
Expand Down Expand Up @@ -55,7 +54,6 @@
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE;
import static com.facebook.presto.metadata.CastType.CAST;
import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_NOT_FOUND;
import static com.facebook.presto.spi.StandardWarningCode.SAMPLED_FIELDS;
Expand Down Expand Up @@ -150,7 +148,7 @@ private PlanNode addSamplingFilter(PlanNode tableScanNode, Optional<VariableRefe
try {
sampledArg = call(
functionAndTypeManager,
QualifiedObjectName.valueOf(JAVA_BUILTIN_NAMESPACE, getKeyBasedSamplingFunction(session)),
getKeyBasedSamplingFunction(session),
DOUBLE,
ImmutableList.of(arg));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
package com.facebook.presto.sql.relational;

import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.CastType;
Expand Down Expand Up @@ -154,12 +153,6 @@ public static CallExpression call(FunctionAndTypeManager functionAndTypeManager,
return call(name, functionHandle, returnType, arguments);
}

public static CallExpression call(FunctionAndTypeManager functionAndTypeManager, QualifiedObjectName qualifiedObjectName, Type returnType, List<RowExpression> arguments)
{
FunctionHandle functionHandle = functionAndTypeManager.lookupFunction(qualifiedObjectName, fromTypes(arguments.stream().map(RowExpression::getType).collect(toImmutableList())));
return call(String.valueOf(qualifiedObjectName), functionHandle, returnType, arguments);
}

public static CallExpression call(FunctionAndTypeResolver functionAndTypeResolver, String name, Type returnType, RowExpression... arguments)
{
FunctionHandle functionHandle = functionAndTypeResolver.lookupFunction(name, fromTypes(Arrays.stream(arguments).map(RowExpression::getType).collect(toImmutableList())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ private NativeQueryRunnerUtils() {}
public static Map<String, String> getNativeWorkerHiveProperties()
{
return ImmutableMap.of("hive.parquet.pushdown-filter-enabled", "true",
"hive.orc-compression-codec", "ZSTD", "hive.storage-format", "DWRF");
"hive.orc-compression-codec", "ZSTD", "hive.storage-format", "DWRF");
}

public static Map<String, String> getNativeWorkerIcebergProperties()
Expand Down Expand Up @@ -59,6 +59,8 @@ public static Map<String, String> getNativeSidecarProperties()
.put("coordinator-sidecar-enabled", "true")
.put("exclude-invalid-worker-session-properties", "true")
.put("presto.default-namespace", "native.default")
// inline-sql-functions is overridden to be true in sidecar enabled native clusters.
.put("inline-sql-functions", "true")
.build();
}

Expand Down
16 changes: 16 additions & 0 deletions presto-native-sidecar-plugin/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,25 @@
</exclusion>
</exclusions>
</dependency>

<dependency>
<groupId>com.facebook.presto</groupId>
<artifactId>presto-built-in-worker-function-tools</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>com.facebook.presto</groupId>
<artifactId>presto-native-sql-invoked-functions-plugin</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.facebook.presto</groupId>
<artifactId>presto-sql-invoked-functions-plugin</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package com.facebook.presto.sidecar;

import com.facebook.presto.scalar.sql.NativeSqlInvokedFunctionsPlugin;
import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManagerFactory;
import com.facebook.presto.sidecar.sessionpropertyproviders.NativeSystemSessionPropertyProviderFactory;
import com.facebook.presto.sidecar.typemanager.NativeTypeManagerFactory;
Expand All @@ -37,5 +38,6 @@ public static void setupNativeSidecarPlugin(QueryRunner queryRunner)
"function-implementation-type", "CPP"));
queryRunner.loadTypeManager(NativeTypeManagerFactory.NAME);
queryRunner.loadPlanCheckerProviderManager("native", ImmutableMap.of());
queryRunner.installPlugin(new NativeSqlInvokedFunctionsPlugin());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import com.facebook.airlift.units.DataSize;
import com.facebook.presto.Session;
import com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils;
import com.facebook.presto.scalar.sql.NativeSqlInvokedFunctionsPlugin;
import com.facebook.presto.scalar.sql.SqlInvokedFunctionsPlugin;
import com.facebook.presto.sidecar.functionNamespace.FunctionDefinitionProvider;
import com.facebook.presto.sidecar.functionNamespace.NativeFunctionDefinitionProvider;
import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManager;
Expand Down Expand Up @@ -45,9 +47,12 @@
import java.util.stream.Collectors;

import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE;
import static com.facebook.presto.SystemSessionProperties.INLINE_SQL_FUNCTIONS;
import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_ENABLED;
import static com.facebook.presto.SystemSessionProperties.REMOVE_MAP_CAST;
import static com.facebook.presto.common.Utils.checkArgument;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createCustomer;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createLineitem;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createNation;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrders;
Expand All @@ -65,6 +70,7 @@ public class TestNativeSidecarPlugin
private static final String REGEX_FUNCTION_NAMESPACE = "native.default.*";
private static final String REGEX_SESSION_NAMESPACE = "Native Execution only.*";
private static final long SIDECAR_HTTP_CLIENT_MAX_CONTENT_SIZE_MB = 128;
private static final int INLINED_SQL_FUNCTIONS_COUNT = 7;

@Override
protected void createTables()
Expand All @@ -75,6 +81,7 @@ protected void createTables()
createOrders(queryRunner);
createOrdersEx(queryRunner);
createRegion(queryRunner);
createCustomer(queryRunner);
}

@Override
Expand All @@ -93,9 +100,11 @@ protected QueryRunner createQueryRunner()
protected QueryRunner createExpectedQueryRunner()
throws Exception
{
return PrestoNativeQueryRunnerUtils.javaHiveQueryRunnerBuilder()
QueryRunner queryRunner = PrestoNativeQueryRunnerUtils.javaHiveQueryRunnerBuilder()
.setAddStorageFormatToPath(true)
.build();
queryRunner.installPlugin(new SqlInvokedFunctionsPlugin());
return queryRunner;
}

public static void setupNativeSidecarPlugin(QueryRunner queryRunner)
Expand All @@ -113,6 +122,7 @@ public static void setupNativeSidecarPlugin(QueryRunner queryRunner)
"sidecar.http-client.max-content-length", SIDECAR_HTTP_CLIENT_MAX_CONTENT_SIZE_MB + "MB"));
queryRunner.loadTypeManager(NativeTypeManagerFactory.NAME);
queryRunner.loadPlanCheckerProviderManager("native", ImmutableMap.of());
queryRunner.installPlugin(new NativeSqlInvokedFunctionsPlugin());
}

@Test
Expand Down Expand Up @@ -163,6 +173,7 @@ public void testSetNativeWorkerSessionProperty()
@Test
public void testShowFunctions()
{
int inlinedSQLFunctionsCount = 0;
@Language("SQL") String sql = "SHOW FUNCTIONS";
MaterializedResult actualResult = computeActual(sql);
List<MaterializedRow> actualRows = actualResult.getMaterializedRows();
Expand All @@ -176,11 +187,17 @@ public void testShowFunctions()

// function namespace should be present.
String fullFunctionName = row.get(5).toString();
if (Pattern.matches(REGEX_FUNCTION_NAMESPACE, fullFunctionName)) {
continue;
if (!Pattern.matches(REGEX_FUNCTION_NAMESPACE, fullFunctionName)) {
// If no namespace match found, check if it's an inlined SQL Invoked function.
String language = row.get(9).toString();
if (language.equalsIgnoreCase("SQL")) {
inlinedSQLFunctionsCount++;
continue;
}
fail(format("No namespace match found for row: %s", row));
}
fail(format("No namespace match found for row: %s", row));
}
assertEquals(inlinedSQLFunctionsCount, INLINED_SQL_FUNCTIONS_COUNT);
}

@Test
Expand Down Expand Up @@ -321,7 +338,7 @@ public void testApproxPercentile()
public void testInformationSchemaTables()
{
assertQuery("select lower(table_name) from information_schema.tables "
+ "where table_name = 'lineitem' or table_name = 'LINEITEM' ");
+ "where table_name = 'lineitem' or table_name = 'LINEITEM' ");
}

@Test
Expand Down Expand Up @@ -423,6 +440,105 @@ public void testRemoveMapCast()
"values 0.5, 0.1");
}

@Test
public void testOverriddenInlinedSqlInvokedFunctions()
{
// String functions
assertQuery("SELECT trail(comment, cast(nationkey as integer)) FROM nation");
assertQuery("SELECT name, comment, replace_first(comment, 'iron', 'gold') from nation");

// Array functions
assertQuery("SELECT array_intersect(ARRAY['apple', 'banana', 'cherry'], ARRAY['apple', 'mango', 'fig'])");
assertQuery("SELECT array_frequency(split(comment, '')) from nation");
assertQuery("SELECT array_duplicates(ARRAY[regionkey]), array_duplicates(ARRAY[comment]) from nation");
assertQuery("SELECT array_has_duplicates(ARRAY[custkey]) from orders");
assertQuery("SELECT array_max_by(ARRAY[comment], x -> length(x)) from orders");
assertQuery("SELECT array_min_by(ARRAY[ROW('USA', 1), ROW('INDIA', 2), ROW('UK', 3)], x -> x[2])");
assertQuery("SELECT array_sort_desc(map_keys(map_union(quantity_by_linenumber))) FROM orders_ex");
assertQuery("SELECT remove_nulls(ARRAY[CAST(regionkey AS VARCHAR), comment, NULL]) from nation");
assertQuery("SELECT array_top_n(ARRAY[CAST(nationkey AS VARCHAR)], 3) from nation");
assertQuerySucceeds("SELECT array_sort_desc(quantities, x -> abs(x)) FROM orders_ex");

// Map functions
assertQuery("SELECT map_normalize(MAP(ARRAY['a', 'b', 'c'], ARRAY[1, 4, 5]))");
assertQuery("SELECT map_normalize(MAP(ARRAY['a', 'b', 'c'], ARRAY[1, 0, -1]))");
assertQuery("SELECT name, map_normalize(MAP(ARRAY['regionkey', 'length'], ARRAY[regionkey, length(comment)])) from nation");
assertQuery("SELECT name, map_remove_null_values(map(ARRAY['region', 'comment', 'nullable'], " +
"ARRAY[CAST(regionkey AS VARCHAR), comment, NULL])) from nation");
assertQuery("SELECT name, map_key_exists(map(ARRAY['nation', 'comment'], ARRAY[CAST(nationkey AS VARCHAR), comment]), 'comment') from nation");
assertQuery("SELECT map_keys_by_top_n_values(MAP(ARRAY[orderkey], ARRAY[custkey]), 2) from orders");
assertQuery("SELECT map_top_n(MAP(ARRAY[CAST(nationkey AS VARCHAR)], ARRAY[comment]), 3) from nation");
assertQuery("SELECT map_top_n_keys(MAP(ARRAY[orderkey], ARRAY[custkey]), 3) from orders");
assertQuery("SELECT map_top_n_values(MAP(ARRAY[orderkey], ARRAY[custkey]), 3) from orders");
assertQuery("SELECT all_keys_match(MAP(ARRAY[comment], ARRAY[custkey]), k -> length(k) > 5) from orders");
assertQuery("SELECT any_keys_match(MAP(ARRAY[comment], ARRAY[custkey]), k -> starts_with(k, 'abc')) from orders");
assertQuery("SELECT any_values_match(MAP(ARRAY[orderkey], ARRAY[totalprice]), k -> abs(k) > 20) from orders");
assertQuery("SELECT no_values_match(MAP(ARRAY[orderkey], ARRAY[comment]), k -> length(k) > 2) from orders");
assertQuery("SELECT no_keys_match(MAP(ARRAY[comment], ARRAY[custkey]), k -> ends_with(k, 'a')) from orders");
}

@Test
public void testNonOverriddenInlinedSqlInvokedFunctionsWhenConfigEnabled()
{
// Array functions
assertQuery("SELECT array_split_into_chunks(split(comment, ''), 2) from nation");
assertQuery("SELECT array_least_frequent(quantities) from orders_ex");
assertQuery("SELECT array_least_frequent(split(comment, ''), 5) from nation");
assertQuerySucceeds("SELECT array_top_n(ARRAY[orderkey], 25, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from orders");

// Map functions
assertQuerySucceeds("SELECT map_top_n_values(MAP(ARRAY[comment], ARRAY[nationkey]), 2, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation");
assertQuerySucceeds("SELECT map_top_n_keys(MAP(ARRAY[regionkey], ARRAY[nationkey]), 5, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation");

Session sessionWithKeyBasedSampling = Session.builder(getSession())
.setSystemProperty(KEY_BASED_SAMPLING_ENABLED, "true")
.build();

@Language("SQL") String query = "select count(1) FROM lineitem l left JOIN orders o ON l.orderkey = o.orderkey JOIN customer c ON o.custkey = c.custkey";

assertQuery(query, "select cast(60175 as bigint)");
assertQuery(sessionWithKeyBasedSampling, query, "select cast(16185 as bigint)");
}

@Test
public void testNonOverriddenInlinedSqlInvokedFunctionsWhenConfigDisabled()
{
// When inline_sql_functions is set to false, the below queries should fail as the implementations don't exist on the native worker
Session session = Session.builder(getSession())
.setSystemProperty(KEY_BASED_SAMPLING_ENABLED, "true")
.setSystemProperty(INLINE_SQL_FUNCTIONS, "false")
.build();

// Array functions
assertQueryFails(session,
"SELECT array_split_into_chunks(split(comment, ''), 2) from nation",
".*Scalar function name not registered: native.default.array_split_into_chunks.*");
assertQueryFails(session,
"SELECT array_least_frequent(quantities) from orders_ex",
".*Scalar function name not registered: native.default.array_least_frequent.*");
assertQueryFails(session,
"SELECT array_least_frequent(split(comment, ''), 2) from nation",
".*Scalar function name not registered: native.default.array_least_frequent.*");
assertQueryFails(session,
"SELECT array_top_n(ARRAY[orderkey], 25, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from orders",
" Scalar function native\\.default\\.array_top_n not registered with arguments.*",
true);

// Map functions
assertQueryFails(session,
"SELECT map_top_n_values(MAP(ARRAY[comment], ARRAY[nationkey]), 2, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation",
".*Scalar function native\\.default\\.map_top_n_values not registered with arguments.*",
true);
assertQueryFails(session,
"SELECT map_top_n_keys(MAP(ARRAY[regionkey], ARRAY[nationkey]), 5, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation",
".*Scalar function native\\.default\\.map_top_n_keys not registered with arguments.*",
true);

assertQueryFails(session,
"select count(1) FROM lineitem l left JOIN orders o ON l.orderkey = o.orderkey JOIN customer c ON o.custkey = c.custkey",
".*Scalar function name not registered: native.default.key_sampling_percent.*");
}

private String generateRandomTableName()
{
String tableName = "tmp_presto_" + UUID.randomUUID().toString().replace("-", "");
Expand Down
29 changes: 29 additions & 0 deletions presto-native-sql-invoked-functions-plugin/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>com.facebook.presto</groupId>
<artifactId>presto-root</artifactId>
<version>0.295-SNAPSHOT</version>
</parent>

<artifactId>presto-native-sql-invoked-functions-plugin</artifactId>
<description>Presto Native - Sql invoked functions plugin</description>
<packaging>presto-plugin</packaging>

<properties>
<air.main.basedir>${project.parent.basedir}</air.main.basedir>
</properties>

<dependencies>
<dependency>
<groupId>com.facebook.presto</groupId>
<artifactId>presto-spi</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</dependency>
</dependencies>
</project>
Loading
Loading