Skip to content

Commit e74ea4b

Browse files
JoverZhangwyxxxcat
andauthored
[Enhancement][doris-future] Unify regr_sxx/syy/sxy on AggregateFunctionRegrData (#59224)
### What problem does this PR solve? Issue Number: close #38977 Problem Summary: This PR migrates regr_sxx/syy/sxy onto the shared Moment(AggregateFunctionRegrData) introduced in #55940. The original implementation and tests were done in #39187 by @wyxxxcat. This PR builds on top of that work, refactoring it to reuse the same state and merge logic. --------- Co-authored-by: wyxxxcat <[email protected]>
1 parent c7b7ad7 commit e74ea4b

File tree

14 files changed

+768
-30
lines changed

14 files changed

+768
-30
lines changed

be/src/vec/aggregate_functions/aggregate_function_regr_union.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,8 @@ void register_aggregate_function_regr_union(AggregateFunctionSimpleFactory& fact
6060
factory.register_function_both("regr_slope", create_aggregate_function_regr<RegrSlopeFunc>);
6161
factory.register_function_both("regr_intercept",
6262
create_aggregate_function_regr<RegrInterceptFunc>);
63+
factory.register_function_both("regr_sxx", create_aggregate_function_regr<RegrSxxFunc>);
64+
factory.register_function_both("regr_syy", create_aggregate_function_regr<RegrSyyFunc>);
65+
factory.register_function_both("regr_sxy", create_aggregate_function_regr<RegrSxyFunc>);
6366
}
6467
} // namespace doris::vectorized

be/src/vec/aggregate_functions/aggregate_function_regr_union.h

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -267,13 +267,48 @@ struct RegrInterceptFunc : AggregateFunctionRegrData<T, true, 2, 2> {
267267
}
268268
};
269269

270+
template <PrimitiveType T>
271+
struct RegrSxxFunc : AggregateFunctionRegrData<T, false, 2, 0> {
272+
static constexpr const char* name = "regr_sxx";
273+
274+
Float64 get_result() const {
275+
if (this->n < 1) {
276+
return std::numeric_limits<Float64>::quiet_NaN();
277+
}
278+
return this->sxx();
279+
}
280+
};
281+
282+
template <PrimitiveType T>
283+
struct RegrSyyFunc : AggregateFunctionRegrData<T, false, 0, 2> {
284+
static constexpr const char* name = "regr_syy";
285+
286+
Float64 get_result() const {
287+
if (this->n < 1) {
288+
return std::numeric_limits<Float64>::quiet_NaN();
289+
}
290+
return this->syy();
291+
}
292+
};
293+
294+
template <PrimitiveType T>
295+
struct RegrSxyFunc : AggregateFunctionRegrData<T, true, 1, 1> {
296+
static constexpr const char* name = "regr_sxy";
297+
298+
Float64 get_result() const {
299+
if (this->n < 1) {
300+
return std::numeric_limits<Float64>::quiet_NaN();
301+
}
302+
return this->sxy();
303+
}
304+
};
305+
270306
template <typename RegrFunc, bool y_nullable, bool x_nullable>
271307
class AggregateFunctionRegrSimple
272308
: public IAggregateFunctionDataHelper<
273309
RegrFunc, AggregateFunctionRegrSimple<RegrFunc, y_nullable, x_nullable>> {
274310
public:
275-
using XInputCol = typename PrimitiveTypeTraits<RegrFunc::Type>::ColumnType;
276-
using YInputCol = XInputCol;
311+
using InputCol = typename PrimitiveTypeTraits<RegrFunc::Type>::ColumnType;
277312
using ResultCol = ColumnFloat64;
278313

279314
explicit AggregateFunctionRegrSimple(const DataTypes& argument_types_)
@@ -291,39 +326,20 @@ class AggregateFunctionRegrSimple
291326

292327
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
293328
Arena&) const override {
294-
bool y_null = false;
295-
bool x_null = false;
296-
const YInputCol* y_nested_column = nullptr;
297-
const XInputCol* x_nested_column = nullptr;
298-
329+
const auto* y_col = nested_or_null<y_nullable>(columns[0], row_num);
299330
if constexpr (y_nullable) {
300-
const auto& y_column_nullable =
301-
assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(*columns[0]);
302-
y_null = y_column_nullable.is_null_at(row_num);
303-
y_nested_column = assert_cast<const YInputCol*, TypeCheckOnRelease::DISABLE>(
304-
y_column_nullable.get_nested_column_ptr().get());
305-
} else {
306-
y_nested_column = assert_cast<const YInputCol*, TypeCheckOnRelease::DISABLE>(
307-
(*columns[0]).get_ptr().get());
331+
if (y_col == nullptr) {
332+
return;
333+
}
308334
}
309-
335+
const auto* x_col = nested_or_null<x_nullable>(columns[1], row_num);
310336
if constexpr (x_nullable) {
311-
const auto& x_column_nullable =
312-
assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(*columns[1]);
313-
x_null = x_column_nullable.is_null_at(row_num);
314-
x_nested_column = assert_cast<const XInputCol*, TypeCheckOnRelease::DISABLE>(
315-
x_column_nullable.get_nested_column_ptr().get());
316-
} else {
317-
x_nested_column = assert_cast<const XInputCol*, TypeCheckOnRelease::DISABLE>(
318-
(*columns[1]).get_ptr().get());
319-
}
320-
321-
if (x_null || y_null) {
322-
return;
337+
if (x_col == nullptr) {
338+
return;
339+
}
323340
}
324341

325-
this->data(place).add(y_nested_column->get_data()[row_num],
326-
x_nested_column->get_data()[row_num]);
342+
this->data(place).add(y_col->get_data()[row_num], x_col->get_data()[row_num]);
327343
}
328344

