Skip to content

Commit 9ff34e2

Browse files
authored
Keep aggregation in Calcite consistent with current PPL behavior (#3405)
* Keep aggregation in Calcite consistent with current PPL behavior Signed-off-by: Lantao Jin <ltjin@amazon.com> * remove unrelated code Signed-off-by: Lantao Jin <ltjin@amazon.com> * revert some code Signed-off-by: Lantao Jin <ltjin@amazon.com> * fix issue 3404 Signed-off-by: Lantao Jin <ltjin@amazon.com> * add more tests Signed-off-by: Lantao Jin <ltjin@amazon.com> * address comments Signed-off-by: Lantao Jin <ltjin@amazon.com> * add more tests Signed-off-by: Lantao Jin <ltjin@amazon.com> --------- Signed-off-by: Lantao Jin <ltjin@amazon.com>
1 parent b7f6cf7 commit 9ff34e2

30 files changed

+1297
-255
lines changed

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,22 @@
2525
import org.apache.calcite.rel.core.JoinRelType;
2626
import org.apache.calcite.rex.RexCall;
2727
import org.apache.calcite.rex.RexCorrelVariable;
28+
import org.apache.calcite.rex.RexInputRef;
2829
import org.apache.calcite.rex.RexLiteral;
2930
import org.apache.calcite.rex.RexNode;
31+
import org.apache.calcite.sql.SqlKind;
3032
import org.apache.calcite.tools.RelBuilder;
3133
import org.apache.calcite.tools.RelBuilder.AggCall;
3234
import org.apache.calcite.util.Holder;
3335
import org.checkerframework.checker.nullness.qual.Nullable;
3436
import org.opensearch.sql.ast.AbstractNodeVisitor;
3537
import org.opensearch.sql.ast.Node;
38+
import org.opensearch.sql.ast.expression.Alias;
3639
import org.opensearch.sql.ast.expression.AllFields;
3740
import org.opensearch.sql.ast.expression.Argument;
3841
import org.opensearch.sql.ast.expression.Field;
3942
import org.opensearch.sql.ast.expression.Let;
43+
import org.opensearch.sql.ast.expression.Map;
4044
import org.opensearch.sql.ast.expression.QualifiedName;
4145
import org.opensearch.sql.ast.expression.UnresolvedExpression;
4246
import org.opensearch.sql.ast.expression.subquery.SubqueryExpression;
@@ -48,10 +52,12 @@
4852
import org.opensearch.sql.ast.tree.Lookup;
4953
import org.opensearch.sql.ast.tree.Project;
5054
import org.opensearch.sql.ast.tree.Relation;
55+
import org.opensearch.sql.ast.tree.Rename;
5156
import org.opensearch.sql.ast.tree.Sort;
5257
import org.opensearch.sql.ast.tree.SubqueryAlias;
5358
import org.opensearch.sql.ast.tree.UnresolvedPlan;
5459
import org.opensearch.sql.calcite.utils.JoinAndLookupUtils;
60+
import org.opensearch.sql.exception.SemanticCheckException;
5561

5662
public class CalciteRelNodeVisitor extends AbstractNodeVisitor<RelNode, CalcitePlanContext> {
5763

@@ -146,6 +152,30 @@ public RelNode visitProject(Project node, CalcitePlanContext context) {
146152
return context.relBuilder.peek();
147153
}
148154

155+
@Override
156+
public RelNode visitRename(Rename node, CalcitePlanContext context) {
157+
visitChildren(node, context);
158+
List<String> originalNames = context.relBuilder.peek().getRowType().getFieldNames();
159+
List<String> newNames = new ArrayList<>(originalNames);
160+
for (Map renameMap : node.getRenameList()) {
161+
if (renameMap.getTarget() instanceof Field t) {
162+
String newName = t.getField().toString();
163+
RexNode check = rexVisitor.analyze(renameMap.getOrigin(), context);
164+
if (check instanceof RexInputRef ref) {
165+
newNames.set(ref.getIndex(), newName);
166+
} else {
167+
throw new SemanticCheckException(
168+
String.format("the original field %s cannot be resolved", renameMap.getOrigin()));
169+
}
170+
} else {
171+
throw new SemanticCheckException(
172+
String.format("the target expected to be field, but is %s", renameMap.getTarget()));
173+
}
174+
}
175+
context.relBuilder.rename(newNames);
176+
return context.relBuilder.peek();
177+
}
178+
149179
@Override
150180
public RelNode visitSort(Sort node, CalcitePlanContext context) {
151181
visitChildren(node, context);
@@ -256,21 +286,84 @@ public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
256286
node.getAggExprList().stream()
257287
.map(expr -> aggVisitor.analyze(expr, context))
258288
.collect(Collectors.toList());
259-
List<RexNode> groupByList =
260-
node.getGroupExprList().stream()
261-
.map(expr -> rexVisitor.analyze(expr, context))
262-
.collect(Collectors.toList());
263-
289+
// The span column is always the first column in result whatever
290+
// the order of span in query is first or last one
291+
List<RexNode> groupByList = new ArrayList<>();
264292
UnresolvedExpression span = node.getSpan();
265293
if (!Objects.isNull(span)) {
266294
RexNode spanRex = rexVisitor.analyze(span, context);
267295
groupByList.add(spanRex);
268296
// add span's group alias field (most recent added expression)
269297
}
298+
groupByList.addAll(
299+
node.getGroupExprList().stream().map(expr -> rexVisitor.analyze(expr, context)).toList());
300+
270301
context.relBuilder.aggregate(context.relBuilder.groupKey(groupByList), aggList);
302+
303+
// handle normal aggregate
304+
// TODO Should we keep alignment with V2 behaviour in new Calcite implementation?
305+
// TODO how about add a legacy enable config to control behaviour in Calcite?
306+
// Some behaviours between PPL and Databases are different.
307+
// As an example, in command `stats count() by colA, colB`:
308+
// 1. the sequence of output schema is different:
309+
// In PPL v2, the sequence of output schema is "count, colA, colB".
310+
// But in most databases, the sequence of output schema is "colA, colB, count".
311+
// 2. the output order is different:
312+
// In PPL v2, the order of output results is ordered by "colA + colB".
313+
// But in most databases, the output order is random.
314+
// User must add ORDER BY clause after GROUP BY clause to keep the results aligning.
315+
// Following logic is to align with the PPL legacy behaviour.
316+
317+
// alignment for 1.sequence of output schema: adding order-by
318+
// we use the groupByList instead of node.getSortExprList as input because
319+
// the groupByList may include span column.
320+
node.getGroupExprList()
321+
.forEach(
322+
g -> {
323+
// node.getGroupExprList() should all be instance of Alias
324+
// which defined in AstBuilder.
325+
assert g instanceof Alias;
326+
});
327+
List<String> aliasesFromGroupByList =
328+
groupByList.stream()
329+
.map(this::extractAliasLiteral)
330+
.flatMap(Optional::stream)
331+
.map(ref -> ((RexLiteral) ref).getValueAs(String.class))
332+
.toList();
333+
List<RexNode> aliasedGroupByList =
334+
aliasesFromGroupByList.stream()
335+
.map(context.relBuilder::field)
336+
.map(f -> (RexNode) f)
337+
.toList();
338+
context.relBuilder.sort(aliasedGroupByList);
339+
340+
// alignment for 2.the output order: schema reordering
341+
List<RexNode> outputFields = context.relBuilder.fields();
342+
int numOfOutputFields = outputFields.size();
343+
int numOfAggList = aggList.size();
344+
List<RexNode> reordered = new ArrayList<>(numOfOutputFields);
345+
// Add aggregation results first
346+
List<RexNode> aggRexList =
347+
outputFields.subList(numOfOutputFields - numOfAggList, numOfOutputFields);
348+
reordered.addAll(aggRexList);
349+
// Add group by columns
350+
reordered.addAll(aliasedGroupByList);
351+
context.relBuilder.project(reordered);
352+
271353
return context.relBuilder.peek();
272354
}
273355

356+
/** extract the RexLiteral of Alias from a node */
357+
private Optional<RexLiteral> extractAliasLiteral(RexNode node) {
358+
if (node == null) {
359+
return Optional.empty();
360+
} else if (node.getKind() == SqlKind.AS) {
361+
return Optional.of((RexLiteral) ((RexCall) node).getOperands().get(1));
362+
} else {
363+
return Optional.empty();
364+
}
365+
}
366+
274367
@Override
275368
public RelNode visitJoin(Join node, CalcitePlanContext context) {
276369
List<UnresolvedPlan> children = node.getChildren();

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,16 +235,17 @@ public RexNode visitSpan(Span node, CalcitePlanContext context) {
235235
return context.rexBuilder.makeIntervalLiteral(new BigDecimal(millis), intervalQualifier);
236236
} else {
237237
// if the unit is not time base - create a math expression to bucket the span partitions
238+
SqlTypeName type = field.getType().getSqlTypeName();
238239
return context.rexBuilder.makeCall(
239-
typeFactory.createSqlType(SqlTypeName.DOUBLE),
240+
typeFactory.createSqlType(type),
240241
SqlStdOperatorTable.MULTIPLY,
241242
List.of(
242243
context.rexBuilder.makeCall(
243-
typeFactory.createSqlType(SqlTypeName.DOUBLE),
244+
typeFactory.createSqlType(type),
244245
SqlStdOperatorTable.FLOOR,
245246
List.of(
246247
context.rexBuilder.makeCall(
247-
typeFactory.createSqlType(SqlTypeName.DOUBLE),
248+
typeFactory.createSqlType(type),
248249
SqlStdOperatorTable.DIVIDE,
249250
List.of(field, value)))),
250251
value));
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
/*
7+
* This file contains code from the Apache Calcite project (original license below).
8+
* It contains modifications, which are licensed as above:
9+
*/
10+
11+
/*
12+
* Licensed to the Apache Software Foundation (ASF) under one or more
13+
* contributor license agreements. See the NOTICE file distributed with
14+
* this work for additional information regarding copyright ownership.
15+
* The ASF licenses this file to you under the Apache License, Version 2.0
16+
* (the "License"); you may not use this file except in compliance with
17+
* the License. You may obtain a copy of the License at
18+
*
19+
* http://www.apache.org/licenses/LICENSE-2.0
20+
*
21+
* Unless required by applicable law or agreed to in writing, software
22+
* distributed under the License is distributed on an "AS IS" BASIS,
23+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24+
* See the License for the specific language governing permissions and
25+
* limitations under the License.
26+
*/
27+
28+
package org.opensearch.sql.calcite.udf.udaf;
29+
30+
import static com.google.common.base.Preconditions.checkArgument;
31+
32+
import org.apache.calcite.sql.SqlAggFunction;
33+
import org.apache.calcite.sql.SqlFunctionCategory;
34+
import org.apache.calcite.sql.SqlKind;
35+
import org.apache.calcite.sql.fun.SqlAvgAggFunction;
36+
import org.apache.calcite.sql.type.OperandTypes;
37+
import org.apache.calcite.sql.type.ReturnTypes;
38+
import org.apache.calcite.sql.type.SqlTypeTransforms;
39+
import org.apache.calcite.util.Optionality;
40+
41+
public class NullableSqlAvgAggFunction extends SqlAggFunction {
42+
43+
// ~ Constructors -----------------------------------------------------------
44+
45+
/** Creates a NullableSqlAvgAggFunction. */
46+
public NullableSqlAvgAggFunction(SqlKind kind) {
47+
this(kind.name(), kind);
48+
}
49+
50+
NullableSqlAvgAggFunction(String name, SqlKind kind) {
51+
super(
52+
name,
53+
null,
54+
kind,
55+
ReturnTypes.AVG_AGG_FUNCTION.andThen(SqlTypeTransforms.FORCE_NULLABLE), // modified here
56+
null,
57+
OperandTypes.NUMERIC,
58+
SqlFunctionCategory.NUMERIC,
59+
false,
60+
false,
61+
Optionality.FORBIDDEN);
62+
checkArgument(SqlKind.AVG_AGG_FUNCTIONS.contains(kind), "unsupported sql kind");
63+
}
64+
65+
// ~ Methods ----------------------------------------------------------------
66+
67+
/**
68+
* Returns the specific function, e.g. AVG or STDDEV_POP.
69+
*
70+
* @return Subtype
71+
*/
72+
@Deprecated // to be removed before 2.0
73+
public SqlAvgAggFunction.Subtype getSubtype() {
74+
return SqlAvgAggFunction.Subtype.valueOf(kind.name());
75+
}
76+
77+
/** Sub-type of aggregate function. */
78+
@Deprecated // to be removed before 2.0
79+
public enum Subtype {
80+
AVG,
81+
STDDEV_POP,
82+
STDDEV_SAMP,
83+
VAR_POP,
84+
VAR_SAMP
85+
}
86+
}
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
/*
7+
* This file contains code from the Apache Calcite project (original license below).
8+
* It contains modifications, which are licensed as above:
9+
*/
10+
11+
/*
12+
* Licensed to the Apache Software Foundation (ASF) under one or more
13+
* contributor license agreements. See the NOTICE file distributed with
14+
* this work for additional information regarding copyright ownership.
15+
* The ASF licenses this file to you under the Apache License, Version 2.0
16+
* (the "License"); you may not use this file except in compliance with
17+
* the License. You may obtain a copy of the License at
18+
*
19+
* http://www.apache.org/licenses/LICENSE-2.0
20+
*
21+
* Unless required by applicable law or agreed to in writing, software
22+
* distributed under the License is distributed on an "AS IS" BASIS,
23+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24+
* See the License for the specific language governing permissions and
25+
* limitations under the License.
26+
*/
27+
28+
package org.opensearch.sql.calcite.udf.udaf;
29+
30+
import com.google.common.collect.ImmutableList;
31+
import java.util.List;
32+
import org.apache.calcite.rel.type.RelDataType;
33+
import org.apache.calcite.rel.type.RelDataTypeFactory;
34+
import org.apache.calcite.sql.SqlAggFunction;
35+
import org.apache.calcite.sql.SqlFunctionCategory;
36+
import org.apache.calcite.sql.SqlKind;
37+
import org.apache.calcite.sql.SqlSplittableAggFunction;
38+
import org.apache.calcite.sql.type.OperandTypes;
39+
import org.apache.calcite.sql.type.ReturnTypes;
40+
import org.apache.calcite.sql.type.SqlTypeTransforms;
41+
import org.apache.calcite.util.Optionality;
42+
import org.checkerframework.checker.nullness.qual.Nullable;
43+
44+
public class NullableSqlSumAggFunction extends SqlAggFunction {
45+
46+
// ~ Instance fields --------------------------------------------------------
47+
48+
@Deprecated // to be removed before 2.0
49+
private final RelDataType type;
50+
51+
// ~ Constructors -----------------------------------------------------------
52+
53+
public NullableSqlSumAggFunction(RelDataType type) {
54+
super(
55+
"SUM",
56+
null,
57+
SqlKind.SUM,
58+
ReturnTypes.AGG_SUM.andThen(SqlTypeTransforms.FORCE_NULLABLE), // modified here
59+
null,
60+
OperandTypes.NUMERIC,
61+
SqlFunctionCategory.NUMERIC,
62+
false,
63+
false,
64+
Optionality.FORBIDDEN);
65+
this.type = type;
66+
}
67+
68+
// ~ Methods ----------------------------------------------------------------
69+
70+
@SuppressWarnings("deprecation")
71+
@Override
72+
public List<RelDataType> getParameterTypes(RelDataTypeFactory typeFactory) {
73+
return ImmutableList.of(type);
74+
}
75+
76+
@Deprecated // to be removed before 2.0
77+
public RelDataType getType() {
78+
return type;
79+
}
80+
81+
@SuppressWarnings("deprecation")
82+
@Override
83+
public RelDataType getReturnType(RelDataTypeFactory typeFactory) {
84+
return type;
85+
}
86+
87+
@Override
88+
public <T extends Object> @Nullable T unwrap(Class<T> clazz) {
89+
if (clazz.isInstance(SqlSplittableAggFunction.SumSplitter.INSTANCE)) {
90+
return clazz.cast(SqlSplittableAggFunction.SumSplitter.INSTANCE);
91+
}
92+
return super.unwrap(clazz);
93+
}
94+
95+
@Override
96+
public SqlAggFunction getRollup() {
97+
return this;
98+
}
99+
}

0 commit comments

Comments
 (0)