Skip to content

Commit 8c7f2a8

Browse files
[presto] Move out M2Y from RegressionState for regr_slope and regr_intercept functions (#25475) (#25748)
Summary: ## Context Currently we don't enforce intermediate/return type are the same in Coordinator and Prestissimo Worker. Velox creates vectors for intermediate/return results based on a plan that comes from Coordinator. Then Prestissimo tries to use those vector and not crash. In practise we had a crash some time ago due to such a mismatch (D74199165). And I added validation to Velox to catch such kind of mismatches early: facebookincubator/velox#13322 But we wasn't able to enable it in prod, because the validation failed for "regr_slope" and "regr_intercept" functions. ## What's changed? In this diff I'm fixing "regr_slope" and "regr_intercept" intermediate types. Basically in Java `AggregationState` for all these functions is the same: ``` AggregationFunction("regr_slope") AggregationFunction("regr_intercept") AggregationFunction("regr_sxy") AggregationFunction("regr_sxx") AggregationFunction("regr_syy") AggregationFunction("regr_r2") AggregationFunction("regr_count") AggregationFunction("regr_avgy") AggregationFunction("regr_avgx") ``` But in Prestissimo the state storage is more optimal: ``` AggregationFunction("regr_slope") AggregationFunction("regr_intercept") ``` These 2 aggregation functions don't have M2Y field. And this is more efficient, because we don't waste memory and CPU on the field, that aren't needed. So I moved M2Y to extended class, the same as it works in Velox: https://github.com/facebookincubator/velox/blob/main/velox/functions/prestosql/aggregates/CovarianceAggregates.cpp?fbclid=IwY2xjawLRTetleHRuA2FlbQIxMQBicmlkETFiT0N3UFR0M2VKOHl6MHRhAR6KRQ1VUQdCkZXzwj14sMQrVZ-R9QBH1utuGJb5U_lyGzDwt8PwV317QRVNJg_aem_-ePxZ-fHO5MNgfUmayVJFA#L326-L337 No major changes, mostly just reorganized the code. ## Test plan I tested `REGR_SLOPE`, `REGR_INTERCEPT` and `REGR_R2` functions since they are heavily used in prod and cover both cases: with and without M2Y field. What my test looked like. For all 3 `REGR_*` functions I found some prod queries, then: 1. Ran them on prev Java build 2. Ran them on new (with this PR) Java build 3. Ran them on prev Prestissimo build 4. Ran them on new (with this PR) Prestissimo build And compared the output results. They all were identical. With this manual test we covered `Coordinator -> Java Worker` and `Coordinator -> Prestissimo Worker` integrations. ## Next steps In this diff I'm trying to apply the same optimization to Java. With this fix, the signatures will become the same in Java and Prestissimo and we will be able to enable the validation Differential Revision: D77625566 == NO RELEASE NOTES ==
1 parent 9c5004f commit 8c7f2a8

File tree

8 files changed

+345
-214
lines changed

8 files changed

+345
-214
lines changed

presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
import com.facebook.presto.operator.aggregation.DoubleCovarianceAggregation;
6868
import com.facebook.presto.operator.aggregation.DoubleHistogramAggregation;
6969
import com.facebook.presto.operator.aggregation.DoubleRegressionAggregation;
70+
import com.facebook.presto.operator.aggregation.DoubleRegressionExtendedAggregation;
7071
import com.facebook.presto.operator.aggregation.DoubleSumAggregation;
7172
import com.facebook.presto.operator.aggregation.EntropyAggregation;
7273
import com.facebook.presto.operator.aggregation.GeometricMeanAggregations;
@@ -84,6 +85,7 @@
8485
import com.facebook.presto.operator.aggregation.RealGeometricMeanAggregations;
8586
import com.facebook.presto.operator.aggregation.RealHistogramAggregation;
8687
import com.facebook.presto.operator.aggregation.RealRegressionAggregation;
88+
import com.facebook.presto.operator.aggregation.RealRegressionExtendedAggregation;
8789
import com.facebook.presto.operator.aggregation.RealSumAggregation;
8890
import com.facebook.presto.operator.aggregation.ReduceAggregationFunction;
8991
import com.facebook.presto.operator.aggregation.SumDataSizeForStats;
@@ -744,7 +746,9 @@ private List<? extends SqlFunction> getBuiltInFunctions(FunctionsConfig function
744746
.aggregates(DoubleCovarianceAggregation.class)
745747
.aggregates(RealCovarianceAggregation.class)
746748
.aggregates(DoubleRegressionAggregation.class)
749+
.aggregates(DoubleRegressionExtendedAggregation.class)
747750
.aggregates(RealRegressionAggregation.class)
751+
.aggregates(RealRegressionExtendedAggregation.class)
748752
.aggregates(DoubleCorrelationAggregation.class)
749753
.aggregates(RealCorrelationAggregation.class)
750754
.aggregates(BitwiseOrAggregation.class)

presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import com.facebook.presto.operator.aggregation.state.CentralMomentsState;
2323
import com.facebook.presto.operator.aggregation.state.CorrelationState;
2424
import com.facebook.presto.operator.aggregation.state.CovarianceState;
25+
import com.facebook.presto.operator.aggregation.state.ExtendedRegressionState;
2526
import com.facebook.presto.operator.aggregation.state.RegressionState;
2627
import com.facebook.presto.operator.aggregation.state.VarianceState;
2728
import com.facebook.presto.spi.function.AggregationFunctionImplementation;
@@ -145,9 +146,14 @@ public static double getCorrelation(CorrelationState state)
145146
public static void updateRegressionState(RegressionState state, double x, double y)
146147
{
147148
double oldMeanX = state.getMeanX();
148-
double oldMeanY = state.getMeanY();
149149
updateCovarianceState(state, x, y);
150150
state.setM2X(state.getM2X() + (x - oldMeanX) * (x - state.getMeanX()));
151+
}
152+
153+
public static void updateExtendedRegressionState(ExtendedRegressionState state, double x, double y)
154+
{
155+
double oldMeanY = state.getMeanY();
156+
updateRegressionState(state, x, y);
151157
state.setM2Y(state.getM2Y() + (y - oldMeanY) * (y - state.getMeanY()));
152158
}
153159

@@ -189,12 +195,12 @@ public static double getRegressionSxy(RegressionState state)
189195
return state.getC2();
190196
}
191197

192-
public static double getRegressionSyy(RegressionState state)
198+
public static double getRegressionSyy(ExtendedRegressionState state)
193199
{
194200
return state.getM2Y();
195201
}
196202

197-
public static double getRegressionR2(RegressionState state)
203+
public static double getRegressionR2(ExtendedRegressionState state)
198204
{
199205
if (state.getM2X() != 0 && state.getM2Y() == 0) {
200206
return 1.0;
@@ -311,10 +317,21 @@ public static void mergeRegressionState(RegressionState state, RegressionState o
311317
long na = state.getCount();
312318
long nb = otherState.getCount();
313319
state.setM2X(state.getM2X() + otherState.getM2X() + na * nb * Math.pow(state.getMeanX() - otherState.getMeanX(), 2) / (double) (na + nb));
314-
state.setM2Y(state.getM2Y() + otherState.getM2Y() + na * nb * Math.pow(state.getMeanY() - otherState.getMeanY(), 2) / (double) (na + nb));
315320
updateCovarianceState(state, otherState);
316321
}
317322

323+
public static void mergeExtendedRegressionState(ExtendedRegressionState state, ExtendedRegressionState otherState)
324+
{
325+
if (otherState.getCount() == 0) {
326+
return;
327+
}
328+
329+
long na = state.getCount();
330+
long nb = otherState.getCount();
331+
state.setM2Y(state.getM2Y() + otherState.getM2Y() + na * nb * Math.pow(state.getMeanY() - otherState.getMeanY(), 2) / (double) (na + nb));
332+
mergeRegressionState(state, otherState);
333+
}
334+
318335
public static String generateAggregationName(String baseName, TypeSignature outputType, List<TypeSignature> inputTypes)
319336
{
320337
StringBuilder sb = new StringBuilder();

presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java

Lines changed: 0 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,8 @@
2424
import com.facebook.presto.spi.function.SqlType;
2525

2626
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
27-
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx;
28-
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy;
29-
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount;
3027
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionIntercept;
31-
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2;
3228
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSlope;
33-
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx;
34-
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy;
35-
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy;
3629
import static com.facebook.presto.operator.aggregation.AggregationUtils.mergeRegressionState;
3730
import static com.facebook.presto.operator.aggregation.AggregationUtils.updateRegressionState;
3831

@@ -78,100 +71,4 @@ public static void regrIntercept(@AggregationState RegressionState state, BlockB
7871
out.appendNull();
7972
}
8073
}
81-
82-
@AggregationFunction("regr_sxy")
83-
@OutputFunction(StandardTypes.DOUBLE)
84-
public static void regrSxy(@AggregationState RegressionState state, BlockBuilder out)
85-
{
86-
double result = getRegressionSxy(state);
87-
double count = getRegressionCount(state);
88-
if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
89-
DOUBLE.writeDouble(out, result);
90-
}
91-
else {
92-
out.appendNull();
93-
}
94-
}
95-
96-
@AggregationFunction("regr_sxx")
97-
@OutputFunction(StandardTypes.DOUBLE)
98-
public static void regrSxx(@AggregationState RegressionState state, BlockBuilder out)
99-
{
100-
double result = getRegressionSxx(state);
101-
double count = getRegressionCount(state);
102-
if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
103-
DOUBLE.writeDouble(out, result);
104-
}
105-
else {
106-
out.appendNull();
107-
}
108-
}
109-
110-
@AggregationFunction("regr_syy")
111-
@OutputFunction(StandardTypes.DOUBLE)
112-
public static void regrSyy(@AggregationState RegressionState state, BlockBuilder out)
113-
{
114-
double result = getRegressionSyy(state);
115-
double count = getRegressionCount(state);
116-
if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
117-
DOUBLE.writeDouble(out, result);
118-
}
119-
else {
120-
out.appendNull();
121-
}
122-
}
123-
124-
@AggregationFunction("regr_r2")
125-
@OutputFunction(StandardTypes.DOUBLE)
126-
public static void regrR2(@AggregationState RegressionState state, BlockBuilder out)
127-
{
128-
double result = getRegressionR2(state);
129-
if (Double.isFinite(result)) {
130-
DOUBLE.writeDouble(out, result);
131-
}
132-
else {
133-
out.appendNull();
134-
}
135-
}
136-
137-
@AggregationFunction("regr_count")
138-
@OutputFunction(StandardTypes.DOUBLE)
139-
public static void regrCount(@AggregationState RegressionState state, BlockBuilder out)
140-
{
141-
double result = getRegressionCount(state);
142-
if (Double.isFinite(result) && result > 0) {
143-
DOUBLE.writeDouble(out, result);
144-
}
145-
else {
146-
out.appendNull();
147-
}
148-
}
149-
150-
@AggregationFunction("regr_avgy")
151-
@OutputFunction(StandardTypes.DOUBLE)
152-
public static void regrAvgy(@AggregationState RegressionState state, BlockBuilder out)
153-
{
154-
double result = getRegressionAvgy(state);
155-
double count = getRegressionCount(state);
156-
if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
157-
DOUBLE.writeDouble(out, result);
158-
}
159-
else {
160-
out.appendNull();
161-
}
162-
}
163-
164-
@AggregationFunction("regr_avgx")
165-
@OutputFunction(StandardTypes.DOUBLE)
166-
public static void regrAvgx(@AggregationState RegressionState state, BlockBuilder out)
167-
{
168-
double result = getRegressionAvgx(state);
169-
double count = getRegressionCount(state);
170-
if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
171-
DOUBLE.writeDouble(out, result);
172-
}
173-
else {
174-
out.appendNull();
175-
}
176-
}
17774
}
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.presto.operator.aggregation;
15+
16+
import com.facebook.presto.common.block.BlockBuilder;
17+
import com.facebook.presto.common.type.StandardTypes;
18+
import com.facebook.presto.operator.aggregation.state.ExtendedRegressionState;
19+
import com.facebook.presto.spi.function.AggregationFunction;
20+
import com.facebook.presto.spi.function.AggregationState;
21+
import com.facebook.presto.spi.function.CombineFunction;
22+
import com.facebook.presto.spi.function.InputFunction;
23+
import com.facebook.presto.spi.function.OutputFunction;
24+
import com.facebook.presto.spi.function.SqlType;
25+
26+
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
27+
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx;
28+
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy;
29+
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount;
30+
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2;
31+
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx;
32+
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy;
33+
import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy;
34+
import static com.facebook.presto.operator.aggregation.AggregationUtils.mergeExtendedRegressionState;
35+
import static com.facebook.presto.operator.aggregation.AggregationUtils.updateExtendedRegressionState;
36+
37+
@AggregationFunction
38+
public class DoubleRegressionExtendedAggregation
39+
{
40+
private DoubleRegressionExtendedAggregation() {}
41+
42+
@InputFunction
43+
public static void input(@AggregationState ExtendedRegressionState state, @SqlType(StandardTypes.DOUBLE) double dependentValue, @SqlType(StandardTypes.DOUBLE) double independentValue)
44+
{
45+
updateExtendedRegressionState(state, independentValue, dependentValue);
46+
}
47+
48+
@CombineFunction
49+
public static void combine(@AggregationState ExtendedRegressionState state, @AggregationState ExtendedRegressionState otherState)
50+
{
51+
mergeExtendedRegressionState(state, otherState);
52+
}
53+
54+
@AggregationFunction("regr_sxy")
55+
@OutputFunction(StandardTypes.DOUBLE)
56+
public static void regrSxy(@AggregationState ExtendedRegressionState state, BlockBuilder out)
57+
{
58+
double result = getRegressionSxy(state);
59+
double count = getRegressionCount(state);
60+
if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
61+
DOUBLE.writeDouble(out, result);
62+
}
63+
else {
64+
out.appendNull();
65+
}
66+
}
67+
68+
@AggregationFunction("regr_sxx")
69+
@OutputFunction(StandardTypes.DOUBLE)
70+
public static void regrSxx(@AggregationState ExtendedRegressionState state, BlockBuilder out)
71+
{
72+
double result = getRegressionSxx(state);
73+
double count = getRegressionCount(state);
74+
if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
75+
DOUBLE.writeDouble(out, result);
76+
}
77+
else {
78+
out.appendNull();
79+
}
80+
}
81+
82+
@AggregationFunction("regr_syy")
83+
@OutputFunction(StandardTypes.DOUBLE)
84+
public static void regrSyy(@AggregationState ExtendedRegressionState state, BlockBuilder out)
85+
{
86+
double result = getRegressionSyy(state);
87+
double count = getRegressionCount(state);
88+
if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
89+
DOUBLE.writeDouble(out, result);
90+
}
91+
else {
92+
out.appendNull();
93+
}
94+
}
95+
96+
@AggregationFunction("regr_r2")
97+
@OutputFunction(StandardTypes.DOUBLE)
98+
public static void regrR2(@AggregationState ExtendedRegressionState state, BlockBuilder out)
99+
{
100+
double result = getRegressionR2(state);
101+
if (Double.isFinite(result)) {
102+
DOUBLE.writeDouble(out, result);
103+
}
104+
else {
105+
out.appendNull();
106+
}
107+
}
108+
109+
@AggregationFunction("regr_count")
110+
@OutputFunction(StandardTypes.DOUBLE)
111+
public static void regrCount(@AggregationState ExtendedRegressionState state, BlockBuilder out)
112+
{
113+
double result = getRegressionCount(state);
114+
if (Double.isFinite(result) && result > 0) {
115+
DOUBLE.writeDouble(out, result);
116+
}
117+
else {
118+
out.appendNull();
119+
}
120+
}
121+
122+
@AggregationFunction("regr_avgy")
123+
@OutputFunction(StandardTypes.DOUBLE)
124+
public static void regrAvgy(@AggregationState ExtendedRegressionState state, BlockBuilder out)
125+
{
126+
double result = getRegressionAvgy(state);
127+
double count = getRegressionCount(state);
128+
if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
129+
DOUBLE.writeDouble(out, result);
130+
}
131+
else {
132+
out.appendNull();
133+
}
134+
}
135+
136+
@AggregationFunction("regr_avgx")
137+
@OutputFunction(StandardTypes.DOUBLE)
138+
public static void regrAvgx(@AggregationState ExtendedRegressionState state, BlockBuilder out)
139+
{
140+
double result = getRegressionAvgx(state);
141+
double count = getRegressionCount(state);
142+
if (Double.isFinite(result) && Double.isFinite(count) && count > 0) {
143+
DOUBLE.writeDouble(out, result);
144+
}
145+
else {
146+
out.appendNull();
147+
}
148+
}
149+
}

0 commit comments

Comments
 (0)