diff --git a/.github/workflows/prestocpp-linux-build-and-unit-test.yml b/.github/workflows/prestocpp-linux-build-and-unit-test.yml
index b28f2292252ea..f3fc4843c3321 100644
--- a/.github/workflows/prestocpp-linux-build-and-unit-test.yml
+++ b/.github/workflows/prestocpp-linux-build-and-unit-test.yml
@@ -370,7 +370,7 @@ jobs:
# Use different Maven options to install.
MAVEN_OPTS: "-Xmx2G -XX:+ExitOnOutOfMemoryError"
run: |
- for i in $(seq 1 3); do ./mvnw clean install $MAVEN_FAST_INSTALL -pl 'presto-native-execution' -am && s=0 && break || s=$? && sleep 10; done; (exit $s)
+ for i in $(seq 1 3); do ./mvnw clean install $MAVEN_FAST_INSTALL -pl 'presto-native-sidecar-plugin' -am && s=0 && break || s=$? && sleep 10; done; (exit $s)
- name: Run presto-native sidecar tests
if: |
diff --git a/CODEOWNERS b/CODEOWNERS
index 0b2831ba7ee31..e4a0e3bb56aed 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -123,6 +123,7 @@ CODEOWNERS @prestodb/team-tsc
#####################################################################
# Presto on Spark module
/presto-spark* @shrinidhijoshi @prestodb/committers
+/presto-native-execution/*/com/facebook/presto/spark/* @shrinidhijoshi @prestodb/committers
#####################################################################
# Presto connectors and plugins
@@ -166,4 +167,3 @@ CODEOWNERS @prestodb/team-tsc
# Presto CI and builds
/.github @czentgr @unidevel @prestodb/committers
/docker @czentgr @unidevel @prestodb/committers
-
diff --git a/pom.xml b/pom.xml
index 15a6f0ec95a20..f73719dacd94a 100644
--- a/pom.xml
+++ b/pom.xml
@@ -90,8 +90,7 @@
6.0.0
17.0.0
3.5.4
-
- 2
+ 2.0.2-6
+ org.codehaus.plexus:plexus-utils
+ com.google.guava:guava
+ com.fasterxml.jackson.core:jackson-annotations
+ com.fasterxml.jackson.core:jackson-core
+ com.fasterxml.jackson.core:jackson-databind
+
+
+
+
+
+
+
+ org.basepom.maven
+ duplicate-finder-maven-plugin
+
+ true
+
+ com.github.benmanes.caffeine.*
+
+ META-INF.versions.9.module-info
+
+ META-INF.versions.11.module-info
+
+ META-INF.versions.9.org.apache.lucene.*
+
+
+
+
+
+
+
+
+
+
spark3
@@ -3173,19 +3238,12 @@
- 3
+ 3.4.1-1
-
-
-
- com.facebook.presto.spark
- spark-core
- 3.4.1-1
- provided
-
-
-
+
+ presto-spark-classloader-spark3
+
diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java
index 40c493fc887c6..d37d658e69204 100644
--- a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java
+++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java
@@ -67,6 +67,7 @@
import com.facebook.presto.operator.aggregation.DoubleCovarianceAggregation;
import com.facebook.presto.operator.aggregation.DoubleHistogramAggregation;
import com.facebook.presto.operator.aggregation.DoubleRegressionAggregation;
+import com.facebook.presto.operator.aggregation.DoubleRegressionExtendedAggregation;
import com.facebook.presto.operator.aggregation.DoubleSumAggregation;
import com.facebook.presto.operator.aggregation.EntropyAggregation;
import com.facebook.presto.operator.aggregation.GeometricMeanAggregations;
@@ -84,6 +85,7 @@
import com.facebook.presto.operator.aggregation.RealGeometricMeanAggregations;
import com.facebook.presto.operator.aggregation.RealHistogramAggregation;
import com.facebook.presto.operator.aggregation.RealRegressionAggregation;
+import com.facebook.presto.operator.aggregation.RealRegressionExtendedAggregation;
import com.facebook.presto.operator.aggregation.RealSumAggregation;
import com.facebook.presto.operator.aggregation.ReduceAggregationFunction;
import com.facebook.presto.operator.aggregation.SumDataSizeForStats;
@@ -744,7 +746,9 @@ private List extends SqlFunction> getBuiltInFunctions(FunctionsConfig function
.aggregates(DoubleCovarianceAggregation.class)
.aggregates(RealCovarianceAggregation.class)
.aggregates(DoubleRegressionAggregation.class)
+ .aggregates(DoubleRegressionExtendedAggregation.class)
.aggregates(RealRegressionAggregation.class)
+ .aggregates(RealRegressionExtendedAggregation.class)
.aggregates(DoubleCorrelationAggregation.class)
.aggregates(RealCorrelationAggregation.class)
.aggregates(BitwiseOrAggregation.class)
diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java
index 578e782bc8d91..c78186ebb2417 100644
--- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java
+++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java
@@ -22,6 +22,7 @@
import com.facebook.presto.operator.aggregation.state.CentralMomentsState;
import com.facebook.presto.operator.aggregation.state.CorrelationState;
import com.facebook.presto.operator.aggregation.state.CovarianceState;
+import com.facebook.presto.operator.aggregation.state.ExtendedRegressionState;
import com.facebook.presto.operator.aggregation.state.RegressionState;
import com.facebook.presto.operator.aggregation.state.VarianceState;
import com.facebook.presto.spi.function.AggregationFunctionImplementation;
@@ -145,9 +146,14 @@ public static double getCorrelation(CorrelationState state)
public static void updateRegressionState(RegressionState state, double x, double y)
{
double oldMeanX = state.getMeanX();
- double oldMeanY = state.getMeanY();
updateCovarianceState(state, x, y);
state.setM2X(state.getM2X() + (x - oldMeanX) * (x - state.getMeanX()));
+ }
+
+ public static void updateExtendedRegressionState(ExtendedRegressionState state, double x, double y)
+ {
+ double oldMeanY = state.getMeanY();
+ updateRegressionState(state, x, y);
state.setM2Y(state.getM2Y() + (y - oldMeanY) * (y - state.getMeanY()));
}
@@ -189,12 +195,12 @@ public static double getRegressionSxy(RegressionState state)
return state.getC2();
}
- public static double getRegressionSyy(RegressionState state)
+ public static double getRegressionSyy(ExtendedRegressionState state)
{
return state.getM2Y();
}
- public static double getRegressionR2(RegressionState state)
+ public static double getRegressionR2(ExtendedRegressionState state)
{
if (state.getM2X() != 0 && state.getM2Y() == 0) {
return 1.0;
@@ -311,10 +317,21 @@ public static void mergeRegressionState(RegressionState state, RegressionState o
long na = state.getCount();
long nb = otherState.getCount();
state.setM2X(state.getM2X() + otherState.getM2X() + na * nb * Math.pow(state.getMeanX() - otherState.getMeanX(), 2) / (double) (na + nb));
- state.setM2Y(state.getM2Y() + otherState.getM2Y() + na * nb * Math.pow(state.getMeanY() - otherState.getMeanY(), 2) / (double) (na + nb));
updateCovarianceState(state, otherState);
}
+ public static void mergeExtendedRegressionState(ExtendedRegressionState state, ExtendedRegressionState otherState)
+ {
+ if (otherState.getCount() == 0) {
+ return;
+ }
+
+ long na = state.getCount();
+ long nb = otherState.getCount();
+ state.setM2Y(state.getM2Y() + otherState.getM2Y() + na * nb * Math.pow(state.getMeanY() - otherState.getMeanY(), 2) / (double) (na + nb));
+ mergeRegressionState(state, otherState);
+ }
+
public static String generateAggregationName(String baseName, TypeSignature outputType, List inputTypes)
{
StringBuilder sb = new StringBuilder();
diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java
index 24d1c6e61fcf5..db3ad26ec5d6d 100644
--- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java
+++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java
@@ -24,15 +24,8 @@
import com.facebook.presto.spi.function.SqlType;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
-import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx;
-import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy;
-import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount;
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionIntercept;
-import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2;
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSlope;
-import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx;
-import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy;
-import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy;
import static com.facebook.presto.operator.aggregation.AggregationUtils.mergeRegressionState;
import static com.facebook.presto.operator.aggregation.AggregationUtils.updateRegressionState;
@@ -78,100 +71,4 @@ public static void regrIntercept(@AggregationState RegressionState state, BlockB
out.appendNull();
}
}
-
- @AggregationFunction("regr_sxy")
- @OutputFunction(StandardTypes.DOUBLE)
- public static void regrSxy(@AggregationState RegressionState state, BlockBuilder out)
- {
- double result = getRegressionSxy(state);
- double count = getRegressionCount(state);
- if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
- DOUBLE.writeDouble(out, result);
- }
- else {
- out.appendNull();
- }
- }
-
- @AggregationFunction("regr_sxx")
- @OutputFunction(StandardTypes.DOUBLE)
- public static void regrSxx(@AggregationState RegressionState state, BlockBuilder out)
- {
- double result = getRegressionSxx(state);
- double count = getRegressionCount(state);
- if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
- DOUBLE.writeDouble(out, result);
- }
- else {
- out.appendNull();
- }
- }
-
- @AggregationFunction("regr_syy")
- @OutputFunction(StandardTypes.DOUBLE)
- public static void regrSyy(@AggregationState RegressionState state, BlockBuilder out)
- {
- double result = getRegressionSyy(state);
- double count = getRegressionCount(state);
- if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
- DOUBLE.writeDouble(out, result);
- }
- else {
- out.appendNull();
- }
- }
-
- @AggregationFunction("regr_r2")
- @OutputFunction(StandardTypes.DOUBLE)
- public static void regrR2(@AggregationState RegressionState state, BlockBuilder out)
- {
- double result = getRegressionR2(state);
- if (Double.isFinite(result)) {
- DOUBLE.writeDouble(out, result);
- }
- else {
- out.appendNull();
- }
- }
-
- @AggregationFunction("regr_count")
- @OutputFunction(StandardTypes.DOUBLE)
- public static void regrCount(@AggregationState RegressionState state, BlockBuilder out)
- {
- double result = getRegressionCount(state);
- if (Double.isFinite(result) && result > 0) {
- DOUBLE.writeDouble(out, result);
- }
- else {
- out.appendNull();
- }
- }
-
- @AggregationFunction("regr_avgy")
- @OutputFunction(StandardTypes.DOUBLE)
- public static void regrAvgy(@AggregationState RegressionState state, BlockBuilder out)
- {
- double result = getRegressionAvgy(state);
- double count = getRegressionCount(state);
- if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
- DOUBLE.writeDouble(out, result);
- }
- else {
- out.appendNull();
- }
- }
-
- @AggregationFunction("regr_avgx")
- @OutputFunction(StandardTypes.DOUBLE)
- public static void regrAvgx(@AggregationState RegressionState state, BlockBuilder out)
- {
- double result = getRegressionAvgx(state);
- double count = getRegressionCount(state);
- if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
- DOUBLE.writeDouble(out, result);
- }
- else {
- out.appendNull();
- }
- }
}
diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionExtendedAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionExtendedAggregation.java
new file mode 100644
index 0000000000000..3550cd0936949
--- /dev/null
+++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionExtendedAggregation.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.operator.aggregation;
+
+import com.facebook.presto.common.block.BlockBuilder;
+import com.facebook.presto.common.type.StandardTypes;
+import com.facebook.presto.operator.aggregation.state.ExtendedRegressionState;
+import com.facebook.presto.spi.function.AggregationFunction;
+import com.facebook.presto.spi.function.AggregationState;
+import com.facebook.presto.spi.function.CombineFunction;
+import com.facebook.presto.spi.function.InputFunction;
+import com.facebook.presto.spi.function.OutputFunction;
+import com.facebook.presto.spi.function.SqlType;
+
+import static com.facebook.presto.common.type.DoubleType.DOUBLE;
+import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx;
+import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy;
+import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount;
+import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2;
+import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx;
+import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy;
+import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy;
+import static com.facebook.presto.operator.aggregation.AggregationUtils.mergeExtendedRegressionState;
+import static com.facebook.presto.operator.aggregation.AggregationUtils.updateExtendedRegressionState;
+
+@AggregationFunction
+public class DoubleRegressionExtendedAggregation
+{
+ private DoubleRegressionExtendedAggregation() {}
+
+ @InputFunction
+ public static void input(@AggregationState ExtendedRegressionState state, @SqlType(StandardTypes.DOUBLE) double dependentValue, @SqlType(StandardTypes.DOUBLE) double independentValue)
+ {
+ updateExtendedRegressionState(state, independentValue, dependentValue);
+ }
+
+ @CombineFunction
+ public static void combine(@AggregationState ExtendedRegressionState state, @AggregationState ExtendedRegressionState otherState)
+ {
+ mergeExtendedRegressionState(state, otherState);
+ }
+
+ @AggregationFunction("regr_sxy")
+ @OutputFunction(StandardTypes.DOUBLE)
+ public static void regrSxy(@AggregationState ExtendedRegressionState state, BlockBuilder out)
+ {
+ double result = getRegressionSxy(state);
+ double count = getRegressionCount(state);
+ if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
+ DOUBLE.writeDouble(out, result);
+ }
+ else {
+ out.appendNull();
+ }
+ }
+
+ @AggregationFunction("regr_sxx")
+ @OutputFunction(StandardTypes.DOUBLE)
+ public static void regrSxx(@AggregationState ExtendedRegressionState state, BlockBuilder out)
+ {
+ double result = getRegressionSxx(state);
+ double count = getRegressionCount(state);
+ if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
+ DOUBLE.writeDouble(out, result);
+ }
+ else {
+ out.appendNull();
+ }
+ }
+
+ @AggregationFunction("regr_syy")
+ @OutputFunction(StandardTypes.DOUBLE)
+ public static void regrSyy(@AggregationState ExtendedRegressionState state, BlockBuilder out)
+ {
+ double result = getRegressionSyy(state);
+ double count = getRegressionCount(state);
+ if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
+ DOUBLE.writeDouble(out, result);
+ }
+ else {
+ out.appendNull();
+ }
+ }
+
+ @AggregationFunction("regr_r2")
+ @OutputFunction(StandardTypes.DOUBLE)
+ public static void regrR2(@AggregationState ExtendedRegressionState state, BlockBuilder out)
+ {
+ double result = getRegressionR2(state);
+ if (Double.isFinite(result)) {
+ DOUBLE.writeDouble(out, result);
+ }
+ else {
+ out.appendNull();
+ }
+ }
+
+ @AggregationFunction("regr_count")
+ @OutputFunction(StandardTypes.DOUBLE)
+ public static void regrCount(@AggregationState ExtendedRegressionState state, BlockBuilder out)
+ {
+ double result = getRegressionCount(state);
+ if (Double.isFinite(result) && result > 0) {
+ DOUBLE.writeDouble(out, result);
+ }
+ else {
+ out.appendNull();
+ }
+ }
+
+ @AggregationFunction("regr_avgy")
+ @OutputFunction(StandardTypes.DOUBLE)
+ public static void regrAvgy(@AggregationState ExtendedRegressionState state, BlockBuilder out)
+ {
+ double result = getRegressionAvgy(state);
+ double count = getRegressionCount(state);
+ if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
+ DOUBLE.writeDouble(out, result);
+ }
+ else {
+ out.appendNull();
+ }
+ }
+
+ @AggregationFunction("regr_avgx")
+ @OutputFunction(StandardTypes.DOUBLE)
+ public static void regrAvgx(@AggregationState ExtendedRegressionState state, BlockBuilder out)
+ {
+ double result = getRegressionAvgx(state);
+ double count = getRegressionCount(state);
+ if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
+ DOUBLE.writeDouble(out, result);
+ }
+ else {
+ out.appendNull();
+ }
+ }
+}
diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java
index 1fe5d006da1a9..a75222bfa93c4 100644
--- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java
+++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java
@@ -24,15 +24,8 @@
import com.facebook.presto.spi.function.SqlType;
import static com.facebook.presto.common.type.RealType.REAL;
-import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx;
-import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy;
-import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount;
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionIntercept;
-import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2;
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSlope;
-import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx;
-import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy;
-import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy;
import static java.lang.Float.floatToRawIntBits;
import static java.lang.Float.intBitsToFloat;
@@ -78,100 +71,4 @@ public static void regrIntercept(@AggregationState RegressionState state, BlockB
out.appendNull();
}
}
-
- @AggregationFunction("regr_sxy")
- @OutputFunction(StandardTypes.REAL)
- public static void regrSxy(@AggregationState RegressionState state, BlockBuilder out)
- {
- double result = getRegressionSxy(state);
- double count = getRegressionCount(state);
- if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
- REAL.writeLong(out, floatToRawIntBits((float) result));
- }
- else {
- out.appendNull();
- }
- }
-
- @AggregationFunction("regr_sxx")
- @OutputFunction(StandardTypes.REAL)
- public static void regrSxx(@AggregationState RegressionState state, BlockBuilder out)
- {
- double result = getRegressionSxx(state);
- double count = getRegressionCount(state);
- if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
- REAL.writeLong(out, floatToRawIntBits((float) result));
- }
- else {
- out.appendNull();
- }
- }
-
- @AggregationFunction("regr_syy")
- @OutputFunction(StandardTypes.REAL)
- public static void regrSyy(@AggregationState RegressionState state, BlockBuilder out)
- {
- double result = getRegressionSyy(state);
- double count = getRegressionCount(state);
- if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
- REAL.writeLong(out, floatToRawIntBits((float) result));
- }
- else {
- out.appendNull();
- }
- }
-
- @AggregationFunction("regr_r2")
- @OutputFunction(StandardTypes.REAL)
- public static void regrR2(@AggregationState RegressionState state, BlockBuilder out)
- {
- double result = getRegressionR2(state);
- if (Double.isFinite(result)) {
- REAL.writeLong(out, floatToRawIntBits((float) result));
- }
- else {
- out.appendNull();
- }
- }
-
- @AggregationFunction("regr_count")
- @OutputFunction(StandardTypes.REAL)
- public static void regrCount(@AggregationState RegressionState state, BlockBuilder out)
- {
- double result = getRegressionCount(state);
- if (Double.isFinite(result) && result > 0) {
- REAL.writeLong(out, floatToRawIntBits((float) result));
- }
- else {
- out.appendNull();
- }
- }
-
- @AggregationFunction("regr_avgy")
- @OutputFunction(StandardTypes.REAL)
- public static void regrAvgy(@AggregationState RegressionState state, BlockBuilder out)
- {
- double result = getRegressionAvgy(state);
- double count = getRegressionCount(state);
- if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
- REAL.writeLong(out, floatToRawIntBits((float) result));
- }
- else {
- out.appendNull();
- }
- }
-
- @AggregationFunction("regr_avgx")
- @OutputFunction(StandardTypes.REAL)
- public static void regrAvgx(@AggregationState RegressionState state, BlockBuilder out)
- {
- double result = getRegressionAvgx(state);
- double count = getRegressionCount(state);
- if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
- REAL.writeLong(out, floatToRawIntBits((float) result));
- }
- else {
- out.appendNull();
- }
- }
}
diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionExtendedAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionExtendedAggregation.java
new file mode 100644
index 0000000000000..2d0335ae9aca6
--- /dev/null
+++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionExtendedAggregation.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.operator.aggregation;
+
+import com.facebook.presto.common.block.BlockBuilder;
+import com.facebook.presto.common.type.StandardTypes;
+import com.facebook.presto.operator.aggregation.state.ExtendedRegressionState;
+import com.facebook.presto.spi.function.AggregationFunction;
+import com.facebook.presto.spi.function.AggregationState;
+import com.facebook.presto.spi.function.CombineFunction;
+import com.facebook.presto.spi.function.InputFunction;
+import com.facebook.presto.spi.function.OutputFunction;
+import com.facebook.presto.spi.function.SqlType;
+
+import static com.facebook.presto.common.type.RealType.REAL;
+import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx;
+import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy;
+import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount;
+import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2;
+import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx;
+import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy;
+import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy;
+import static java.lang.Float.floatToRawIntBits;
+import static java.lang.Float.intBitsToFloat;
+
+@AggregationFunction
+public class RealRegressionExtendedAggregation
+{
+ private RealRegressionExtendedAggregation() {}
+
+ @InputFunction
+ public static void input(@AggregationState ExtendedRegressionState state, @SqlType(StandardTypes.REAL) long dependentValue, @SqlType(StandardTypes.REAL) long independentValue)
+ {
+ DoubleRegressionExtendedAggregation.input(state, intBitsToFloat((int) dependentValue), intBitsToFloat((int) independentValue));
+ }
+
+ @CombineFunction
+ public static void combine(@AggregationState ExtendedRegressionState state, @AggregationState ExtendedRegressionState otherState)
+ {
+ DoubleRegressionExtendedAggregation.combine(state, otherState);
+ }
+
+ @AggregationFunction("regr_sxy")
+ @OutputFunction(StandardTypes.REAL)
+ public static void regrSxy(@AggregationState ExtendedRegressionState state, BlockBuilder out)
+ {
+ double result = getRegressionSxy(state);
+ double count = getRegressionCount(state);
+ if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
+ REAL.writeLong(out, floatToRawIntBits((float) result));
+ }
+ else {
+ out.appendNull();
+ }
+ }
+
+ @AggregationFunction("regr_sxx")
+ @OutputFunction(StandardTypes.REAL)
+ public static void regrSxx(@AggregationState ExtendedRegressionState state, BlockBuilder out)
+ {
+ double result = getRegressionSxx(state);
+ double count = getRegressionCount(state);
+ if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
+ REAL.writeLong(out, floatToRawIntBits((float) result));
+ }
+ else {
+ out.appendNull();
+ }
+ }
+
+ @AggregationFunction("regr_syy")
+ @OutputFunction(StandardTypes.REAL)
+ public static void regrSyy(@AggregationState ExtendedRegressionState state, BlockBuilder out)
+ {
+ double result = getRegressionSyy(state);
+ double count = getRegressionCount(state);
+ if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
+ REAL.writeLong(out, floatToRawIntBits((float) result));
+ }
+ else {
+ out.appendNull();
+ }
+ }
+
+ @AggregationFunction("regr_r2")
+ @OutputFunction(StandardTypes.REAL)
+ public static void regrR2(@AggregationState ExtendedRegressionState state, BlockBuilder out)
+ {
+ double result = getRegressionR2(state);
+ if (Double.isFinite(result)) {
+ REAL.writeLong(out, floatToRawIntBits((float) result));
+ }
+ else {
+ out.appendNull();
+ }
+ }
+
+ @AggregationFunction("regr_count")
+ @OutputFunction(StandardTypes.REAL)
+ public static void regrCount(@AggregationState ExtendedRegressionState state, BlockBuilder out)
+ {
+ double result = getRegressionCount(state);
+ if (Double.isFinite(result) && result > 0) {
+ REAL.writeLong(out, floatToRawIntBits((float) result));
+ }
+ else {
+ out.appendNull();
+ }
+ }
+
+ @AggregationFunction("regr_avgy")
+ @OutputFunction(StandardTypes.REAL)
+ public static void regrAvgy(@AggregationState ExtendedRegressionState state, BlockBuilder out)
+ {
+ double result = getRegressionAvgy(state);
+ double count = getRegressionCount(state);
+ if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
+ REAL.writeLong(out, floatToRawIntBits((float) result));
+ }
+ else {
+ out.appendNull();
+ }
+ }
+
+ @AggregationFunction("regr_avgx")
+ @OutputFunction(StandardTypes.REAL)
+ public static void regrAvgx(@AggregationState ExtendedRegressionState state, BlockBuilder out)
+ {
+ double result = getRegressionAvgx(state);
+ double count = getRegressionCount(state);
+ if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
+ REAL.writeLong(out, floatToRawIntBits((float) result));
+ }
+ else {
+ out.appendNull();
+ }
+ }
+}
diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/ExtendedRegressionState.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/ExtendedRegressionState.java
new file mode 100644
index 0000000000000..64a9883174158
--- /dev/null
+++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/ExtendedRegressionState.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.operator.aggregation.state;
+
+public interface ExtendedRegressionState
+ extends RegressionState
+{
+ double getM2Y();
+
+ void setM2Y(double value);
+}
diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java
index 79837f90c0c11..ae3af6f46dc43 100644
--- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java
+++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java
@@ -19,8 +19,4 @@ public interface RegressionState
double getM2X();
void setM2X(double value);
-
- double getM2Y();
-
- void setM2Y(double value);
}
diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java
index f34359aeba6af..1c7856a62caea 100644
--- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java
+++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java
@@ -14,7 +14,6 @@
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.Session;
-import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.Varchars;
@@ -55,7 +54,6 @@
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
-import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE;
import static com.facebook.presto.metadata.CastType.CAST;
import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_NOT_FOUND;
import static com.facebook.presto.spi.StandardWarningCode.SAMPLED_FIELDS;
@@ -150,7 +148,7 @@ private PlanNode addSamplingFilter(PlanNode tableScanNode, Optional arguments)
- {
- FunctionHandle functionHandle = functionAndTypeManager.lookupFunction(qualifiedObjectName, fromTypes(arguments.stream().map(RowExpression::getType).collect(toImmutableList())));
- return call(String.valueOf(qualifiedObjectName), functionHandle, returnType, arguments);
- }
-
public static CallExpression call(FunctionAndTypeResolver functionAndTypeResolver, String name, Type returnType, RowExpression... arguments)
{
FunctionHandle functionHandle = functionAndTypeResolver.lookupFunction(name, fromTypes(Arrays.stream(arguments).map(RowExpression::getType).collect(toImmutableList())));
diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java
index a3fe25d7b65fa..337582257bad4 100644
--- a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java
+++ b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java
@@ -281,7 +281,7 @@ public boolean isEqualsFunction(FunctionHandle functionHandle)
@Override
public FunctionHandle subscriptFunction(Type baseType, Type indexType)
{
- return functionAndTypeResolver.lookupFunction(SUBSCRIPT.getFunctionName().getObjectName(), fromTypes(baseType, indexType));
+ return functionAndTypeResolver.resolveOperator(SUBSCRIPT, fromTypes(baseType, indexType));
}
@Override
diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java
index 3650f8cb36531..59a60b0024842 100644
--- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java
+++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java
@@ -29,7 +29,7 @@ private NativeQueryRunnerUtils() {}
public static Map getNativeWorkerHiveProperties()
{
return ImmutableMap.of("hive.parquet.pushdown-filter-enabled", "true",
- "hive.orc-compression-codec", "ZSTD", "hive.storage-format", "DWRF");
+ "hive.orc-compression-codec", "ZSTD", "hive.storage-format", "DWRF");
}
public static Map getNativeWorkerIcebergProperties()
@@ -59,6 +59,8 @@ public static Map getNativeSidecarProperties()
.put("coordinator-sidecar-enabled", "true")
.put("exclude-invalid-worker-session-properties", "true")
.put("presto.default-namespace", "native.default")
+ // inline-sql-functions is overridden to be true in sidecar enabled native clusters.
+ .put("inline-sql-functions", "true")
.build();
}
diff --git a/presto-native-execution/src/test/java/com/facebook/presto/spark/PrestoSparkNativeQueryRunnerUtils.java b/presto-native-execution/src/test/java/com/facebook/presto/spark/PrestoSparkNativeQueryRunnerUtils.java
index 14113b22d4fb3..21ef570381795 100644
--- a/presto-native-execution/src/test/java/com/facebook/presto/spark/PrestoSparkNativeQueryRunnerUtils.java
+++ b/presto-native-execution/src/test/java/com/facebook/presto/spark/PrestoSparkNativeQueryRunnerUtils.java
@@ -79,7 +79,6 @@ public static Map getNativeExecutionSparkConfigs()
.put("catalog.config-dir", "/")
.put("task.info-update-interval", "100ms")
.put("spark.initial-partition-count", "1")
- .put("register-test-functions", "true")
.put("native-execution-program-arguments", "--logtostderr=1 --minloglevel=3")
.put("spark.partition-count-auto-tune-enabled", "false");
diff --git a/presto-native-sidecar-plugin/pom.xml b/presto-native-sidecar-plugin/pom.xml
index b2bc40e8f2bd4..d04424440469a 100644
--- a/presto-native-sidecar-plugin/pom.xml
+++ b/presto-native-sidecar-plugin/pom.xml
@@ -260,9 +260,25 @@
+
com.facebook.presto
presto-built-in-worker-function-tools
+ ${project.version}
+
+
+
+ com.facebook.presto
+ presto-native-sql-invoked-functions-plugin
+ ${project.version}
+ test
+
+
+
+ com.facebook.presto
+ presto-sql-invoked-functions-plugin
+ ${project.version}
+ test
diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java
index 776d4920e2f16..c8c7e1123f974 100644
--- a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java
+++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java
@@ -13,6 +13,7 @@
*/
package com.facebook.presto.sidecar;
+import com.facebook.presto.scalar.sql.NativeSqlInvokedFunctionsPlugin;
import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManagerFactory;
import com.facebook.presto.sidecar.sessionpropertyproviders.NativeSystemSessionPropertyProviderFactory;
import com.facebook.presto.sidecar.typemanager.NativeTypeManagerFactory;
@@ -37,5 +38,6 @@ public static void setupNativeSidecarPlugin(QueryRunner queryRunner)
"function-implementation-type", "CPP"));
queryRunner.loadTypeManager(NativeTypeManagerFactory.NAME);
queryRunner.loadPlanCheckerProviderManager("native", ImmutableMap.of());
+ queryRunner.installPlugin(new NativeSqlInvokedFunctionsPlugin());
}
}
diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java
index 1c5d05d55b274..fe0b18c24b2fb 100644
--- a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java
+++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java
@@ -14,7 +14,10 @@
package com.facebook.presto.sidecar;
import com.facebook.airlift.units.DataSize;
+import com.facebook.presto.Session;
import com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils;
+import com.facebook.presto.scalar.sql.NativeSqlInvokedFunctionsPlugin;
+import com.facebook.presto.scalar.sql.SqlInvokedFunctionsPlugin;
import com.facebook.presto.sidecar.functionNamespace.FunctionDefinitionProvider;
import com.facebook.presto.sidecar.functionNamespace.NativeFunctionDefinitionProvider;
import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManager;
@@ -44,8 +47,12 @@
import java.util.stream.Collectors;
import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE;
+import static com.facebook.presto.SystemSessionProperties.INLINE_SQL_FUNCTIONS;
+import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_ENABLED;
+import static com.facebook.presto.SystemSessionProperties.REMOVE_MAP_CAST;
import static com.facebook.presto.common.Utils.checkArgument;
import static com.facebook.presto.common.type.BigintType.BIGINT;
+import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createCustomer;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createLineitem;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createNation;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrders;
@@ -63,6 +70,7 @@ public class TestNativeSidecarPlugin
private static final String REGEX_FUNCTION_NAMESPACE = "native.default.*";
private static final String REGEX_SESSION_NAMESPACE = "Native Execution only.*";
private static final long SIDECAR_HTTP_CLIENT_MAX_CONTENT_SIZE_MB = 128;
+ private static final int INLINED_SQL_FUNCTIONS_COUNT = 7;
@Override
protected void createTables()
@@ -73,6 +81,7 @@ protected void createTables()
createOrders(queryRunner);
createOrdersEx(queryRunner);
createRegion(queryRunner);
+ createCustomer(queryRunner);
}
@Override
@@ -91,9 +100,11 @@ protected QueryRunner createQueryRunner()
protected QueryRunner createExpectedQueryRunner()
throws Exception
{
- return PrestoNativeQueryRunnerUtils.javaHiveQueryRunnerBuilder()
+ QueryRunner queryRunner = PrestoNativeQueryRunnerUtils.javaHiveQueryRunnerBuilder()
.setAddStorageFormatToPath(true)
.build();
+ queryRunner.installPlugin(new SqlInvokedFunctionsPlugin());
+ return queryRunner;
}
public static void setupNativeSidecarPlugin(QueryRunner queryRunner)
@@ -111,6 +122,7 @@ public static void setupNativeSidecarPlugin(QueryRunner queryRunner)
"sidecar.http-client.max-content-length", SIDECAR_HTTP_CLIENT_MAX_CONTENT_SIZE_MB + "MB"));
queryRunner.loadTypeManager(NativeTypeManagerFactory.NAME);
queryRunner.loadPlanCheckerProviderManager("native", ImmutableMap.of());
+ queryRunner.installPlugin(new NativeSqlInvokedFunctionsPlugin());
}
@Test
@@ -161,6 +173,7 @@ public void testSetNativeWorkerSessionProperty()
@Test
public void testShowFunctions()
{
+ int inlinedSQLFunctionsCount = 0;
@Language("SQL") String sql = "SHOW FUNCTIONS";
MaterializedResult actualResult = computeActual(sql);
List actualRows = actualResult.getMaterializedRows();
@@ -174,11 +187,17 @@ public void testShowFunctions()
// function namespace should be present.
String fullFunctionName = row.get(5).toString();
- if (Pattern.matches(REGEX_FUNCTION_NAMESPACE, fullFunctionName)) {
- continue;
+ if (!Pattern.matches(REGEX_FUNCTION_NAMESPACE, fullFunctionName)) {
+ // If no namespace match found, check if it's an inlined SQL Invoked function.
+ String language = row.get(9).toString();
+ if (language.equalsIgnoreCase("SQL")) {
+ inlinedSQLFunctionsCount++;
+ continue;
+ }
+ fail(format("No namespace match found for row: %s", row));
}
- fail(format("No namespace match found for row: %s", row));
}
+ assertEquals(inlinedSQLFunctionsCount, INLINED_SQL_FUNCTIONS_COUNT);
}
@Test
@@ -319,7 +338,7 @@ public void testApproxPercentile()
public void testInformationSchemaTables()
{
assertQuery("select lower(table_name) from information_schema.tables "
- + "where table_name = 'lineitem' or table_name = 'LINEITEM' ");
+ + "where table_name = 'lineitem' or table_name = 'LINEITEM' ");
}
@Test
@@ -405,6 +424,121 @@ public void testGeometryQueries()
"Error from native plan checker: .SpatialJoinNode no abstract type PlanNode ");
}
+ @Test
+ public void testRemoveMapCast()
+ {
+ Session enableOptimization = Session.builder(getSession())
+ .setSystemProperty(REMOVE_MAP_CAST, "true")
+ .build();
+ assertQuery(enableOptimization, "select feature[key] from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 4)) t(feature, key)",
+ "values 0.5, 0.1");
+ assertQuery(enableOptimization, "select element_at(feature, key) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 4)) t(feature, key)",
+ "values 0.5, 0.1");
+ assertQuery(enableOptimization, "select element_at(feature, key) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 400000000000)) t(feature, key)",
+ "values 0.5, null");
+ assertQuery(enableOptimization, "select feature[key] from (values (map(array[cast(1 as varchar), '2', '3', '4'], array[0.3, 0.5, 0.9, 0.1]), cast('2' as varchar)), (map(array[cast(1 as varchar), '2', '3', '4'], array[0.3, 0.5, 0.9, 0.1]), '4')) t(feature, key)",
+ "values 0.5, 0.1");
+ }
+
+ @Test
+ public void testOverriddenInlinedSqlInvokedFunctions()
+ {
+ // String functions
+ assertQuery("SELECT trail(comment, cast(nationkey as integer)) FROM nation");
+ assertQuery("SELECT name, comment, replace_first(comment, 'iron', 'gold') from nation");
+
+ // Array functions
+ assertQuery("SELECT array_intersect(ARRAY['apple', 'banana', 'cherry'], ARRAY['apple', 'mango', 'fig'])");
+ assertQuery("SELECT array_frequency(split(comment, '')) from nation");
+ assertQuery("SELECT array_duplicates(ARRAY[regionkey]), array_duplicates(ARRAY[comment]) from nation");
+ assertQuery("SELECT array_has_duplicates(ARRAY[custkey]) from orders");
+ assertQuery("SELECT array_max_by(ARRAY[comment], x -> length(x)) from orders");
+ assertQuery("SELECT array_min_by(ARRAY[ROW('USA', 1), ROW('INDIA', 2), ROW('UK', 3)], x -> x[2])");
+ assertQuery("SELECT array_sort_desc(map_keys(map_union(quantity_by_linenumber))) FROM orders_ex");
+ assertQuery("SELECT remove_nulls(ARRAY[CAST(regionkey AS VARCHAR), comment, NULL]) from nation");
+ assertQuery("SELECT array_top_n(ARRAY[CAST(nationkey AS VARCHAR)], 3) from nation");
+ assertQuerySucceeds("SELECT array_sort_desc(quantities, x -> abs(x)) FROM orders_ex");
+
+ // Map functions
+ assertQuery("SELECT map_normalize(MAP(ARRAY['a', 'b', 'c'], ARRAY[1, 4, 5]))");
+ assertQuery("SELECT map_normalize(MAP(ARRAY['a', 'b', 'c'], ARRAY[1, 0, -1]))");
+ assertQuery("SELECT name, map_normalize(MAP(ARRAY['regionkey', 'length'], ARRAY[regionkey, length(comment)])) from nation");
+ assertQuery("SELECT name, map_remove_null_values(map(ARRAY['region', 'comment', 'nullable'], " +
+ "ARRAY[CAST(regionkey AS VARCHAR), comment, NULL])) from nation");
+ assertQuery("SELECT name, map_key_exists(map(ARRAY['nation', 'comment'], ARRAY[CAST(nationkey AS VARCHAR), comment]), 'comment') from nation");
+ assertQuery("SELECT map_keys_by_top_n_values(MAP(ARRAY[orderkey], ARRAY[custkey]), 2) from orders");
+ assertQuery("SELECT map_top_n(MAP(ARRAY[CAST(nationkey AS VARCHAR)], ARRAY[comment]), 3) from nation");
+ assertQuery("SELECT map_top_n_keys(MAP(ARRAY[orderkey], ARRAY[custkey]), 3) from orders");
+ assertQuery("SELECT map_top_n_values(MAP(ARRAY[orderkey], ARRAY[custkey]), 3) from orders");
+ assertQuery("SELECT all_keys_match(MAP(ARRAY[comment], ARRAY[custkey]), k -> length(k) > 5) from orders");
+ assertQuery("SELECT any_keys_match(MAP(ARRAY[comment], ARRAY[custkey]), k -> starts_with(k, 'abc')) from orders");
+ assertQuery("SELECT any_values_match(MAP(ARRAY[orderkey], ARRAY[totalprice]), k -> abs(k) > 20) from orders");
+ assertQuery("SELECT no_values_match(MAP(ARRAY[orderkey], ARRAY[comment]), k -> length(k) > 2) from orders");
+ assertQuery("SELECT no_keys_match(MAP(ARRAY[comment], ARRAY[custkey]), k -> ends_with(k, 'a')) from orders");
+ }
+
+ @Test
+ public void testNonOverriddenInlinedSqlInvokedFunctionsWhenConfigEnabled()
+ {
+ // Array functions
+ assertQuery("SELECT array_split_into_chunks(split(comment, ''), 2) from nation");
+ assertQuery("SELECT array_least_frequent(quantities) from orders_ex");
+ assertQuery("SELECT array_least_frequent(split(comment, ''), 5) from nation");
+ assertQuerySucceeds("SELECT array_top_n(ARRAY[orderkey], 25, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from orders");
+
+ // Map functions
+ assertQuerySucceeds("SELECT map_top_n_values(MAP(ARRAY[comment], ARRAY[nationkey]), 2, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation");
+ assertQuerySucceeds("SELECT map_top_n_keys(MAP(ARRAY[regionkey], ARRAY[nationkey]), 5, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation");
+
+ Session sessionWithKeyBasedSampling = Session.builder(getSession())
+ .setSystemProperty(KEY_BASED_SAMPLING_ENABLED, "true")
+ .build();
+
+ @Language("SQL") String query = "select count(1) FROM lineitem l left JOIN orders o ON l.orderkey = o.orderkey JOIN customer c ON o.custkey = c.custkey";
+
+ assertQuery(query, "select cast(60175 as bigint)");
+ assertQuery(sessionWithKeyBasedSampling, query, "select cast(16185 as bigint)");
+ }
+
+ @Test
+ public void testNonOverriddenInlinedSqlInvokedFunctionsWhenConfigDisabled()
+ {
+ // When inline_sql_functions is set to false, the below queries should fail as the implementations don't exist on the native worker
+ Session session = Session.builder(getSession())
+ .setSystemProperty(KEY_BASED_SAMPLING_ENABLED, "true")
+ .setSystemProperty(INLINE_SQL_FUNCTIONS, "false")
+ .build();
+
+ // Array functions
+ assertQueryFails(session,
+ "SELECT array_split_into_chunks(split(comment, ''), 2) from nation",
+ ".*Scalar function name not registered: native.default.array_split_into_chunks.*");
+ assertQueryFails(session,
+ "SELECT array_least_frequent(quantities) from orders_ex",
+ ".*Scalar function name not registered: native.default.array_least_frequent.*");
+ assertQueryFails(session,
+ "SELECT array_least_frequent(split(comment, ''), 2) from nation",
+ ".*Scalar function name not registered: native.default.array_least_frequent.*");
+ assertQueryFails(session,
+ "SELECT array_top_n(ARRAY[orderkey], 25, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from orders",
+ " Scalar function native\\.default\\.array_top_n not registered with arguments.*",
+ true);
+
+ // Map functions
+ assertQueryFails(session,
+ "SELECT map_top_n_values(MAP(ARRAY[comment], ARRAY[nationkey]), 2, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation",
+ ".*Scalar function native\\.default\\.map_top_n_values not registered with arguments.*",
+ true);
+ assertQueryFails(session,
+ "SELECT map_top_n_keys(MAP(ARRAY[regionkey], ARRAY[nationkey]), 5, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) from nation",
+ ".*Scalar function native\\.default\\.map_top_n_keys not registered with arguments.*",
+ true);
+
+ assertQueryFails(session,
+ "select count(1) FROM lineitem l left JOIN orders o ON l.orderkey = o.orderkey JOIN customer c ON o.custkey = c.custkey",
+ ".*Scalar function name not registered: native.default.key_sampling_percent.*");
+ }
+
private String generateRandomTableName()
{
String tableName = "tmp_presto_" + UUID.randomUUID().toString().replace("-", "");
diff --git a/presto-native-sql-invoked-functions-plugin/pom.xml b/presto-native-sql-invoked-functions-plugin/pom.xml
new file mode 100644
index 0000000000000..7d837a28a8bfb
--- /dev/null
+++ b/presto-native-sql-invoked-functions-plugin/pom.xml
@@ -0,0 +1,29 @@
+
+ 4.0.0
+
+ com.facebook.presto
+ presto-root
+ 0.295-SNAPSHOT
+
+
+ presto-native-sql-invoked-functions-plugin
+ Presto Native - Sql invoked functions plugin
+ presto-plugin
+
+
+ ${project.parent.basedir}
+
+
+
+
+ com.facebook.presto
+ presto-spi
+ provided
+
+
+ com.google.guava
+ guava
+
+
+
diff --git a/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeArraySqlFunctions.java b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeArraySqlFunctions.java
new file mode 100644
index 0000000000000..841883d99ae8f
--- /dev/null
+++ b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeArraySqlFunctions.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.scalar.sql;
+
+import com.facebook.presto.spi.function.Description;
+import com.facebook.presto.spi.function.SqlInvokedScalarFunction;
+import com.facebook.presto.spi.function.SqlParameter;
+import com.facebook.presto.spi.function.SqlParameters;
+import com.facebook.presto.spi.function.SqlType;
+import com.facebook.presto.spi.function.TypeParameter;
+
+public class NativeArraySqlFunctions
+{
+ private NativeArraySqlFunctions() {}
+
+ @SqlInvokedScalarFunction(value = "array_split_into_chunks", deterministic = true, calledOnNullInput = false)
+ @Description("Returns an array of arrays splitting input array into chunks of given length. " +
+ "If array is not evenly divisible it will split into as many possible chunks and " +
+ "return the left over elements for the last array. Returns null for null inputs, but not elements.")
+ @TypeParameter("T")
+ @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "sz", type = "int")})
+ @SqlType("array(array(T))")
+ public static String arraySplitIntoChunks()
+ {
+ return "RETURN IF(sz <= 0, " +
+ "fail('Invalid slice size: ' || cast(sz as varchar) || '. Size must be greater than zero.'), " +
+ "IF(cardinality(input) / sz > 10000, " +
+ "fail('Cannot split array of size: ' || cast(cardinality(input) as varchar) || ' into more than 10000 parts.'), " +
+ "transform(" +
+ "sequence(1, cardinality(input), sz), " +
+ "x -> slice(input, x, sz))))";
+ }
+
+ @SqlInvokedScalarFunction(value = "array_least_frequent", deterministic = true, calledOnNullInput = true)
+ @Description("Determines the least frequent element in the array. If there are multiple elements, the function returns the smallest element")
+ @TypeParameter("T")
+ @SqlParameter(name = "input", type = "array(T)")
+ @SqlType("array")
+ public static String arrayLeastFrequent()
+ {
+ return "RETURN IF(COALESCE(CARDINALITY(REMOVE_NULLS(input)), 0) = 0, NULL, TRANSFORM(SLICE(ARRAY_SORT(TRANSFORM(MAP_ENTRIES(ARRAY_FREQUENCY(REMOVE_NULLS(input))), x -> ROW(x[2], x[1]))), 1, 1), x -> x[2]))";
+ }
+
+ @SqlInvokedScalarFunction(value = "array_least_frequent", deterministic = true, calledOnNullInput = true)
+ @Description("Determines the n least frequent element in the array in the ascending order of the elements.")
+ @TypeParameter("T")
+ @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "n", type = "bigint")})
+ @SqlType("array")
+ public static String arrayNLeastFrequent()
+ {
+ return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), IF(COALESCE(CARDINALITY(REMOVE_NULLS(input)), 0) = 0, NULL, TRANSFORM(SLICE(ARRAY_SORT(TRANSFORM(MAP_ENTRIES(ARRAY_FREQUENCY(REMOVE_NULLS(input))), x -> ROW(x[2], x[1]))), 1, n), x -> x[2])))";
+ }
+
+ @SqlInvokedScalarFunction(value = "array_top_n", deterministic = true, calledOnNullInput = true)
+ @Description("Returns the top N values of the given map sorted using the provided lambda comparator.")
+ @TypeParameter("T")
+ @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "n", type = "int"), @SqlParameter(name = "f", type = "function(T, T, bigint)")})
+ @SqlType("array")
+ public static String arrayTopNComparator()
+ {
+ return "RETURN IF(n < 0, fail('Parameter n: ' || cast(n as varchar) || ' to ARRAY_TOP_N is negative'), SLICE(REVERSE(ARRAY_SORT(input, f)), 1, n))";
+ }
+}
diff --git a/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeMapSqlFunctions.java b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeMapSqlFunctions.java
new file mode 100644
index 0000000000000..9eccc84d6d8c8
--- /dev/null
+++ b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeMapSqlFunctions.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.scalar.sql;
+
+import com.facebook.presto.spi.function.Description;
+import com.facebook.presto.spi.function.SqlInvokedScalarFunction;
+import com.facebook.presto.spi.function.SqlParameter;
+import com.facebook.presto.spi.function.SqlParameters;
+import com.facebook.presto.spi.function.SqlType;
+import com.facebook.presto.spi.function.TypeParameter;
+
+public class NativeMapSqlFunctions
+{
+ private NativeMapSqlFunctions() {}
+
+ @SqlInvokedScalarFunction(value = "map_top_n_keys", deterministic = true, calledOnNullInput = true)
+ @Description("Returns the top N keys of the given map sorting its keys using the provided lambda comparator.")
+ @TypeParameter("K")
+ @TypeParameter("V")
+ @SqlParameters({@SqlParameter(name = "input", type = "map(K, V)"), @SqlParameter(name = "n", type = "bigint"), @SqlParameter(name = "f", type = "function(K, K, bigint)")})
+ @SqlType("array")
+ public static String mapTopNKeysComparator()
+ {
+ return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), slice(reverse(array_sort(map_keys(input), f)), 1, n))";
+ }
+
+ @SqlInvokedScalarFunction(value = "map_top_n_values", deterministic = true, calledOnNullInput = true)
+ @Description("Returns the top N values of the given map sorted using the provided lambda comparator.")
+ @TypeParameter("K")
+ @TypeParameter("V")
+ @SqlParameters({@SqlParameter(name = "input", type = "map(K, V)"), @SqlParameter(name = "n", type = "bigint"), @SqlParameter(name = "f", type = "function(V, V, bigint)")})
+ @SqlType("array")
+ public static String mapTopNValuesComparator()
+ {
+ return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), slice(reverse(array_sort(remove_nulls(map_values(input)), f)) || filter(map_values(input), x -> x is null), 1, n))";
+ }
+}
diff --git a/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSimpleSamplingPercent.java b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSimpleSamplingPercent.java
new file mode 100644
index 0000000000000..a710391760714
--- /dev/null
+++ b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSimpleSamplingPercent.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.scalar.sql;
+
+import com.facebook.presto.spi.function.Description;
+import com.facebook.presto.spi.function.SqlInvokedScalarFunction;
+import com.facebook.presto.spi.function.SqlParameter;
+import com.facebook.presto.spi.function.SqlType;
+
+public class NativeSimpleSamplingPercent
+{
+ private NativeSimpleSamplingPercent() {}
+
+ @SqlInvokedScalarFunction(value = "key_sampling_percent", deterministic = true, calledOnNullInput = false)
+ @Description("Returns a value between 0.0 and 1.0 using the hash of the given input string")
+ @SqlParameter(name = "input", type = "varchar")
+ @SqlType("double")
+ public static String keySamplingPercent()
+ {
+ return "return (abs(from_ieee754_64(xxhash64(cast(input as varbinary)))) % 100) / 100. ";
+ }
+}
diff --git a/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSqlInvokedFunctionsPlugin.java b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSqlInvokedFunctionsPlugin.java
new file mode 100644
index 0000000000000..69d7ff1e78522
--- /dev/null
+++ b/presto-native-sql-invoked-functions-plugin/src/main/java/com/facebook/presto/scalar/sql/NativeSqlInvokedFunctionsPlugin.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.scalar.sql;
+
+import com.facebook.presto.spi.Plugin;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.Set;
+
+public class NativeSqlInvokedFunctionsPlugin
+ implements Plugin
+{
+ @Override
+ public Set> getSqlInvokedFunctions()
+ {
+ return ImmutableSet.>builder()
+ .add(NativeArraySqlFunctions.class)
+ .add(NativeMapSqlFunctions.class)
+ .add(NativeSimpleSamplingPercent.class)
+ .build();
+ }
+}
diff --git a/presto-native-tests/pom.xml b/presto-native-tests/pom.xml
index e90785cbb6940..7e23670117b1d 100644
--- a/presto-native-tests/pom.xml
+++ b/presto-native-tests/pom.xml
@@ -192,6 +192,13 @@
units
test
+
+
+ com.facebook.presto
+ presto-native-sql-invoked-functions-plugin
+ ${project.version}
+ test
+
diff --git a/presto-plan-checker-router-plugin/pom.xml b/presto-plan-checker-router-plugin/pom.xml
index 60e34e4147d19..db1c4b0f1736c 100644
--- a/presto-plan-checker-router-plugin/pom.xml
+++ b/presto-plan-checker-router-plugin/pom.xml
@@ -223,6 +223,13 @@
presto-hive-metastore
test
+
+
+ com.facebook.presto
+ presto-native-sql-invoked-functions-plugin
+ ${project.version}
+ test
+
diff --git a/presto-product-tests/conf/docker/common/compose-commons.sh b/presto-product-tests/conf/docker/common/compose-commons.sh
index eae9f18ce9583..5c20783716b60 100644
--- a/presto-product-tests/conf/docker/common/compose-commons.sh
+++ b/presto-product-tests/conf/docker/common/compose-commons.sh
@@ -39,6 +39,16 @@ if [[ -z "${PRESTO_SERVER_DIR:-}" ]]; then
source "${PRODUCT_TESTS_ROOT}/target/classes/presto.env"
PRESTO_SERVER_DIR="${PROJECT_ROOT}/presto-server/target/presto-server-${PRESTO_VERSION}/"
fi
+
+# The following plugin results in a function signature conflict when loaded in Java/ sidecar disabled native clusters.
+# This plugin is only meant for sidecar enabled native clusters, hence exclude it.
+PLUGIN_TO_EXCLUDE="native-sql-invoked-functions-plugin"
+
+if [[ -d "${PRESTO_SERVER_DIR}/plugin/${PLUGIN_TO_EXCLUDE}" ]]; then
+ echo "Excluding plugin: $PLUGIN_TO_EXCLUDE"
+ rm -rf "${PRESTO_SERVER_DIR}/plugin/${PLUGIN_TO_EXCLUDE}"
+fi
+
export_canonical_path PRESTO_SERVER_DIR
if [[ -z "${PRESTO_CLI_JAR:-}" ]]; then
diff --git a/presto-server/src/main/provisio/presto.xml b/presto-server/src/main/provisio/presto.xml
index b14b36a768e69..d15a041c7d1f5 100644
--- a/presto-server/src/main/provisio/presto.xml
+++ b/presto-server/src/main/provisio/presto.xml
@@ -292,4 +292,10 @@
+
+
+
+
+
+
diff --git a/presto-spark-base/pom.xml b/presto-spark-base/pom.xml
index 37dc1d4fb541b..7767219a7ec65 100644
--- a/presto-spark-base/pom.xml
+++ b/presto-spark-base/pom.xml
@@ -16,7 +16,6 @@
9.4.55.v20240627
4.12.0
3.9.1
- 2
@@ -55,12 +54,6 @@
provided
-
- com.facebook.presto
- presto-spark-classloader-spark${dep.pos.classloader.module-name.suffix}
- provided
-
-
com.facebook.presto
presto-client
@@ -538,6 +531,25 @@
+
+ spark2
+
+
+ true
+
+ !spark-version
+
+
+
+
+
+ com.facebook.presto
+ presto-spark-classloader-spark2
+ ${project.version}
+
+
+
+
spark3
@@ -548,11 +560,13 @@
-
- 3
-
-
+
+ com.facebook.presto
+ presto-spark-classloader-spark3
+ ${project.version}
+
+
com.facebook.presto.spark
spark-core
diff --git a/presto-spark-classloader-interface/pom.xml b/presto-spark-classloader-interface/pom.xml
index 8f93be12818a4..a95e45827c5f1 100644
--- a/presto-spark-classloader-interface/pom.xml
+++ b/presto-spark-classloader-interface/pom.xml
@@ -13,7 +13,6 @@
${project.parent.basedir}
true
- 2
@@ -23,11 +22,6 @@
provided
-
- com.facebook.presto
- presto-spark-classloader-spark${dep.pos.classloader.module-name.suffix}
-
-
com.google.guava
guava
@@ -40,6 +34,25 @@
+
+ spark2
+
+
+ true
+
+ !spark-version
+
+
+
+
+
+ com.facebook.presto
+ presto-spark-classloader-spark2
+ ${project.version}
+
+
+
+
spark3
@@ -50,22 +63,12 @@
-
- 3
-
-
-
-
-
- com.facebook.presto.spark
- spark-core
- 3.4.1-1
- compile
-
-
-
-
+
+ com.facebook.presto
+ presto-spark-classloader-spark3
+ ${project.version}
+
org.scala-lang
scala-library