329345
void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }
@@ -355,6 +371,21 @@ class AggregateFunctionRegrSimple
355371
dst_column.get_data().push_back(result);
356372
}
357373
}
374+
375+
private:
376+
template <bool Nullable>
377+
static ALWAYS_INLINE const InputCol* nested_or_null(const IColumn* col, ssize_t row_num) {
378+
if constexpr (Nullable) {
379+
const auto& c = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(*col);
380+
if (c.is_null_at(row_num)) {
381+
return nullptr;
382+
}
383+
return assert_cast<const InputCol*, TypeCheckOnRelease::DISABLE>(
384+
c.get_nested_column_ptr().get());
385+
} else {
386+
return assert_cast<const InputCol*, TypeCheckOnRelease::DISABLE>(col->get_ptr().get());
387+
}
388+
}
358389
};
359390
} // namespace doris::vectorized
360391

fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@
7777
import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion;
7878
import org.apache.doris.nereids.trees.expressions.functions.agg.RegrIntercept;
7979
import org.apache.doris.nereids.trees.expressions.functions.agg.RegrSlope;
80+
import org.apache.doris.nereids.trees.expressions.functions.agg.RegrSxx;
81+
import org.apache.doris.nereids.trees.expressions.functions.agg.RegrSxy;
82+
import org.apache.doris.nereids.trees.expressions.functions.agg.RegrSyy;
8083
import org.apache.doris.nereids.trees.expressions.functions.agg.Retention;
8184
import org.apache.doris.nereids.trees.expressions.functions.agg.Sem;
8285
import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceCount;
@@ -173,6 +176,9 @@ private BuiltinAggregateFunctions() {
173176
agg(QuantileUnion.class, "quantile_union"),
174177
agg(RegrIntercept.class, "regr_intercept"),
175178
agg(RegrSlope.class, "regr_slope"),
179+
agg(RegrSxx.class, "regr_sxx"),
180+
agg(RegrSxy.class, "regr_sxy"),
181+
agg(RegrSyy.class, "regr_syy"),
176182
agg(Retention.class, "retention"),
177183
agg(Sem.class, "sem"),
178184
agg(SequenceCount.class, "sequence_count"),

fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ public boolean isNullResultWithOneNullParamFunctions(String funcName) {
104104

105105
public static final String REGR_SLOPE = "regr_slope";
106106

107+
public static final String REGR_SXX = "regr_sxx";
108+
109+
public static final String REGR_SXY = "regr_sxy";
110+
111+
public static final String REGR_SYY = "regr_syy";
112+
107113
public static final String SEQUENCE_COUNT = "sequence_count";
108114

109115
public static final String GROUP_ARRAY_INTERSECT = "group_array_intersect";
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package org.apache.doris.nereids.trees.expressions.functions.agg;
19+
20+
import org.apache.doris.catalog.FunctionSignature;
21+
import org.apache.doris.catalog.Type;
22+
import org.apache.doris.nereids.exceptions.AnalysisException;
23+
import org.apache.doris.nereids.trees.expressions.Expression;
24+
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
25+
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
26+
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
27+
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
28+
import org.apache.doris.nereids.types.DataType;
29+
import org.apache.doris.nereids.types.DoubleType;
30+
31+
import com.google.common.base.Preconditions;
32+
import com.google.common.collect.ImmutableList;
33+
34+
import java.util.List;
35+
36+
/** regr_sxx agg function. */
37+
public class RegrSxx extends AggregateFunction
38+
implements BinaryExpression, ExplicitlyCastableSignature, AlwaysNullable {
39+
40+
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
41+
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE));
42+
43+
public RegrSxx(Expression arg0, Expression arg1) {
44+
this(false, arg0, arg1);
45+
}
46+
47+
public RegrSxx(boolean distinct, Expression arg0, Expression arg1) {
48+
super("regr_sxx", distinct, arg0, arg1);
49+
}
50+
51+
public RegrSxx(AggregateFunctionParams functionParams) {
52+
super(functionParams);
53+
}
54+
55+
@Override
56+
public void checkLegalityBeforeTypeCoercion() {
57+
DataType yType = left().getDataType();
58+
DataType xType = right().getDataType();
59+
if (yType.isOnlyMetricType() || xType.isOnlyMetricType()) {
60+
throw new AnalysisException(Type.OnlyMetricTypeErrorMsg);
61+
}
62+
if (!yType.isNumericType() && !yType.isNullType()) {
63+
throw new AnalysisException("regr_sxx requires numeric for first parameter: " + toSql());
64+
}
65+
if (!xType.isNumericType() && !xType.isNullType()) {
66+
throw new AnalysisException("regr_sxx requires numeric for second parameter: " + toSql());
67+
}
68+
}
69+
70+
@Override
71+
public RegrSxx withDistinctAndChildren(boolean distinct, List<Expression> children) {
72+
Preconditions.checkArgument(children.size() == 2);
73+
return new RegrSxx(getFunctionParams(distinct, children));
74+
}
75+
76+
@Override
77+
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
78+
return visitor.visitRegrSxx(this, context);
79+
}
80+
81+
@Override
82+
public List<FunctionSignature> getSignatures() {
83+
return SIGNATURES;
84+
}
85+
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package org.apache.doris.nereids.trees.expressions.functions.agg;
19+
20+
import org.apache.doris.catalog.FunctionSignature;
21+
import org.apache.doris.catalog.Type;
22+
import org.apache.doris.nereids.exceptions.AnalysisException;
23+
import org.apache.doris.nereids.trees.expressions.Expression;
24+
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
25+
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
26+
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
27+
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
28+
import org.apache.doris.nereids.types.DataType;
29+
import org.apache.doris.nereids.types.DoubleType;
30+
31+
import com.google.common.base.Preconditions;
32+
import com.google.common.collect.ImmutableList;
33+
34+
import java.util.List;
35+
36+
/** regr_sxy agg function. */
37+
public class RegrSxy extends AggregateFunction
38+
implements BinaryExpression, ExplicitlyCastableSignature, AlwaysNullable {
39+
40+
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
41+
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE));
42+
43+
public RegrSxy(Expression arg0, Expression arg1) {
44+
this(false, arg0, arg1);
45+
}
46+
47+
public RegrSxy(boolean distinct, Expression arg0, Expression arg1) {
48+
super("regr_sxy", distinct, arg0, arg1);
49+
}
50+
51+
public RegrSxy(AggregateFunctionParams functionParams) {
52+
super(functionParams);
53+
}
54+
55+
@Override
56+
public void checkLegalityBeforeTypeCoercion() {
57+
DataType yType = left().getDataType();
58+
DataType xType = right().getDataType();
59+
if (yType.isOnlyMetricType() || xType.isOnlyMetricType()) {
60+
throw new AnalysisException(Type.OnlyMetricTypeErrorMsg);
61+
}
62+
if (!yType.isNumericType() && !yType.isNullType()) {
63+
throw new AnalysisException("regr_sxy requires numeric for first parameter: " + toSql());
64+
}
65+
if (!xType.isNumericType() && !xType.isNullType()) {
66+
throw new AnalysisException("regr_sxy requires numeric for second parameter: " + toSql());
67+
}
68+
}
69+
70+
@Override
71+
public RegrSxy withDistinctAndChildren(boolean distinct, List<Expression> children) {
72+
Preconditions.checkArgument(children.size() == 2);
73+
return new RegrSxy(getFunctionParams(distinct, children));
74+
}
75+
76+
@Override
77+
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
78+
return visitor.visitRegrSxy(this, context);
79+
}
80+
81+
@Override
82+
public List<FunctionSignature> getSignatures() {
83+
return SIGNATURES;
84+
}
85+
}

0 commit comments

Comments
 (0)