diff --git a/.github/workflows/prestocpp-linux-build-and-unit-test.yml b/.github/workflows/prestocpp-linux-build-and-unit-test.yml index b28f2292252ea..f3fc4843c3321 100644 --- a/.github/workflows/prestocpp-linux-build-and-unit-test.yml +++ b/.github/workflows/prestocpp-linux-build-and-unit-test.yml @@ -370,7 +370,7 @@ jobs: # Use different Maven options to install. MAVEN_OPTS: "-Xmx2G -XX:+ExitOnOutOfMemoryError" run: | - for i in $(seq 1 3); do ./mvnw clean install $MAVEN_FAST_INSTALL -pl 'presto-native-execution' -am && s=0 && break || s=$? && sleep 10; done; (exit $s) + for i in $(seq 1 3); do ./mvnw clean install $MAVEN_FAST_INSTALL -pl 'presto-native-sidecar-plugin' -am && s=0 && break || s=$? && sleep 10; done; (exit $s) - name: Run presto-native sidecar tests if: | diff --git a/CODEOWNERS b/CODEOWNERS index 0b2831ba7ee31..e4a0e3bb56aed 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -123,6 +123,7 @@ CODEOWNERS @prestodb/team-tsc ##################################################################### # Presto on Spark module /presto-spark* @shrinidhijoshi @prestodb/committers +/presto-native-execution/*/com/facebook/presto/spark/* @shrinidhijoshi @prestodb/committers ##################################################################### # Presto connectors and plugins @@ -166,4 +167,3 @@ CODEOWNERS @prestodb/team-tsc # Presto CI and builds /.github @czentgr @unidevel @prestodb/committers /docker @czentgr @unidevel @prestodb/committers - diff --git a/pom.xml b/pom.xml index 15a6f0ec95a20..f73719dacd94a 100644 --- a/pom.xml +++ b/pom.xml @@ -90,8 +90,7 @@ 6.0.0 17.0.0 3.5.4 - - 2 + 2.0.2-6 + org.codehaus.plexus:plexus-utils + com.google.guava:guava + com.fasterxml.jackson.core:jackson-annotations + com.fasterxml.jackson.core:jackson-core + com.fasterxml.jackson.core:jackson-databind + + + + + + + + org.basepom.maven + duplicate-finder-maven-plugin + + true + + com.github.benmanes.caffeine.* + + META-INF.versions.9.module-info + + META-INF.versions.11.module-info + + META-INF.versions.9.org.apache.lucene.* + + + + + + + + + + spark3 @@ -3173,19 +3238,12 @@ - 3 + 3.4.1-1 - - - - com.facebook.presto.spark - spark-core - 3.4.1-1 - provided - - - + + presto-spark-classloader-spark3 + diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index 40c493fc887c6..d37d658e69204 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -67,6 +67,7 @@ import com.facebook.presto.operator.aggregation.DoubleCovarianceAggregation; import com.facebook.presto.operator.aggregation.DoubleHistogramAggregation; import com.facebook.presto.operator.aggregation.DoubleRegressionAggregation; +import com.facebook.presto.operator.aggregation.DoubleRegressionExtendedAggregation; import com.facebook.presto.operator.aggregation.DoubleSumAggregation; import com.facebook.presto.operator.aggregation.EntropyAggregation; import com.facebook.presto.operator.aggregation.GeometricMeanAggregations; @@ -84,6 +85,7 @@ import com.facebook.presto.operator.aggregation.RealGeometricMeanAggregations; import com.facebook.presto.operator.aggregation.RealHistogramAggregation; import com.facebook.presto.operator.aggregation.RealRegressionAggregation; +import com.facebook.presto.operator.aggregation.RealRegressionExtendedAggregation; import com.facebook.presto.operator.aggregation.RealSumAggregation; import com.facebook.presto.operator.aggregation.ReduceAggregationFunction; import com.facebook.presto.operator.aggregation.SumDataSizeForStats; @@ -744,7 +746,9 @@ private List getBuiltInFunctions(FunctionsConfig function .aggregates(DoubleCovarianceAggregation.class) .aggregates(RealCovarianceAggregation.class) .aggregates(DoubleRegressionAggregation.class) + .aggregates(DoubleRegressionExtendedAggregation.class) .aggregates(RealRegressionAggregation.class) + .aggregates(RealRegressionExtendedAggregation.class) .aggregates(DoubleCorrelationAggregation.class) .aggregates(RealCorrelationAggregation.class) .aggregates(BitwiseOrAggregation.class) diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java index 578e782bc8d91..c78186ebb2417 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java @@ -22,6 +22,7 @@ import com.facebook.presto.operator.aggregation.state.CentralMomentsState; import com.facebook.presto.operator.aggregation.state.CorrelationState; import com.facebook.presto.operator.aggregation.state.CovarianceState; +import com.facebook.presto.operator.aggregation.state.ExtendedRegressionState; import com.facebook.presto.operator.aggregation.state.RegressionState; import com.facebook.presto.operator.aggregation.state.VarianceState; import com.facebook.presto.spi.function.AggregationFunctionImplementation; @@ -145,9 +146,14 @@ public static double getCorrelation(CorrelationState state) public static void updateRegressionState(RegressionState state, double x, double y) { double oldMeanX = state.getMeanX(); - double oldMeanY = state.getMeanY(); updateCovarianceState(state, x, y); state.setM2X(state.getM2X() + (x - oldMeanX) * (x - state.getMeanX())); + } + + public static void updateExtendedRegressionState(ExtendedRegressionState state, double x, double y) + { + double oldMeanY = state.getMeanY(); + updateRegressionState(state, x, y); state.setM2Y(state.getM2Y() + (y - oldMeanY) * (y - state.getMeanY())); } @@ -189,12 +195,12 @@ public static double getRegressionSxy(RegressionState state) return state.getC2(); } - public static double getRegressionSyy(RegressionState state) + public static double getRegressionSyy(ExtendedRegressionState state) { return state.getM2Y(); } - public static double getRegressionR2(RegressionState state) + public static double getRegressionR2(ExtendedRegressionState state) { if (state.getM2X() != 0 && state.getM2Y() == 0) { return 1.0; @@ -311,10 +317,21 @@ public static void mergeRegressionState(RegressionState state, RegressionState o long na = state.getCount(); long nb = otherState.getCount(); state.setM2X(state.getM2X() + otherState.getM2X() + na * nb * Math.pow(state.getMeanX() - otherState.getMeanX(), 2) / (double) (na + nb)); - state.setM2Y(state.getM2Y() + otherState.getM2Y() + na * nb * Math.pow(state.getMeanY() - otherState.getMeanY(), 2) / (double) (na + nb)); updateCovarianceState(state, otherState); } + public static void mergeExtendedRegressionState(ExtendedRegressionState state, ExtendedRegressionState otherState) + { + if (otherState.getCount() == 0) { + return; + } + + long na = state.getCount(); + long nb = otherState.getCount(); + state.setM2Y(state.getM2Y() + otherState.getM2Y() + na * nb * Math.pow(state.getMeanY() - otherState.getMeanY(), 2) / (double) (na + nb)); + mergeRegressionState(state, otherState); + } + public static String generateAggregationName(String baseName, TypeSignature outputType, List inputTypes) { StringBuilder sb = new StringBuilder(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java index 24d1c6e61fcf5..db3ad26ec5d6d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java @@ -24,15 +24,8 @@ import com.facebook.presto.spi.function.SqlType; import static com.facebook.presto.common.type.DoubleType.DOUBLE; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionIntercept; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSlope; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; import static com.facebook.presto.operator.aggregation.AggregationUtils.mergeRegressionState; import static com.facebook.presto.operator.aggregation.AggregationUtils.updateRegressionState; @@ -78,100 +71,4 @@ public static void regrIntercept(@AggregationState RegressionState state, BlockB out.appendNull(); } } - - @AggregationFunction("regr_sxy") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrSxy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSxy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_sxx") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrSxx(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSxx(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_syy") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrSyy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSyy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_r2") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrR2(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionR2(state); - if (Double.isFinite(result)) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_count") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrCount(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionCount(state); - if (Double.isFinite(result) && result > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_avgy") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrAvgy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionAvgy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_avgx") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrAvgx(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionAvgx(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionExtendedAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionExtendedAggregation.java new file mode 100644 index 0000000000000..3550cd0936949 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionExtendedAggregation.java @@ -0,0 +1,149 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.operator.aggregation.state.ExtendedRegressionState; +import com.facebook.presto.spi.function.AggregationFunction; +import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.CombineFunction; +import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.OutputFunction; +import com.facebook.presto.spi.function.SqlType; + +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.mergeExtendedRegressionState; +import static com.facebook.presto.operator.aggregation.AggregationUtils.updateExtendedRegressionState; + +@AggregationFunction +public class DoubleRegressionExtendedAggregation +{ + private DoubleRegressionExtendedAggregation() {} + + @InputFunction + public static void input(@AggregationState ExtendedRegressionState state, @SqlType(StandardTypes.DOUBLE) double dependentValue, @SqlType(StandardTypes.DOUBLE) double independentValue) + { + updateExtendedRegressionState(state, independentValue, dependentValue); + } + + @CombineFunction + public static void combine(@AggregationState ExtendedRegressionState state, @AggregationState ExtendedRegressionState otherState) + { + mergeExtendedRegressionState(state, otherState); + } + + @AggregationFunction("regr_sxy") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrSxy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSxy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_sxx") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrSxx(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSxx(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_syy") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrSyy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSyy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_r2") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrR2(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionR2(state); + if (Double.isFinite(result)) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_count") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrCount(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionCount(state); + if (Double.isFinite(result) && result > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgy") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrAvgy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionAvgy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgx") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrAvgx(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionAvgx(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java index 1fe5d006da1a9..a75222bfa93c4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java @@ -24,15 +24,8 @@ import com.facebook.presto.spi.function.SqlType; import static com.facebook.presto.common.type.RealType.REAL; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionIntercept; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSlope; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; import static java.lang.Float.floatToRawIntBits; import static java.lang.Float.intBitsToFloat; @@ -78,100 +71,4 @@ public static void regrIntercept(@AggregationState RegressionState state, BlockB out.appendNull(); } } - - @AggregationFunction("regr_sxy") - @OutputFunction(StandardTypes.REAL) - public static void regrSxy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSxy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_sxx") - @OutputFunction(StandardTypes.REAL) - public static void regrSxx(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSxx(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_syy") - @OutputFunction(StandardTypes.REAL) - public static void regrSyy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSyy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_r2") - @OutputFunction(StandardTypes.REAL) - public static void regrR2(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionR2(state); - if (Double.isFinite(result)) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_count") - @OutputFunction(StandardTypes.REAL) - public static void regrCount(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionCount(state); - if (Double.isFinite(result) && result > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_avgy") - @OutputFunction(StandardTypes.REAL) - public static void regrAvgy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionAvgy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_avgx") - @OutputFunction(StandardTypes.REAL) - public static void regrAvgx(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionAvgx(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionExtendedAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionExtendedAggregation.java new file mode 100644 index 0000000000000..2d0335ae9aca6 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionExtendedAggregation.java @@ -0,0 +1,149 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.operator.aggregation.state.ExtendedRegressionState; +import com.facebook.presto.spi.function.AggregationFunction; +import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.CombineFunction; +import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.OutputFunction; +import com.facebook.presto.spi.function.SqlType; + +import static com.facebook.presto.common.type.RealType.REAL; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; +import static java.lang.Float.floatToRawIntBits; +import static java.lang.Float.intBitsToFloat; + +@AggregationFunction +public class RealRegressionExtendedAggregation +{ + private RealRegressionExtendedAggregation() {} + + @InputFunction + public static void input(@AggregationState ExtendedRegressionState state, @SqlType(StandardTypes.REAL) long dependentValue, @SqlType(StandardTypes.REAL) long independentValue) + { + DoubleRegressionExtendedAggregation.input(state, intBitsToFloat((int) dependentValue), intBitsToFloat((int) independentValue)); + } + + @CombineFunction + public static void combine(@AggregationState ExtendedRegressionState state, @AggregationState ExtendedRegressionState otherState) + { + DoubleRegressionExtendedAggregation.combine(state, otherState); + } + + @AggregationFunction("regr_sxy") + @OutputFunction(StandardTypes.REAL) + public static void regrSxy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSxy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_sxx") + @OutputFunction(StandardTypes.REAL) + public static void regrSxx(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSxx(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_syy") + @OutputFunction(StandardTypes.REAL) + public static void regrSyy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSyy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_r2") + @OutputFunction(StandardTypes.REAL) + public static void regrR2(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionR2(state); + if (Double.isFinite(result)) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_count") + @OutputFunction(StandardTypes.REAL) + public static void regrCount(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionCount(state); + if (Double.isFinite(result) && result > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgy") + @OutputFunction(StandardTypes.REAL) + public static void regrAvgy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionAvgy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgx") + @OutputFunction(StandardTypes.REAL) + public static void regrAvgx(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionAvgx(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/ExtendedRegressionState.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/ExtendedRegressionState.java new file mode 100644 index 0000000000000..64a9883174158 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/ExtendedRegressionState.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.state; + +public interface ExtendedRegressionState + extends RegressionState +{ + double getM2Y(); + + void setM2Y(double value); +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java index 79837f90c0c11..ae3af6f46dc43 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java @@ -19,8 +19,4 @@ public interface RegressionState double getM2X(); void setM2X(double value); - - double getM2Y(); - - void setM2Y(double value); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java index f34359aeba6af..1c7856a62caea 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; -import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.Varchars; @@ -55,7 +54,6 @@ import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.VarcharType.VARCHAR; -import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static com.facebook.presto.metadata.CastType.CAST; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_NOT_FOUND; import static com.facebook.presto.spi.StandardWarningCode.SAMPLED_FIELDS; @@ -150,7 +148,7 @@ private PlanNode addSamplingFilter(PlanNode tableScanNode, Optional arguments) - { - FunctionHandle functionHandle = functionAndTypeManager.lookupFunction(qualifiedObjectName, fromTypes(arguments.stream().map(RowExpression::getType).collect(toImmutableList()))); - return call(String.valueOf(qualifiedObjectName), functionHandle, returnType, arguments); - } - public static CallExpression call(FunctionAndTypeResolver functionAndTypeResolver, String name, Type returnType, RowExpression... arguments) { FunctionHandle functionHandle = functionAndTypeResolver.lookupFunction(name, fromTypes(Arrays.stream(arguments).map(RowExpression::getType).collect(toImmutableList()))); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java index a3fe25d7b65fa..337582257bad4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java @@ -281,7 +281,7 @@ public boolean isEqualsFunction(FunctionHandle functionHandle) @Override public FunctionHandle subscriptFunction(Type baseType, Type indexType) { - return functionAndTypeResolver.lookupFunction(SUBSCRIPT.getFunctionName().getObjectName(), fromTypes(baseType, indexType)); + return functionAndTypeResolver.resolveOperator(SUBSCRIPT, fromTypes(baseType, indexType)); } @Override diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java index 3650f8cb36531..59a60b0024842 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java @@ -29,7 +29,7 @@ private NativeQueryRunnerUtils() {} public static Map getNativeWorkerHiveProperties() { return ImmutableMap.of("hive.parquet.pushdown-filter-enabled", "true", - "hive.orc-compression-codec", "ZSTD", "hive.storage-format", "DWRF"); + "hive.orc-compression-codec", "ZSTD", "hive.storage-format", "DWRF"); } public static Map getNativeWorkerIcebergProperties() @@ -59,6 +59,8 @@ public static Map getNativeSidecarProperties() .put("coordinator-sidecar-enabled", "true") .put("exclude-invalid-worker-session-properties", "true") .put("presto.default-namespace", "native.default") + // inline-sql-functions is overridden to be true in sidecar enabled native clusters. + .put("inline-sql-functions", "true") .build(); } diff --git a/presto-native-execution/src/test/java/com/facebook/presto/spark/PrestoSparkNativeQueryRunnerUtils.java b/presto-native-execution/src/test/java/com/facebook/presto/spark/PrestoSparkNativeQueryRunnerUtils.java index 14113b22d4fb3..21ef570381795 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/spark/PrestoSparkNativeQueryRunnerUtils.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/spark/PrestoSparkNativeQueryRunnerUtils.java @@ -79,7 +79,6 @@ public static Map getNativeExecutionSparkConfigs() .put("catalog.config-dir", "/") .put("task.info-update-interval", "100ms") .put("spark.initial-partition-count", "1") - .put("register-test-functions", "true") .put("native-execution-program-arguments", "--logtostderr=1 --minloglevel=3") .put("spark.partition-count-auto-tune-enabled", "false"); diff --git a/presto-native-sidecar-plugin/pom.xml b/presto-native-sidecar-plugin/pom.xml index b2bc40e8f2bd4..d04424440469a 100644 --- a/presto-native-sidecar-plugin/pom.xml +++ b/presto-native-sidecar-plugin/pom.xml @@ -260,9 +260,25 @@ + com.facebook.presto presto-built-in-worker-function-tools + ${project.version} + + + + com.facebook.presto + presto-native-sql-invoked-functions-plugin + ${project.version} + test + + + + com.facebook.presto + presto-sql-invoked-functions-plugin + ${project.version} + test diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java index 776d4920e2f16..c8c7e1123f974 100644 --- a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java +++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sidecar; +import com.facebook.presto.scalar.sql.NativeSqlInvokedFunctionsPlugin; import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManagerFactory; import com.facebook.presto.sidecar.sessionpropertyproviders.NativeSystemSessionPropertyProviderFactory; import com.facebook.presto.sidecar.typemanager.NativeTypeManagerFactory; @@ -37,5 +38,6 @@ public static void setupNativeSidecarPlugin(QueryRunner queryRunner) "function-implementation-type", "CPP")); queryRunner.loadTypeManager(NativeTypeManagerFactory.NAME); queryRunner.loadPlanCheckerProviderManager("native", ImmutableMap.of()); + queryRunner.installPlugin(new NativeSqlInvokedFunctionsPlugin()); } } diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java index 1c5d05d55b274..fe0b18c24b2fb 100644 --- a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java +++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java @@ -14,7 +14,10 @@ package com.facebook.presto.sidecar; import com.facebook.airlift.units.DataSize; +import com.facebook.presto.Session; import com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils; +import com.facebook.presto.scalar.sql.NativeSqlInvokedFunctionsPlugin; +import com.facebook.presto.scalar.sql.SqlInvokedFunctionsPlugin; import com.facebook.presto.sidecar.functionNamespace.FunctionDefinitionProvider; import com.facebook.presto.sidecar.functionNamespace.NativeFunctionDefinitionProvider; import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManager; @@ -44,8 +47,12 @@ import java.util.stream.Collectors; import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.presto.SystemSessionProperties.INLINE_SQL_FUNCTIONS; +import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_ENABLED; +import static com.facebook.presto.SystemSessionProperties.REMOVE_MAP_CAST; import static com.facebook.presto.common.Utils.checkArgument; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createCustomer; import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createLineitem; import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createNation; import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrders; @@ -63,6 +70,7 @@ public class TestNativeSidecarPlugin private static final String REGEX_FUNCTION_NAMESPACE = "native.default.*"; private static final String REGEX_SESSION_NAMESPACE = "Native Execution only.*"; private static final long SIDECAR_HTTP_CLIENT_MAX_CONTENT_SIZE_MB = 128; + private static final int INLINED_SQL_FUNCTIONS_COUNT = 7; @Override protected void createTables() @@ -73,6 +81,7 @@ protected void createTables() createOrders(queryRunner); createOrdersEx(queryRunner); createRegion(queryRunner); + createCustomer(queryRunner); } @Override @@ -91,9 +100,11 @@ protected QueryRunner createQueryRunner() protected QueryRunner createExpectedQueryRunner() throws Exception { - return PrestoNativeQueryRunnerUtils.javaHiveQueryRunnerBuilder() + QueryRunner queryRunner = PrestoNativeQueryRunnerUtils.javaHiveQueryRunnerBuilder() .setAddStorageFormatToPath(true) .build(); + queryRunner.installPlugin(new SqlInvokedFunctionsPlugin()); + return queryRunner; } public static void setupNativeSidecarPlugin(QueryRunner queryRunner) @@ -111,6 +122,7 @@ public static void setupNativeSidecarPlugin(QueryRunner queryRunner) "sidecar.http-client.max-content-length", SIDECAR_HTTP_CLIENT_MAX_CONTENT_SIZE_MB + "MB")); queryRunner.loadTypeManager(NativeTypeManagerFactory.NAME); queryRunner.loadPlanCheckerProviderManager("native", ImmutableMap.of()); + queryRunner.installPlugin(new NativeSqlInvokedFunctionsPlugin()); } @Test @@ -161,6 +173,7 @@ public void testSetNativeWorkerSessionProperty() @Test public void testShowFunctions() { + int inlinedSQLFunctionsCount = 0; @Language("SQL") String sql = "SHOW FUNCTIONS"; MaterializedResult actualResult = computeActual(sql); List actualRows = actualResult.getMaterializedRows(); @@ -174,11 +187,17 @@ public void testShowFunctions() // function namespace should be present. String fullFunctionName = row.get(5).toString(); - if (Pattern.matches(REGEX_FUNCTION_NAMESPACE, fullFunctionName)) { - continue; + if (!Pattern.matches(REGEX_FUNCTION_NAMESPACE, fullFunctionName)) { + // If no namespace match found, check if it's an inlined SQL Invoked function. + String language = row.get(9).toString(); + if (language.equalsIgnoreCase("SQL")) { + inlinedSQLFunctionsCount++; + continue; + } + fail(format("No namespace match found for row: %s", row)); } - fail(format("No namespace match found for row: %s", row)); } + assertEquals(inlinedSQLFunctionsCount, INLINED_SQL_FUNCTIONS_COUNT); } @Test @@ -319,7 +338,7 @@ public void testApproxPercentile() public void testInformationSchemaTables() { assertQuery("select lower(table_name) from information_schema.tables " - + "where table_name = 'lineitem' or table_name = 'LINEITEM' "); + + "where table_name = 'lineitem' or table_name = 'LINEITEM' "); } @Test @@ -405,6 +424,121 @@ public void testGeometryQueries() "Error from native plan checker: .SpatialJoinNode no abstract type PlanNode "); } + @Test + public void testRemoveMapCast() + { + Session enableOptimization = Session.builder(getSession()) + .setSystemProperty(REMOVE_MAP_CAST, "true") + .build(); + assertQuery(enableOptimization, "select feature[key] from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 4)) t(feature, key)", + "values 0.5, 0.1"); + assertQuery(enableOptimization, "select element_at(feature, key) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 4)) t(feature, key)", + "values 0.5, 0.1"); + assertQuery(enableOptimization, "select element_at(feature, key) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 400000000000)) t(feature, key)", + "values 0.5, null"); + assertQuery(enableOptimization, "select feature[key] from (values (map(array[cast(1 as varchar), '2', '3', '4'], array[0.3, 0.5, 0.9, 0.1]), cast('2' as varchar)), (map(array[cast(1 as varchar), '2', '3', '4'], array[0.3, 0.5, 0.9, 0.1]), '4')) t(feature, key)", + "values 0.5, 0.1"); + } + + @Test + public void testOverriddenInlinedSqlInvokedFunctions() + { + // String functions + assertQuery("SELECT trail(comment, cast(nationkey as integer)) FROM nation"); + assertQuery("SELECT name, comment, replace_first(comment, 'iron', 'gold') from nation"); + + // Array functions + assertQuery("SELECT array_intersect(ARRAY['apple', 'banana', 'cherry'], ARRAY['apple', 'mango', 'fig'])"); + assertQuery("SELECT array_frequency(split(comment, '')) from nation"); + assertQuery("SELECT array_duplicates(ARRAY[regionkey]), array_duplicates(ARRAY[comment]) from nation"); + assertQuery("SELECT array_has_duplicates(ARRAY[custkey]) from orders"); + assertQuery("SELECT array_max_by(ARRAY[comment], x -> length(x)) from orders"); + assertQuery("SELECT array_min_by(ARRAY[ROW('USA', 1), ROW('INDIA', 2), ROW('UK', 3)], x -> x[2])"); + assertQuery("SELECT array_sort_desc(map_keys(map_union(quantity_by_linenumber))) FROM orders_ex"); + assertQuery("SELECT remove_nulls(ARRAY[CAST(regionkey AS VARCHAR), comment, NULL]) from nation"); + assertQuery("SELECT array_top_n(ARRAY[CAST(nationkey AS VARCHAR)], 3) from nation"); + assertQuerySucceeds("SELECT array_sort_desc(quantities, x -> abs(x)) FROM orders_ex"); + + // Map functions + assertQuery("SELECT map_normalize(MAP(ARRAY['a', 'b', 'c'], ARRAY[1, 4, 5]))"); + assertQuery("SELECT map_normalize(MAP(ARRAY['a', 'b', 'c'], ARRAY[1, 0, -1]))"); + assertQuery("SELECT name, map_normalize(MAP(ARRAY['regionkey', 'length'], ARRAY[regionkey, length(comment)])) from nation"); + assertQuery("SELECT name, map_remove_null_values(map(ARRAY['region', 'comment', 'nullable'], " + + "ARRAY[CAST(regionkey AS VARCHAR), comment, NULL])) from nation"); + assertQuery("SELECT name, map_key_exists(map(ARRAY['nation', 'comment'], ARRAY[CAST(nationkey AS VARCHAR), comment]), 'comment') from nation"); + assertQuery("SELECT map_keys_by_top_n_values(MAP(ARRAY[orderkey], ARRAY[custkey]), 2) from orders"); + assertQuery("SELECT map_top_n(MAP(ARRAY[CAST(nationkey AS VARCHAR)], ARRAY[comment]), 3) from nation"); + assertQuery("SELECT map_top_n_keys(MAP(ARRAY[orderkey], ARRAY[custkey]), 3) from orders"); + assertQuery("SELECT map_top_n_values(MAP(ARRAY[orderkey], ARRAY[custkey]), 3) from orders"); + assertQuery("SELECT all_keys_match(MAP(ARRAY[comment], ARRAY[custkey]), k -> length(k) > 5) from orders"); + assertQuery("SELECT any_keys_match(MAP(ARRAY[comment], ARRAY[custkey]), k -> starts_with(k, 'abc')) from orders"); + assertQuery("SELECT any_values_match(MAP(ARRAY[orderkey], ARRAY[totalprice]), k -> abs(k) > 20) from orders"); + assertQuery("SELECT no_values_match(MAP(ARRAY[orderkey], ARRAY[comment]), k -> length(k) > 2) from orders"); + assertQuery("SELECT no_keys_match(MAP(ARRAY[comment], ARRAY[custkey]), k -> ends_with(k, 'a')) from orders"); + } + + @Test + public void testNonOverriddenInlinedSqlInvokedFunctionsWhenConfigEnabled() + { + // Array functions + assertQuery("SELECT array_split_into_chunks(split(comment, ''), 2) from nation"); + assertQuery("SELECT array_least_frequent(quantities) from orders_ex"); + assertQuery("SELECT array_least_frequent(split(comment, ''), 5) from nation"); + assertQuerySucceeds("SELECT array_top_n(ARRAY[orderkey], 25, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from orders"); + + // Map functions + assertQuerySucceeds("SELECT map_top_n_values(MAP(ARRAY[comment], ARRAY[nationkey]), 2, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation"); + assertQuerySucceeds("SELECT map_top_n_keys(MAP(ARRAY[regionkey], ARRAY[nationkey]), 5, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation"); + + Session sessionWithKeyBasedSampling = Session.builder(getSession()) + .setSystemProperty(KEY_BASED_SAMPLING_ENABLED, "true") + .build(); + + @Language("SQL") String query = "select count(1) FROM lineitem l left JOIN orders o ON l.orderkey = o.orderkey JOIN customer c ON o.custkey = c.custkey"; + + assertQuery(query, "select cast(60175 as bigint)"); + assertQuery(sessionWithKeyBasedSampling, query, "select cast(16185 as bigint)"); + } + + @Test + public void testNonOverriddenInlinedSqlInvokedFunctionsWhenConfigDisabled() + { + // When inline_sql_functions is set to false, the below queries should fail as the implementations don't exist on the native worker + Session session = Session.builder(getSession()) + .setSystemProperty(KEY_BASED_SAMPLING_ENABLED, "true") + .setSystemProperty(INLINE_SQL_FUNCTIONS, "false") + .build(); + + // Array functions + assertQueryFails(session, + "SELECT array_split_into_chunks(split(comment, ''), 2) from nation", + ".*Scalar function name not registered: native.default.array_split_into_chunks.*"); + assertQueryFails(session, + "SELECT array_least_frequent(quantities) from orders_ex", + ".*Scalar function name not registered: native.default.array_least_frequent.*"); + assertQueryFails(session, + "SELECT array_least_frequent(split(comment, ''), 2) from nation", + ".*Scalar function name not registered: native.default.array_least_frequent.*"); + assertQueryFails(session, + "SELECT array_top_n(ARRAY[orderkey], 25, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from orders", + " Scalar function native\\.default\\.array_top_n not registered with arguments.*", + true); + + // Map functions + assertQueryFails(session, + "SELECT map_top_n_values(MAP(ARRAY[comment], ARRAY[nationkey]), 2, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation", + ".*Scalar function native\\.default\\.map_top_n_values not registered with arguments.*", + true); + assertQueryFails(session, + "SELECT map_top_n_keys(MAP(ARRAY[regionkey], ARRAY[nationkey]), 5, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation", + ".*Scalar function native\\.default\\.map_top_n_keys not registered with arguments.*", + true); + + assertQueryFails(session, + "select count(1) FROM lineitem l left JOIN orders o ON l.orderkey = o.orderkey JOIN customer c ON o.custkey = c.custkey", + ".*Scalar function name not registered: native.default.key_sampling_percent.*"); + } + private String generateRandomTableName() { String tableName = "tmp_presto_" + UUID.randomUUID().toString().replace("-", ""); diff --git a/presto-native-sql-invoked-functions-plugin/pom.xml b/presto-native-sql-invoked-functions-plugin/pom.xml new file mode 100644 index 0000000000000..7d837a28a8bfb --- /dev/null +++ b/presto-native-sql-invoked-functions-plugin/pom.xml @@ -0,0 +1,29 @@ + + 4.0.0 + + com.facebook.presto + presto-root + 0.295-SNAPSHOT + + + presto-native-sql-invoked-functions-plugin + Presto Native - Sql invoked functions plugin + presto-plugin + + + ${project.parent.basedir} + + + + + com.facebook.presto + presto-spi + provided + + + com.google.guava + guava + + + diff --git a/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeArraySqlFunctions.java b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeArraySqlFunctions.java new file mode 100644 index 0000000000000..841883d99ae8f --- /dev/null +++ b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeArraySqlFunctions.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.scalar.sql; + +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.SqlInvokedScalarFunction; +import com.facebook.presto.spi.function.SqlParameter; +import com.facebook.presto.spi.function.SqlParameters; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; + +public class NativeArraySqlFunctions +{ + private NativeArraySqlFunctions() {} + + @SqlInvokedScalarFunction(value = "array_split_into_chunks", deterministic = true, calledOnNullInput = false) + @Description("Returns an array of arrays splitting input array into chunks of given length. " + + "If array is not evenly divisible it will split into as many possible chunks and " + + "return the left over elements for the last array. Returns null for null inputs, but not elements.") + @TypeParameter("T") + @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "sz", type = "int")}) + @SqlType("array(array(T))") + public static String arraySplitIntoChunks() + { + return "RETURN IF(sz <= 0, " + + "fail('Invalid slice size: ' || cast(sz as varchar) || '. Size must be greater than zero.'), " + + "IF(cardinality(input) / sz > 10000, " + + "fail('Cannot split array of size: ' || cast(cardinality(input) as varchar) || ' into more than 10000 parts.'), " + + "transform(" + + "sequence(1, cardinality(input), sz), " + + "x -> slice(input, x, sz))))"; + } + + @SqlInvokedScalarFunction(value = "array_least_frequent", deterministic = true, calledOnNullInput = true) + @Description("Determines the least frequent element in the array. If there are multiple elements, the function returns the smallest element") + @TypeParameter("T") + @SqlParameter(name = "input", type = "array(T)") + @SqlType("array") + public static String arrayLeastFrequent() + { + return "RETURN IF(COALESCE(CARDINALITY(REMOVE_NULLS(input)), 0) = 0, NULL, TRANSFORM(SLICE(ARRAY_SORT(TRANSFORM(MAP_ENTRIES(ARRAY_FREQUENCY(REMOVE_NULLS(input))), x -> ROW(x[2], x[1]))), 1, 1), x -> x[2]))"; + } + + @SqlInvokedScalarFunction(value = "array_least_frequent", deterministic = true, calledOnNullInput = true) + @Description("Determines the n least frequent element in the array in the ascending order of the elements.") + @TypeParameter("T") + @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "n", type = "bigint")}) + @SqlType("array") + public static String arrayNLeastFrequent() + { + return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), IF(COALESCE(CARDINALITY(REMOVE_NULLS(input)), 0) = 0, NULL, TRANSFORM(SLICE(ARRAY_SORT(TRANSFORM(MAP_ENTRIES(ARRAY_FREQUENCY(REMOVE_NULLS(input))), x -> ROW(x[2], x[1]))), 1, n), x -> x[2])))"; + } + + @SqlInvokedScalarFunction(value = "array_top_n", deterministic = true, calledOnNullInput = true) + @Description("Returns the top N values of the given map sorted using the provided lambda comparator.") + @TypeParameter("T") + @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "n", type = "int"), @SqlParameter(name = "f", type = "function(T, T, bigint)")}) + @SqlType("array") + public static String arrayTopNComparator() + { + return "RETURN IF(n < 0, fail('Parameter n: ' || cast(n as varchar) || ' to ARRAY_TOP_N is negative'), SLICE(REVERSE(ARRAY_SORT(input, f)), 1, n))"; + } +} diff --git a/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeMapSqlFunctions.java b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeMapSqlFunctions.java new file mode 100644 index 0000000000000..9eccc84d6d8c8 --- /dev/null +++ b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeMapSqlFunctions.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.scalar.sql; + +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.SqlInvokedScalarFunction; +import com.facebook.presto.spi.function.SqlParameter; +import com.facebook.presto.spi.function.SqlParameters; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; + +public class NativeMapSqlFunctions +{ + private NativeMapSqlFunctions() {} + + @SqlInvokedScalarFunction(value = "map_top_n_keys", deterministic = true, calledOnNullInput = true) + @Description("Returns the top N keys of the given map sorting its keys using the provided lambda comparator.") + @TypeParameter("K") + @TypeParameter("V") + @SqlParameters({@SqlParameter(name = "input", type = "map(K, V)"), @SqlParameter(name = "n", type = "bigint"), @SqlParameter(name = "f", type = "function(K, K, bigint)")}) + @SqlType("array") + public static String mapTopNKeysComparator() + { + return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), slice(reverse(array_sort(map_keys(input), f)), 1, n))"; + } + + @SqlInvokedScalarFunction(value = "map_top_n_values", deterministic = true, calledOnNullInput = true) + @Description("Returns the top N values of the given map sorted using the provided lambda comparator.") + @TypeParameter("K") + @TypeParameter("V") + @SqlParameters({@SqlParameter(name = "input", type = "map(K, V)"), @SqlParameter(name = "n", type = "bigint"), @SqlParameter(name = "f", type = "function(V, V, bigint)")}) + @SqlType("array") + public static String mapTopNValuesComparator() + { + return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), slice(reverse(array_sort(remove_nulls(map_values(input)), f)) || filter(map_values(input), x -> x is null), 1, n))"; + } +} diff --git a/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSimpleSamplingPercent.java b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSimpleSamplingPercent.java new file mode 100644 index 0000000000000..a710391760714 --- /dev/null +++ b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSimpleSamplingPercent.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.scalar.sql; + +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.SqlInvokedScalarFunction; +import com.facebook.presto.spi.function.SqlParameter; +import com.facebook.presto.spi.function.SqlType; + +public class NativeSimpleSamplingPercent +{ + private NativeSimpleSamplingPercent() {} + + @SqlInvokedScalarFunction(value = "key_sampling_percent", deterministic = true, calledOnNullInput = false) + @Description("Returns a value between 0.0 and 1.0 using the hash of the given input string") + @SqlParameter(name = "input", type = "varchar") + @SqlType("double") + public static String keySamplingPercent() + { + return "return (abs(from_ieee754_64(xxhash64(cast(input as varbinary)))) % 100) / 100. "; + } +} diff --git a/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSqlInvokedFunctionsPlugin.java b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSqlInvokedFunctionsPlugin.java new file mode 100644 index 0000000000000..69d7ff1e78522 --- /dev/null +++ b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSqlInvokedFunctionsPlugin.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.scalar.sql; + +import com.facebook.presto.spi.Plugin; +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +public class NativeSqlInvokedFunctionsPlugin + implements Plugin +{ + @Override + public Set> getSqlInvokedFunctions() + { + return ImmutableSet.>builder() + .add(NativeArraySqlFunctions.class) + .add(NativeMapSqlFunctions.class) + .add(NativeSimpleSamplingPercent.class) + .build(); + } +} diff --git a/presto-native-tests/pom.xml b/presto-native-tests/pom.xml index e90785cbb6940..7e23670117b1d 100644 --- a/presto-native-tests/pom.xml +++ b/presto-native-tests/pom.xml @@ -192,6 +192,13 @@ units test + + + com.facebook.presto + presto-native-sql-invoked-functions-plugin + ${project.version} + test + diff --git a/presto-plan-checker-router-plugin/pom.xml b/presto-plan-checker-router-plugin/pom.xml index 60e34e4147d19..db1c4b0f1736c 100644 --- a/presto-plan-checker-router-plugin/pom.xml +++ b/presto-plan-checker-router-plugin/pom.xml @@ -223,6 +223,13 @@ presto-hive-metastore test + + + com.facebook.presto + presto-native-sql-invoked-functions-plugin + ${project.version} + test + diff --git a/presto-product-tests/conf/docker/common/compose-commons.sh b/presto-product-tests/conf/docker/common/compose-commons.sh index eae9f18ce9583..5c20783716b60 100644 --- a/presto-product-tests/conf/docker/common/compose-commons.sh +++ b/presto-product-tests/conf/docker/common/compose-commons.sh @@ -39,6 +39,16 @@ if [[ -z "${PRESTO_SERVER_DIR:-}" ]]; then source "${PRODUCT_TESTS_ROOT}/target/classes/presto.env" PRESTO_SERVER_DIR="${PROJECT_ROOT}/presto-server/target/presto-server-${PRESTO_VERSION}/" fi + +# The following plugin results in a function signature conflict when loaded in Java/ sidecar disabled native clusters. +# This plugin is only meant for sidecar enabled native clusters, hence exclude it. +PLUGIN_TO_EXCLUDE="native-sql-invoked-functions-plugin" + +if [[ -d "${PRESTO_SERVER_DIR}/plugin/${PLUGIN_TO_EXCLUDE}" ]]; then + echo "Excluding plugin: $PLUGIN_TO_EXCLUDE" + rm -rf "${PRESTO_SERVER_DIR}/plugin/${PLUGIN_TO_EXCLUDE}" +fi + export_canonical_path PRESTO_SERVER_DIR if [[ -z "${PRESTO_CLI_JAR:-}" ]]; then diff --git a/presto-server/src/main/provisio/presto.xml b/presto-server/src/main/provisio/presto.xml index b14b36a768e69..d15a041c7d1f5 100644 --- a/presto-server/src/main/provisio/presto.xml +++ b/presto-server/src/main/provisio/presto.xml @@ -292,4 +292,10 @@ + + + + + + diff --git a/presto-spark-base/pom.xml b/presto-spark-base/pom.xml index 37dc1d4fb541b..7767219a7ec65 100644 --- a/presto-spark-base/pom.xml +++ b/presto-spark-base/pom.xml @@ -16,7 +16,6 @@ 9.4.55.v20240627 4.12.0 3.9.1 - 2 @@ -55,12 +54,6 @@ provided - - com.facebook.presto - presto-spark-classloader-spark${dep.pos.classloader.module-name.suffix} - provided - - com.facebook.presto presto-client @@ -538,6 +531,25 @@ + + spark2 + + + true + + !spark-version + + + + + + com.facebook.presto + presto-spark-classloader-spark2 + ${project.version} + + + + spark3 @@ -548,11 +560,13 @@ - - 3 - - + + com.facebook.presto + presto-spark-classloader-spark3 + ${project.version} + + com.facebook.presto.spark spark-core diff --git a/presto-spark-classloader-interface/pom.xml b/presto-spark-classloader-interface/pom.xml index 8f93be12818a4..a95e45827c5f1 100644 --- a/presto-spark-classloader-interface/pom.xml +++ b/presto-spark-classloader-interface/pom.xml @@ -13,7 +13,6 @@ ${project.parent.basedir} true - 2 @@ -23,11 +22,6 @@ provided - - com.facebook.presto - presto-spark-classloader-spark${dep.pos.classloader.module-name.suffix} - - com.google.guava guava @@ -40,6 +34,25 @@ + + spark2 + + + true + + !spark-version + + + + + + com.facebook.presto + presto-spark-classloader-spark2 + ${project.version} + + + + spark3 @@ -50,22 +63,12 @@ - - 3 - - - - - - com.facebook.presto.spark - spark-core - 3.4.1-1 - compile - - - - + + com.facebook.presto + presto-spark-classloader-spark3 + ${project.version} + org.scala-lang scala-library