From 92a8ae51ae69e99904605456edd602d43a46ab35 Mon Sep 17 00:00:00 2001 From: Shrinidhi Joshi Date: Wed, 3 Sep 2025 15:18:28 -0700 Subject: [PATCH 1/6] CODEOWNERS: Expand codeownership of presto-spark owners to include presto-spark code in presto-native-execution module --- CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CODEOWNERS b/CODEOWNERS index 0b2831ba7ee31..e4a0e3bb56aed 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -123,6 +123,7 @@ CODEOWNERS @prestodb/team-tsc ##################################################################### # Presto on Spark module /presto-spark* @shrinidhijoshi @prestodb/committers +/presto-native-execution/*/com/facebook/presto/spark/* @shrinidhijoshi @prestodb/committers ##################################################################### # Presto connectors and plugins @@ -166,4 +167,3 @@ CODEOWNERS @prestodb/team-tsc # Presto CI and builds /.github @czentgr @unidevel @prestodb/committers /docker @czentgr @unidevel @prestodb/committers - From 9c5004f0a4ed6de3727696a431583a7294dc28ba Mon Sep 17 00:00:00 2001 From: Gary Helmling Date: Mon, 1 Sep 2025 23:09:28 -0700 Subject: [PATCH 2/6] Fix up conditional inclusion of Spark2/3 modules --- pom.xml | 111 ++++++++++++++++----- presto-spark-base/pom.xml | 36 +++++-- presto-spark-classloader-interface/pom.xml | 45 +++++---- 3 files changed, 133 insertions(+), 59 deletions(-) diff --git a/pom.xml b/pom.xml index 15a6f0ec95a20..d5ac80313ec19 100644 --- a/pom.xml +++ b/pom.xml @@ -90,8 +90,7 @@ 6.0.0 17.0.0 3.5.4 - - 2 + 2.0.2-6 + org.codehaus.plexus:plexus-utils + com.google.guava:guava + com.fasterxml.jackson.core:jackson-annotations + com.fasterxml.jackson.core:jackson-core + com.fasterxml.jackson.core:jackson-databind + + + + + + + + org.basepom.maven + duplicate-finder-maven-plugin + + true + + com.github.benmanes.caffeine.* + + META-INF.versions.9.module-info + + META-INF.versions.11.module-info + + META-INF.versions.9.org.apache.lucene.* + + + + + + + + + + spark3 @@ -3173,19 +3237,12 @@ - 3 + 3.4.1-1 - - - - com.facebook.presto.spark - spark-core - 3.4.1-1 - provided - - - + + presto-spark-classloader-spark3 + diff --git a/presto-spark-base/pom.xml b/presto-spark-base/pom.xml index 37dc1d4fb541b..7767219a7ec65 100644 --- a/presto-spark-base/pom.xml +++ b/presto-spark-base/pom.xml @@ -16,7 +16,6 @@ 9.4.55.v20240627 4.12.0 3.9.1 - 2 @@ -55,12 +54,6 @@ provided - - com.facebook.presto - presto-spark-classloader-spark${dep.pos.classloader.module-name.suffix} - provided - - com.facebook.presto presto-client @@ -538,6 +531,25 @@ + + spark2 + + + true + + !spark-version + + + + + + com.facebook.presto + presto-spark-classloader-spark2 + ${project.version} + + + + spark3 @@ -548,11 +560,13 @@ - - 3 - - + + com.facebook.presto + presto-spark-classloader-spark3 + ${project.version} + + com.facebook.presto.spark spark-core diff --git a/presto-spark-classloader-interface/pom.xml b/presto-spark-classloader-interface/pom.xml index 8f93be12818a4..a95e45827c5f1 100644 --- a/presto-spark-classloader-interface/pom.xml +++ b/presto-spark-classloader-interface/pom.xml @@ -13,7 +13,6 @@ ${project.parent.basedir} true - 2 @@ -23,11 +22,6 @@ provided - - com.facebook.presto - presto-spark-classloader-spark${dep.pos.classloader.module-name.suffix} - - com.google.guava guava @@ -40,6 +34,25 @@ + + spark2 + + + true + + !spark-version + + + + + + com.facebook.presto + presto-spark-classloader-spark2 + ${project.version} + + + + spark3 @@ -50,22 +63,12 @@ - - 3 - - - - - - com.facebook.presto.spark - spark-core - 3.4.1-1 - compile - - - - + + com.facebook.presto + presto-spark-classloader-spark3 + ${project.version} + org.scala-lang scala-library From 8c7f2a8487683d73f64b5c765ff073e85dee8fbf Mon Sep 17 00:00:00 2001 From: Artem Selishchev Date: Thu, 4 Sep 2025 18:28:57 -0700 Subject: [PATCH 3/6] [presto] Move out M2Y from RegressionState for regr_slope and regr_intercept functions (#25475) (#25748) Summary: ## Context Currently we don't enforce intermediate/return type are the same in Coordinator and Prestissimo Worker. Velox creates vectors for intermediate/return results based on a plan that comes from Coordinator. Then Prestissimo tries to use those vector and not crash. In practise we had a crash some time ago due to such a mismatch (D74199165). And I added validation to Velox to catch such kind of mismatches early: https://github.com/facebookincubator/velox/pull/13322 But we wasn't able to enable it in prod, because the validation failed for "regr_slope" and "regr_intercept" functions. ## What's changed? In this diff I'm fixing "regr_slope" and "regr_intercept" intermediate types. Basically in Java `AggregationState` for all these functions is the same: ``` AggregationFunction("regr_slope") AggregationFunction("regr_intercept") AggregationFunction("regr_sxy") AggregationFunction("regr_sxx") AggregationFunction("regr_syy") AggregationFunction("regr_r2") AggregationFunction("regr_count") AggregationFunction("regr_avgy") AggregationFunction("regr_avgx") ``` But in Prestissimo the state storage is more optimal: ``` AggregationFunction("regr_slope") AggregationFunction("regr_intercept") ``` These 2 aggregation functions don't have M2Y field. And this is more efficient, because we don't waste memory and CPU on the field, that aren't needed. So I moved M2Y to extended class, the same as it works in Velox: https://github.com/facebookincubator/velox/blob/main/velox/functions/prestosql/aggregates/CovarianceAggregates.cpp?fbclid=IwY2xjawLRTetleHRuA2FlbQIxMQBicmlkETFiT0N3UFR0M2VKOHl6MHRhAR6KRQ1VUQdCkZXzwj14sMQrVZ-R9QBH1utuGJb5U_lyGzDwt8PwV317QRVNJg_aem_-ePxZ-fHO5MNgfUmayVJFA#L326-L337 No major changes, mostly just reorganized the code. ## Test plan I tested `REGR_SLOPE`, `REGR_INTERCEPT` and `REGR_R2` functions since they are heavily used in prod and cover both cases: with and without M2Y field. What my test looked like. For all 3 `REGR_*` functions I found some prod queries, then: 1. Ran them on prev Java build 2. Ran them on new (with this PR) Java build 3. Ran them on prev Prestissimo build 4. Ran them on new (with this PR) Prestissimo build And compared the output results. They all were identical. With this manual test we covered `Coordinator -> Java Worker` and `Coordinator -> Prestissimo Worker` integrations. ## Next steps In this diff I'm trying to apply the same optimization to Java. With this fix, the signatures will become the same in Java and Prestissimo and we will be able to enable the validation Differential Revision: D77625566 == NO RELEASE NOTES == --- ...uiltInTypeAndFunctionNamespaceManager.java | 4 + .../aggregation/AggregationUtils.java | 25 ++- .../DoubleRegressionAggregation.java | 103 ------------ .../DoubleRegressionExtendedAggregation.java | 149 ++++++++++++++++++ .../RealRegressionAggregation.java | 103 ------------ .../RealRegressionExtendedAggregation.java | 149 ++++++++++++++++++ .../state/ExtendedRegressionState.java | 22 +++ .../aggregation/state/RegressionState.java | 4 - 8 files changed, 345 insertions(+), 214 deletions(-) create mode 100644 presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionExtendedAggregation.java create mode 100644 presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionExtendedAggregation.java create mode 100644 presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/ExtendedRegressionState.java diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index 40c493fc887c6..d37d658e69204 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -67,6 +67,7 @@ import com.facebook.presto.operator.aggregation.DoubleCovarianceAggregation; import com.facebook.presto.operator.aggregation.DoubleHistogramAggregation; import com.facebook.presto.operator.aggregation.DoubleRegressionAggregation; +import com.facebook.presto.operator.aggregation.DoubleRegressionExtendedAggregation; import com.facebook.presto.operator.aggregation.DoubleSumAggregation; import com.facebook.presto.operator.aggregation.EntropyAggregation; import com.facebook.presto.operator.aggregation.GeometricMeanAggregations; @@ -84,6 +85,7 @@ import com.facebook.presto.operator.aggregation.RealGeometricMeanAggregations; import com.facebook.presto.operator.aggregation.RealHistogramAggregation; import com.facebook.presto.operator.aggregation.RealRegressionAggregation; +import com.facebook.presto.operator.aggregation.RealRegressionExtendedAggregation; import com.facebook.presto.operator.aggregation.RealSumAggregation; import com.facebook.presto.operator.aggregation.ReduceAggregationFunction; import com.facebook.presto.operator.aggregation.SumDataSizeForStats; @@ -744,7 +746,9 @@ private List getBuiltInFunctions(FunctionsConfig function .aggregates(DoubleCovarianceAggregation.class) .aggregates(RealCovarianceAggregation.class) .aggregates(DoubleRegressionAggregation.class) + .aggregates(DoubleRegressionExtendedAggregation.class) .aggregates(RealRegressionAggregation.class) + .aggregates(RealRegressionExtendedAggregation.class) .aggregates(DoubleCorrelationAggregation.class) .aggregates(RealCorrelationAggregation.class) .aggregates(BitwiseOrAggregation.class) diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java index 578e782bc8d91..c78186ebb2417 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java @@ -22,6 +22,7 @@ import com.facebook.presto.operator.aggregation.state.CentralMomentsState; import com.facebook.presto.operator.aggregation.state.CorrelationState; import com.facebook.presto.operator.aggregation.state.CovarianceState; +import com.facebook.presto.operator.aggregation.state.ExtendedRegressionState; import com.facebook.presto.operator.aggregation.state.RegressionState; import com.facebook.presto.operator.aggregation.state.VarianceState; import com.facebook.presto.spi.function.AggregationFunctionImplementation; @@ -145,9 +146,14 @@ public static double getCorrelation(CorrelationState state) public static void updateRegressionState(RegressionState state, double x, double y) { double oldMeanX = state.getMeanX(); - double oldMeanY = state.getMeanY(); updateCovarianceState(state, x, y); state.setM2X(state.getM2X() + (x - oldMeanX) * (x - state.getMeanX())); + } + + public static void updateExtendedRegressionState(ExtendedRegressionState state, double x, double y) + { + double oldMeanY = state.getMeanY(); + updateRegressionState(state, x, y); state.setM2Y(state.getM2Y() + (y - oldMeanY) * (y - state.getMeanY())); } @@ -189,12 +195,12 @@ public static double getRegressionSxy(RegressionState state) return state.getC2(); } - public static double getRegressionSyy(RegressionState state) + public static double getRegressionSyy(ExtendedRegressionState state) { return state.getM2Y(); } - public static double getRegressionR2(RegressionState state) + public static double getRegressionR2(ExtendedRegressionState state) { if (state.getM2X() != 0 && state.getM2Y() == 0) { return 1.0; @@ -311,10 +317,21 @@ public static void mergeRegressionState(RegressionState state, RegressionState o long na = state.getCount(); long nb = otherState.getCount(); state.setM2X(state.getM2X() + otherState.getM2X() + na * nb * Math.pow(state.getMeanX() - otherState.getMeanX(), 2) / (double) (na + nb)); - state.setM2Y(state.getM2Y() + otherState.getM2Y() + na * nb * Math.pow(state.getMeanY() - otherState.getMeanY(), 2) / (double) (na + nb)); updateCovarianceState(state, otherState); } + public static void mergeExtendedRegressionState(ExtendedRegressionState state, ExtendedRegressionState otherState) + { + if (otherState.getCount() == 0) { + return; + } + + long na = state.getCount(); + long nb = otherState.getCount(); + state.setM2Y(state.getM2Y() + otherState.getM2Y() + na * nb * Math.pow(state.getMeanY() - otherState.getMeanY(), 2) / (double) (na + nb)); + mergeRegressionState(state, otherState); + } + public static String generateAggregationName(String baseName, TypeSignature outputType, List inputTypes) { StringBuilder sb = new StringBuilder(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java index 24d1c6e61fcf5..db3ad26ec5d6d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java @@ -24,15 +24,8 @@ import com.facebook.presto.spi.function.SqlType; import static com.facebook.presto.common.type.DoubleType.DOUBLE; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionIntercept; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSlope; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; import static com.facebook.presto.operator.aggregation.AggregationUtils.mergeRegressionState; import static com.facebook.presto.operator.aggregation.AggregationUtils.updateRegressionState; @@ -78,100 +71,4 @@ public static void regrIntercept(@AggregationState RegressionState state, BlockB out.appendNull(); } } - - @AggregationFunction("regr_sxy") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrSxy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSxy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_sxx") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrSxx(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSxx(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_syy") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrSyy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSyy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_r2") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrR2(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionR2(state); - if (Double.isFinite(result)) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_count") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrCount(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionCount(state); - if (Double.isFinite(result) && result > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_avgy") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrAvgy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionAvgy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_avgx") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrAvgx(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionAvgx(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionExtendedAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionExtendedAggregation.java new file mode 100644 index 0000000000000..3550cd0936949 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionExtendedAggregation.java @@ -0,0 +1,149 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.operator.aggregation.state.ExtendedRegressionState; +import com.facebook.presto.spi.function.AggregationFunction; +import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.CombineFunction; +import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.OutputFunction; +import com.facebook.presto.spi.function.SqlType; + +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.mergeExtendedRegressionState; +import static com.facebook.presto.operator.aggregation.AggregationUtils.updateExtendedRegressionState; + +@AggregationFunction +public class DoubleRegressionExtendedAggregation +{ + private DoubleRegressionExtendedAggregation() {} + + @InputFunction + public static void input(@AggregationState ExtendedRegressionState state, @SqlType(StandardTypes.DOUBLE) double dependentValue, @SqlType(StandardTypes.DOUBLE) double independentValue) + { + updateExtendedRegressionState(state, independentValue, dependentValue); + } + + @CombineFunction + public static void combine(@AggregationState ExtendedRegressionState state, @AggregationState ExtendedRegressionState otherState) + { + mergeExtendedRegressionState(state, otherState); + } + + @AggregationFunction("regr_sxy") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrSxy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSxy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_sxx") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrSxx(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSxx(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_syy") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrSyy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSyy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_r2") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrR2(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionR2(state); + if (Double.isFinite(result)) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_count") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrCount(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionCount(state); + if (Double.isFinite(result) && result > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgy") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrAvgy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionAvgy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgx") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrAvgx(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionAvgx(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java index 1fe5d006da1a9..a75222bfa93c4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java @@ -24,15 +24,8 @@ import com.facebook.presto.spi.function.SqlType; import static com.facebook.presto.common.type.RealType.REAL; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionIntercept; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSlope; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; import static java.lang.Float.floatToRawIntBits; import static java.lang.Float.intBitsToFloat; @@ -78,100 +71,4 @@ public static void regrIntercept(@AggregationState RegressionState state, BlockB out.appendNull(); } } - - @AggregationFunction("regr_sxy") - @OutputFunction(StandardTypes.REAL) - public static void regrSxy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSxy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_sxx") - @OutputFunction(StandardTypes.REAL) - public static void regrSxx(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSxx(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_syy") - @OutputFunction(StandardTypes.REAL) - public static void regrSyy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSyy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_r2") - @OutputFunction(StandardTypes.REAL) - public static void regrR2(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionR2(state); - if (Double.isFinite(result)) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_count") - @OutputFunction(StandardTypes.REAL) - public static void regrCount(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionCount(state); - if (Double.isFinite(result) && result > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_avgy") - @OutputFunction(StandardTypes.REAL) - public static void regrAvgy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionAvgy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_avgx") - @OutputFunction(StandardTypes.REAL) - public static void regrAvgx(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionAvgx(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionExtendedAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionExtendedAggregation.java new file mode 100644 index 0000000000000..2d0335ae9aca6 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionExtendedAggregation.java @@ -0,0 +1,149 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.operator.aggregation.state.ExtendedRegressionState; +import com.facebook.presto.spi.function.AggregationFunction; +import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.CombineFunction; +import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.OutputFunction; +import com.facebook.presto.spi.function.SqlType; + +import static com.facebook.presto.common.type.RealType.REAL; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; +import static java.lang.Float.floatToRawIntBits; +import static java.lang.Float.intBitsToFloat; + +@AggregationFunction +public class RealRegressionExtendedAggregation +{ + private RealRegressionExtendedAggregation() {} + + @InputFunction + public static void input(@AggregationState ExtendedRegressionState state, @SqlType(StandardTypes.REAL) long dependentValue, @SqlType(StandardTypes.REAL) long independentValue) + { + DoubleRegressionExtendedAggregation.input(state, intBitsToFloat((int) dependentValue), intBitsToFloat((int) independentValue)); + } + + @CombineFunction + public static void combine(@AggregationState ExtendedRegressionState state, @AggregationState ExtendedRegressionState otherState) + { + DoubleRegressionExtendedAggregation.combine(state, otherState); + } + + @AggregationFunction("regr_sxy") + @OutputFunction(StandardTypes.REAL) + public static void regrSxy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSxy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_sxx") + @OutputFunction(StandardTypes.REAL) + public static void regrSxx(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSxx(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_syy") + @OutputFunction(StandardTypes.REAL) + public static void regrSyy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSyy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_r2") + @OutputFunction(StandardTypes.REAL) + public static void regrR2(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionR2(state); + if (Double.isFinite(result)) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_count") + @OutputFunction(StandardTypes.REAL) + public static void regrCount(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionCount(state); + if (Double.isFinite(result) && result > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgy") + @OutputFunction(StandardTypes.REAL) + public static void regrAvgy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionAvgy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgx") + @OutputFunction(StandardTypes.REAL) + public static void regrAvgx(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionAvgx(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/ExtendedRegressionState.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/ExtendedRegressionState.java new file mode 100644 index 0000000000000..64a9883174158 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/ExtendedRegressionState.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.state; + +public interface ExtendedRegressionState + extends RegressionState +{ + double getM2Y(); + + void setM2Y(double value); +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java index 79837f90c0c11..ae3af6f46dc43 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java @@ -19,8 +19,4 @@ public interface RegressionState double getM2X(); void setM2X(double value); - - double getM2Y(); - - void setM2Y(double value); } From 0d75bc32e4d8943c626502db3d7195bf67124fe4 Mon Sep 17 00:00:00 2001 From: Xin Zhang Date: Thu, 4 Sep 2025 19:03:24 +0100 Subject: [PATCH 4/6] Remove register-test-functions from PrestoSparkNativeQueryRunnerUtils --- .../facebook/presto/spark/PrestoSparkNativeQueryRunnerUtils.java | 1 - 1 file changed, 1 deletion(-) diff --git a/presto-native-execution/src/test/java/com/facebook/presto/spark/PrestoSparkNativeQueryRunnerUtils.java b/presto-native-execution/src/test/java/com/facebook/presto/spark/PrestoSparkNativeQueryRunnerUtils.java index 14113b22d4fb3..21ef570381795 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/spark/PrestoSparkNativeQueryRunnerUtils.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/spark/PrestoSparkNativeQueryRunnerUtils.java @@ -79,7 +79,6 @@ public static Map getNativeExecutionSparkConfigs() .put("catalog.config-dir", "/") .put("task.info-update-interval", "100ms") .put("spark.initial-partition-count", "1") - .put("register-test-functions", "true") .put("native-execution-program-arguments", "--logtostderr=1 --minloglevel=3") .put("spark.partition-count-auto-tune-enabled", "false"); From 3257215cfe2ff4ab22329c7e2827a10a8f11b4e8 Mon Sep 17 00:00:00 2001 From: Pramod Satya Date: Thu, 4 Sep 2025 20:13:43 -0700 Subject: [PATCH 5/6] [native] Use subscript operator to retrieve function handle --- .../sql/relational/FunctionResolution.java | 2 +- .../sidecar/TestNativeSidecarPlugin.java | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java index a3fe25d7b65fa..337582257bad4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java @@ -281,7 +281,7 @@ public boolean isEqualsFunction(FunctionHandle functionHandle) @Override public FunctionHandle subscriptFunction(Type baseType, Type indexType) { - return functionAndTypeResolver.lookupFunction(SUBSCRIPT.getFunctionName().getObjectName(), fromTypes(baseType, indexType)); + return functionAndTypeResolver.resolveOperator(SUBSCRIPT, fromTypes(baseType, indexType)); } @Override diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java index 1c5d05d55b274..b4716f98362cf 100644 --- a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java +++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java @@ -14,6 +14,7 @@ package com.facebook.presto.sidecar; import com.facebook.airlift.units.DataSize; +import com.facebook.presto.Session; import com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils; import com.facebook.presto.sidecar.functionNamespace.FunctionDefinitionProvider; import com.facebook.presto.sidecar.functionNamespace.NativeFunctionDefinitionProvider; @@ -44,6 +45,7 @@ import java.util.stream.Collectors; import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.presto.SystemSessionProperties.REMOVE_MAP_CAST; import static com.facebook.presto.common.Utils.checkArgument; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createLineitem; @@ -405,6 +407,22 @@ public void testGeometryQueries() "Error from native plan checker: .SpatialJoinNode no abstract type PlanNode "); } + @Test + public void testRemoveMapCast() + { + Session enableOptimization = Session.builder(getSession()) + .setSystemProperty(REMOVE_MAP_CAST, "true") + .build(); + assertQuery(enableOptimization, "select feature[key] from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 4)) t(feature, key)", + "values 0.5, 0.1"); + assertQuery(enableOptimization, "select element_at(feature, key) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 4)) t(feature, key)", + "values 0.5, 0.1"); + assertQuery(enableOptimization, "select element_at(feature, key) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 400000000000)) t(feature, key)", + "values 0.5, null"); + assertQuery(enableOptimization, "select feature[key] from (values (map(array[cast(1 as varchar), '2', '3', '4'], array[0.3, 0.5, 0.9, 0.1]), cast('2' as varchar)), (map(array[cast(1 as varchar), '2', '3', '4'], array[0.3, 0.5, 0.9, 0.1]), '4')) t(feature, key)", + "values 0.5, 0.1"); + } + private String generateRandomTableName() { String tableName = "tmp_presto_" + UUID.randomUUID().toString().replace("-", ""); From 0458101a8f4a099e2dcde4fafd12339e96fb250a Mon Sep 17 00:00:00 2001 From: Pratik Joseph Dabre Date: Fri, 5 Sep 2025 11:27:30 -0700 Subject: [PATCH 6/6] [native] Introduce presto-native-sql-invoked-functions-plugin for sidecar enabled clusters Adds a new plugin : presto-native-sql-invoked-functions-plugin that contains all inlined SQL functions except those with overridden native implementations. This plugin is intended to be loaded only in sidecar enabled clusters. --- .../prestocpp-linux-build-and-unit-test.yml | 2 +- pom.xml | 1 + .../optimizations/KeyBasedSampler.java | 4 +- .../presto/sql/relational/Expressions.java | 7 - .../nativeworker/NativeQueryRunnerUtils.java | 4 +- presto-native-sidecar-plugin/pom.xml | 16 +++ .../NativeSidecarPluginQueryRunnerUtils.java | 2 + .../sidecar/TestNativeSidecarPlugin.java | 126 +++++++++++++++++- .../pom.xml | 29 ++++ .../scalar/sql/NativeArraySqlFunctions.java | 74 ++++++++++ .../scalar/sql/NativeMapSqlFunctions.java | 48 +++++++ .../sql/NativeSimpleSamplingPercent.java | 33 +++++ .../sql/NativeSqlInvokedFunctionsPlugin.java | 33 +++++ presto-native-tests/pom.xml | 7 + presto-plan-checker-router-plugin/pom.xml | 7 + .../conf/docker/common/compose-commons.sh | 10 ++ presto-server/src/main/provisio/presto.xml | 6 + 17 files changed, 392 insertions(+), 17 deletions(-) create mode 100644 presto-native-sql-invoked-functions-plugin/pom.xml create mode 100644 presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeArraySqlFunctions.java create mode 100644 presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeMapSqlFunctions.java create mode 100644 presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSimpleSamplingPercent.java create mode 100644 presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSqlInvokedFunctionsPlugin.java diff --git a/.github/workflows/prestocpp-linux-build-and-unit-test.yml b/.github/workflows/prestocpp-linux-build-and-unit-test.yml index b28f2292252ea..f3fc4843c3321 100644 --- a/.github/workflows/prestocpp-linux-build-and-unit-test.yml +++ b/.github/workflows/prestocpp-linux-build-and-unit-test.yml @@ -370,7 +370,7 @@ jobs: # Use different Maven options to install. MAVEN_OPTS: "-Xmx2G -XX:+ExitOnOutOfMemoryError" run: | - for i in $(seq 1 3); do ./mvnw clean install $MAVEN_FAST_INSTALL -pl 'presto-native-execution' -am && s=0 && break || s=$? && sleep 10; done; (exit $s) + for i in $(seq 1 3); do ./mvnw clean install $MAVEN_FAST_INSTALL -pl 'presto-native-sidecar-plugin' -am && s=0 && break || s=$? && sleep 10; done; (exit $s) - name: Run presto-native sidecar tests if: | diff --git a/pom.xml b/pom.xml index d5ac80313ec19..f73719dacd94a 100644 --- a/pom.xml +++ b/pom.xml @@ -224,6 +224,7 @@ presto-router-example-plugin-scheduler presto-plan-checker-router-plugin presto-sql-invoked-functions-plugin + presto-native-sql-invoked-functions-plugin diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java index f34359aeba6af..1c7856a62caea 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; -import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.Varchars; @@ -55,7 +54,6 @@ import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.VarcharType.VARCHAR; -import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static com.facebook.presto.metadata.CastType.CAST; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_NOT_FOUND; import static com.facebook.presto.spi.StandardWarningCode.SAMPLED_FIELDS; @@ -150,7 +148,7 @@ private PlanNode addSamplingFilter(PlanNode tableScanNode, Optional arguments) - { - FunctionHandle functionHandle = functionAndTypeManager.lookupFunction(qualifiedObjectName, fromTypes(arguments.stream().map(RowExpression::getType).collect(toImmutableList()))); - return call(String.valueOf(qualifiedObjectName), functionHandle, returnType, arguments); - } - public static CallExpression call(FunctionAndTypeResolver functionAndTypeResolver, String name, Type returnType, RowExpression... arguments) { FunctionHandle functionHandle = functionAndTypeResolver.lookupFunction(name, fromTypes(Arrays.stream(arguments).map(RowExpression::getType).collect(toImmutableList()))); diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java index 3650f8cb36531..59a60b0024842 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java @@ -29,7 +29,7 @@ private NativeQueryRunnerUtils() {} public static Map getNativeWorkerHiveProperties() { return ImmutableMap.of("hive.parquet.pushdown-filter-enabled", "true", - "hive.orc-compression-codec", "ZSTD", "hive.storage-format", "DWRF"); + "hive.orc-compression-codec", "ZSTD", "hive.storage-format", "DWRF"); } public static Map getNativeWorkerIcebergProperties() @@ -59,6 +59,8 @@ public static Map getNativeSidecarProperties() .put("coordinator-sidecar-enabled", "true") .put("exclude-invalid-worker-session-properties", "true") .put("presto.default-namespace", "native.default") + // inline-sql-functions is overridden to be true in sidecar enabled native clusters. + .put("inline-sql-functions", "true") .build(); } diff --git a/presto-native-sidecar-plugin/pom.xml b/presto-native-sidecar-plugin/pom.xml index b2bc40e8f2bd4..d04424440469a 100644 --- a/presto-native-sidecar-plugin/pom.xml +++ b/presto-native-sidecar-plugin/pom.xml @@ -260,9 +260,25 @@ + com.facebook.presto presto-built-in-worker-function-tools + ${project.version} + + + + com.facebook.presto + presto-native-sql-invoked-functions-plugin + ${project.version} + test + + + + com.facebook.presto + presto-sql-invoked-functions-plugin + ${project.version} + test diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java index 776d4920e2f16..c8c7e1123f974 100644 --- a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java +++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sidecar; +import com.facebook.presto.scalar.sql.NativeSqlInvokedFunctionsPlugin; import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManagerFactory; import com.facebook.presto.sidecar.sessionpropertyproviders.NativeSystemSessionPropertyProviderFactory; import com.facebook.presto.sidecar.typemanager.NativeTypeManagerFactory; @@ -37,5 +38,6 @@ public static void setupNativeSidecarPlugin(QueryRunner queryRunner) "function-implementation-type", "CPP")); queryRunner.loadTypeManager(NativeTypeManagerFactory.NAME); queryRunner.loadPlanCheckerProviderManager("native", ImmutableMap.of()); + queryRunner.installPlugin(new NativeSqlInvokedFunctionsPlugin()); } } diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java index b4716f98362cf..fe0b18c24b2fb 100644 --- a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java +++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java @@ -16,6 +16,8 @@ import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils; +import com.facebook.presto.scalar.sql.NativeSqlInvokedFunctionsPlugin; +import com.facebook.presto.scalar.sql.SqlInvokedFunctionsPlugin; import com.facebook.presto.sidecar.functionNamespace.FunctionDefinitionProvider; import com.facebook.presto.sidecar.functionNamespace.NativeFunctionDefinitionProvider; import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManager; @@ -45,9 +47,12 @@ import java.util.stream.Collectors; import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.presto.SystemSessionProperties.INLINE_SQL_FUNCTIONS; +import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_ENABLED; import static com.facebook.presto.SystemSessionProperties.REMOVE_MAP_CAST; import static com.facebook.presto.common.Utils.checkArgument; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createCustomer; import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createLineitem; import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createNation; import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrders; @@ -65,6 +70,7 @@ public class TestNativeSidecarPlugin private static final String REGEX_FUNCTION_NAMESPACE = "native.default.*"; private static final String REGEX_SESSION_NAMESPACE = "Native Execution only.*"; private static final long SIDECAR_HTTP_CLIENT_MAX_CONTENT_SIZE_MB = 128; + private static final int INLINED_SQL_FUNCTIONS_COUNT = 7; @Override protected void createTables() @@ -75,6 +81,7 @@ protected void createTables() createOrders(queryRunner); createOrdersEx(queryRunner); createRegion(queryRunner); + createCustomer(queryRunner); } @Override @@ -93,9 +100,11 @@ protected QueryRunner createQueryRunner() protected QueryRunner createExpectedQueryRunner() throws Exception { - return PrestoNativeQueryRunnerUtils.javaHiveQueryRunnerBuilder() + QueryRunner queryRunner = PrestoNativeQueryRunnerUtils.javaHiveQueryRunnerBuilder() .setAddStorageFormatToPath(true) .build(); + queryRunner.installPlugin(new SqlInvokedFunctionsPlugin()); + return queryRunner; } public static void setupNativeSidecarPlugin(QueryRunner queryRunner) @@ -113,6 +122,7 @@ public static void setupNativeSidecarPlugin(QueryRunner queryRunner) "sidecar.http-client.max-content-length", SIDECAR_HTTP_CLIENT_MAX_CONTENT_SIZE_MB + "MB")); queryRunner.loadTypeManager(NativeTypeManagerFactory.NAME); queryRunner.loadPlanCheckerProviderManager("native", ImmutableMap.of()); + queryRunner.installPlugin(new NativeSqlInvokedFunctionsPlugin()); } @Test @@ -163,6 +173,7 @@ public void testSetNativeWorkerSessionProperty() @Test public void testShowFunctions() { + int inlinedSQLFunctionsCount = 0; @Language("SQL") String sql = "SHOW FUNCTIONS"; MaterializedResult actualResult = computeActual(sql); List actualRows = actualResult.getMaterializedRows(); @@ -176,11 +187,17 @@ public void testShowFunctions() // function namespace should be present. String fullFunctionName = row.get(5).toString(); - if (Pattern.matches(REGEX_FUNCTION_NAMESPACE, fullFunctionName)) { - continue; + if (!Pattern.matches(REGEX_FUNCTION_NAMESPACE, fullFunctionName)) { + // If no namespace match found, check if it's an inlined SQL Invoked function. + String language = row.get(9).toString(); + if (language.equalsIgnoreCase("SQL")) { + inlinedSQLFunctionsCount++; + continue; + } + fail(format("No namespace match found for row: %s", row)); } - fail(format("No namespace match found for row: %s", row)); } + assertEquals(inlinedSQLFunctionsCount, INLINED_SQL_FUNCTIONS_COUNT); } @Test @@ -321,7 +338,7 @@ public void testApproxPercentile() public void testInformationSchemaTables() { assertQuery("select lower(table_name) from information_schema.tables " - + "where table_name = 'lineitem' or table_name = 'LINEITEM' "); + + "where table_name = 'lineitem' or table_name = 'LINEITEM' "); } @Test @@ -423,6 +440,105 @@ public void testRemoveMapCast() "values 0.5, 0.1"); } + @Test + public void testOverriddenInlinedSqlInvokedFunctions() + { + // String functions + assertQuery("SELECT trail(comment, cast(nationkey as integer)) FROM nation"); + assertQuery("SELECT name, comment, replace_first(comment, 'iron', 'gold') from nation"); + + // Array functions + assertQuery("SELECT array_intersect(ARRAY['apple', 'banana', 'cherry'], ARRAY['apple', 'mango', 'fig'])"); + assertQuery("SELECT array_frequency(split(comment, '')) from nation"); + assertQuery("SELECT array_duplicates(ARRAY[regionkey]), array_duplicates(ARRAY[comment]) from nation"); + assertQuery("SELECT array_has_duplicates(ARRAY[custkey]) from orders"); + assertQuery("SELECT array_max_by(ARRAY[comment], x -> length(x)) from orders"); + assertQuery("SELECT array_min_by(ARRAY[ROW('USA', 1), ROW('INDIA', 2), ROW('UK', 3)], x -> x[2])"); + assertQuery("SELECT array_sort_desc(map_keys(map_union(quantity_by_linenumber))) FROM orders_ex"); + assertQuery("SELECT remove_nulls(ARRAY[CAST(regionkey AS VARCHAR), comment, NULL]) from nation"); + assertQuery("SELECT array_top_n(ARRAY[CAST(nationkey AS VARCHAR)], 3) from nation"); + assertQuerySucceeds("SELECT array_sort_desc(quantities, x -> abs(x)) FROM orders_ex"); + + // Map functions + assertQuery("SELECT map_normalize(MAP(ARRAY['a', 'b', 'c'], ARRAY[1, 4, 5]))"); + assertQuery("SELECT map_normalize(MAP(ARRAY['a', 'b', 'c'], ARRAY[1, 0, -1]))"); + assertQuery("SELECT name, map_normalize(MAP(ARRAY['regionkey', 'length'], ARRAY[regionkey, length(comment)])) from nation"); + assertQuery("SELECT name, map_remove_null_values(map(ARRAY['region', 'comment', 'nullable'], " + + "ARRAY[CAST(regionkey AS VARCHAR), comment, NULL])) from nation"); + assertQuery("SELECT name, map_key_exists(map(ARRAY['nation', 'comment'], ARRAY[CAST(nationkey AS VARCHAR), comment]), 'comment') from nation"); + assertQuery("SELECT map_keys_by_top_n_values(MAP(ARRAY[orderkey], ARRAY[custkey]), 2) from orders"); + assertQuery("SELECT map_top_n(MAP(ARRAY[CAST(nationkey AS VARCHAR)], ARRAY[comment]), 3) from nation"); + assertQuery("SELECT map_top_n_keys(MAP(ARRAY[orderkey], ARRAY[custkey]), 3) from orders"); + assertQuery("SELECT map_top_n_values(MAP(ARRAY[orderkey], ARRAY[custkey]), 3) from orders"); + assertQuery("SELECT all_keys_match(MAP(ARRAY[comment], ARRAY[custkey]), k -> length(k) > 5) from orders"); + assertQuery("SELECT any_keys_match(MAP(ARRAY[comment], ARRAY[custkey]), k -> starts_with(k, 'abc')) from orders"); + assertQuery("SELECT any_values_match(MAP(ARRAY[orderkey], ARRAY[totalprice]), k -> abs(k) > 20) from orders"); + assertQuery("SELECT no_values_match(MAP(ARRAY[orderkey], ARRAY[comment]), k -> length(k) > 2) from orders"); + assertQuery("SELECT no_keys_match(MAP(ARRAY[comment], ARRAY[custkey]), k -> ends_with(k, 'a')) from orders"); + } + + @Test + public void testNonOverriddenInlinedSqlInvokedFunctionsWhenConfigEnabled() + { + // Array functions + assertQuery("SELECT array_split_into_chunks(split(comment, ''), 2) from nation"); + assertQuery("SELECT array_least_frequent(quantities) from orders_ex"); + assertQuery("SELECT array_least_frequent(split(comment, ''), 5) from nation"); + assertQuerySucceeds("SELECT array_top_n(ARRAY[orderkey], 25, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from orders"); + + // Map functions + assertQuerySucceeds("SELECT map_top_n_values(MAP(ARRAY[comment], ARRAY[nationkey]), 2, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation"); + assertQuerySucceeds("SELECT map_top_n_keys(MAP(ARRAY[regionkey], ARRAY[nationkey]), 5, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation"); + + Session sessionWithKeyBasedSampling = Session.builder(getSession()) + .setSystemProperty(KEY_BASED_SAMPLING_ENABLED, "true") + .build(); + + @Language("SQL") String query = "select count(1) FROM lineitem l left JOIN orders o ON l.orderkey = o.orderkey JOIN customer c ON o.custkey = c.custkey"; + + assertQuery(query, "select cast(60175 as bigint)"); + assertQuery(sessionWithKeyBasedSampling, query, "select cast(16185 as bigint)"); + } + + @Test + public void testNonOverriddenInlinedSqlInvokedFunctionsWhenConfigDisabled() + { + // When inline_sql_functions is set to false, the below queries should fail as the implementations don't exist on the native worker + Session session = Session.builder(getSession()) + .setSystemProperty(KEY_BASED_SAMPLING_ENABLED, "true") + .setSystemProperty(INLINE_SQL_FUNCTIONS, "false") + .build(); + + // Array functions + assertQueryFails(session, + "SELECT array_split_into_chunks(split(comment, ''), 2) from nation", + ".*Scalar function name not registered: native.default.array_split_into_chunks.*"); + assertQueryFails(session, + "SELECT array_least_frequent(quantities) from orders_ex", + ".*Scalar function name not registered: native.default.array_least_frequent.*"); + assertQueryFails(session, + "SELECT array_least_frequent(split(comment, ''), 2) from nation", + ".*Scalar function name not registered: native.default.array_least_frequent.*"); + assertQueryFails(session, + "SELECT array_top_n(ARRAY[orderkey], 25, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from orders", + " Scalar function native\\.default\\.array_top_n not registered with arguments.*", + true); + + // Map functions + assertQueryFails(session, + "SELECT map_top_n_values(MAP(ARRAY[comment], ARRAY[nationkey]), 2, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation", + ".*Scalar function native\\.default\\.map_top_n_values not registered with arguments.*", + true); + assertQueryFails(session, + "SELECT map_top_n_keys(MAP(ARRAY[regionkey], ARRAY[nationkey]), 5, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation", + ".*Scalar function native\\.default\\.map_top_n_keys not registered with arguments.*", + true); + + assertQueryFails(session, + "select count(1) FROM lineitem l left JOIN orders o ON l.orderkey = o.orderkey JOIN customer c ON o.custkey = c.custkey", + ".*Scalar function name not registered: native.default.key_sampling_percent.*"); + } + private String generateRandomTableName() { String tableName = "tmp_presto_" + UUID.randomUUID().toString().replace("-", ""); diff --git a/presto-native-sql-invoked-functions-plugin/pom.xml b/presto-native-sql-invoked-functions-plugin/pom.xml new file mode 100644 index 0000000000000..7d837a28a8bfb --- /dev/null +++ b/presto-native-sql-invoked-functions-plugin/pom.xml @@ -0,0 +1,29 @@ + + 4.0.0 + + com.facebook.presto + presto-root + 0.295-SNAPSHOT + + + presto-native-sql-invoked-functions-plugin + Presto Native - Sql invoked functions plugin + presto-plugin + + + ${project.parent.basedir} + + + + + com.facebook.presto + presto-spi + provided + + + com.google.guava + guava + + + diff --git a/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeArraySqlFunctions.java b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeArraySqlFunctions.java new file mode 100644 index 0000000000000..841883d99ae8f --- /dev/null +++ b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeArraySqlFunctions.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.scalar.sql; + +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.SqlInvokedScalarFunction; +import com.facebook.presto.spi.function.SqlParameter; +import com.facebook.presto.spi.function.SqlParameters; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; + +public class NativeArraySqlFunctions +{ + private NativeArraySqlFunctions() {} + + @SqlInvokedScalarFunction(value = "array_split_into_chunks", deterministic = true, calledOnNullInput = false) + @Description("Returns an array of arrays splitting input array into chunks of given length. " + + "If array is not evenly divisible it will split into as many possible chunks and " + + "return the left over elements for the last array. Returns null for null inputs, but not elements.") + @TypeParameter("T") + @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "sz", type = "int")}) + @SqlType("array(array(T))") + public static String arraySplitIntoChunks() + { + return "RETURN IF(sz <= 0, " + + "fail('Invalid slice size: ' || cast(sz as varchar) || '. Size must be greater than zero.'), " + + "IF(cardinality(input) / sz > 10000, " + + "fail('Cannot split array of size: ' || cast(cardinality(input) as varchar) || ' into more than 10000 parts.'), " + + "transform(" + + "sequence(1, cardinality(input), sz), " + + "x -> slice(input, x, sz))))"; + } + + @SqlInvokedScalarFunction(value = "array_least_frequent", deterministic = true, calledOnNullInput = true) + @Description("Determines the least frequent element in the array. If there are multiple elements, the function returns the smallest element") + @TypeParameter("T") + @SqlParameter(name = "input", type = "array(T)") + @SqlType("array") + public static String arrayLeastFrequent() + { + return "RETURN IF(COALESCE(CARDINALITY(REMOVE_NULLS(input)), 0) = 0, NULL, TRANSFORM(SLICE(ARRAY_SORT(TRANSFORM(MAP_ENTRIES(ARRAY_FREQUENCY(REMOVE_NULLS(input))), x -> ROW(x[2], x[1]))), 1, 1), x -> x[2]))"; + } + + @SqlInvokedScalarFunction(value = "array_least_frequent", deterministic = true, calledOnNullInput = true) + @Description("Determines the n least frequent element in the array in the ascending order of the elements.") + @TypeParameter("T") + @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "n", type = "bigint")}) + @SqlType("array") + public static String arrayNLeastFrequent() + { + return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), IF(COALESCE(CARDINALITY(REMOVE_NULLS(input)), 0) = 0, NULL, TRANSFORM(SLICE(ARRAY_SORT(TRANSFORM(MAP_ENTRIES(ARRAY_FREQUENCY(REMOVE_NULLS(input))), x -> ROW(x[2], x[1]))), 1, n), x -> x[2])))"; + } + + @SqlInvokedScalarFunction(value = "array_top_n", deterministic = true, calledOnNullInput = true) + @Description("Returns the top N values of the given map sorted using the provided lambda comparator.") + @TypeParameter("T") + @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "n", type = "int"), @SqlParameter(name = "f", type = "function(T, T, bigint)")}) + @SqlType("array") + public static String arrayTopNComparator() + { + return "RETURN IF(n < 0, fail('Parameter n: ' || cast(n as varchar) || ' to ARRAY_TOP_N is negative'), SLICE(REVERSE(ARRAY_SORT(input, f)), 1, n))"; + } +} diff --git a/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeMapSqlFunctions.java b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeMapSqlFunctions.java new file mode 100644 index 0000000000000..9eccc84d6d8c8 --- /dev/null +++ b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeMapSqlFunctions.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.scalar.sql; + +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.SqlInvokedScalarFunction; +import com.facebook.presto.spi.function.SqlParameter; +import com.facebook.presto.spi.function.SqlParameters; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; + +public class NativeMapSqlFunctions +{ + private NativeMapSqlFunctions() {} + + @SqlInvokedScalarFunction(value = "map_top_n_keys", deterministic = true, calledOnNullInput = true) + @Description("Returns the top N keys of the given map sorting its keys using the provided lambda comparator.") + @TypeParameter("K") + @TypeParameter("V") + @SqlParameters({@SqlParameter(name = "input", type = "map(K, V)"), @SqlParameter(name = "n", type = "bigint"), @SqlParameter(name = "f", type = "function(K, K, bigint)")}) + @SqlType("array") + public static String mapTopNKeysComparator() + { + return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), slice(reverse(array_sort(map_keys(input), f)), 1, n))"; + } + + @SqlInvokedScalarFunction(value = "map_top_n_values", deterministic = true, calledOnNullInput = true) + @Description("Returns the top N values of the given map sorted using the provided lambda comparator.") + @TypeParameter("K") + @TypeParameter("V") + @SqlParameters({@SqlParameter(name = "input", type = "map(K, V)"), @SqlParameter(name = "n", type = "bigint"), @SqlParameter(name = "f", type = "function(V, V, bigint)")}) + @SqlType("array") + public static String mapTopNValuesComparator() + { + return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), slice(reverse(array_sort(remove_nulls(map_values(input)), f)) || filter(map_values(input), x -> x is null), 1, n))"; + } +} diff --git a/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSimpleSamplingPercent.java b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSimpleSamplingPercent.java new file mode 100644 index 0000000000000..a710391760714 --- /dev/null +++ b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSimpleSamplingPercent.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.scalar.sql; + +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.SqlInvokedScalarFunction; +import com.facebook.presto.spi.function.SqlParameter; +import com.facebook.presto.spi.function.SqlType; + +public class NativeSimpleSamplingPercent +{ + private NativeSimpleSamplingPercent() {} + + @SqlInvokedScalarFunction(value = "key_sampling_percent", deterministic = true, calledOnNullInput = false) + @Description("Returns a value between 0.0 and 1.0 using the hash of the given input string") + @SqlParameter(name = "input", type = "varchar") + @SqlType("double") + public static String keySamplingPercent() + { + return "return (abs(from_ieee754_64(xxhash64(cast(input as varbinary)))) % 100) / 100. "; + } +} diff --git a/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSqlInvokedFunctionsPlugin.java b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSqlInvokedFunctionsPlugin.java new file mode 100644 index 0000000000000..69d7ff1e78522 --- /dev/null +++ b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSqlInvokedFunctionsPlugin.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.scalar.sql; + +import com.facebook.presto.spi.Plugin; +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +public class NativeSqlInvokedFunctionsPlugin + implements Plugin +{ + @Override + public Set> getSqlInvokedFunctions() + { + return ImmutableSet.>builder() + .add(NativeArraySqlFunctions.class) + .add(NativeMapSqlFunctions.class) + .add(NativeSimpleSamplingPercent.class) + .build(); + } +} diff --git a/presto-native-tests/pom.xml b/presto-native-tests/pom.xml index e90785cbb6940..7e23670117b1d 100644 --- a/presto-native-tests/pom.xml +++ b/presto-native-tests/pom.xml @@ -192,6 +192,13 @@ units test + + + com.facebook.presto + presto-native-sql-invoked-functions-plugin + ${project.version} + test + diff --git a/presto-plan-checker-router-plugin/pom.xml b/presto-plan-checker-router-plugin/pom.xml index 60e34e4147d19..db1c4b0f1736c 100644 --- a/presto-plan-checker-router-plugin/pom.xml +++ b/presto-plan-checker-router-plugin/pom.xml @@ -223,6 +223,13 @@ presto-hive-metastore test + + + com.facebook.presto + presto-native-sql-invoked-functions-plugin + ${project.version} + test + diff --git a/presto-product-tests/conf/docker/common/compose-commons.sh b/presto-product-tests/conf/docker/common/compose-commons.sh index eae9f18ce9583..5c20783716b60 100644 --- a/presto-product-tests/conf/docker/common/compose-commons.sh +++ b/presto-product-tests/conf/docker/common/compose-commons.sh @@ -39,6 +39,16 @@ if [[ -z "${PRESTO_SERVER_DIR:-}" ]]; then source "${PRODUCT_TESTS_ROOT}/target/classes/presto.env" PRESTO_SERVER_DIR="${PROJECT_ROOT}/presto-server/target/presto-server-${PRESTO_VERSION}/" fi + +# The following plugin results in a function signature conflict when loaded in Java/ sidecar disabled native clusters. +# This plugin is only meant for sidecar enabled native clusters, hence exclude it. +PLUGIN_TO_EXCLUDE="native-sql-invoked-functions-plugin" + +if [[ -d "${PRESTO_SERVER_DIR}/plugin/${PLUGIN_TO_EXCLUDE}" ]]; then + echo "Excluding plugin: $PLUGIN_TO_EXCLUDE" + rm -rf "${PRESTO_SERVER_DIR}/plugin/${PLUGIN_TO_EXCLUDE}" +fi + export_canonical_path PRESTO_SERVER_DIR if [[ -z "${PRESTO_CLI_JAR:-}" ]]; then diff --git a/presto-server/src/main/provisio/presto.xml b/presto-server/src/main/provisio/presto.xml index b14b36a768e69..d15a041c7d1f5 100644 --- a/presto-server/src/main/provisio/presto.xml +++ b/presto-server/src/main/provisio/presto.xml @@ -292,4 +292,10 @@ + + + + + +