Skip to content

Commit 8ea3372

Browse files
committed
Merge remote-tracking branch 'upstream/main' into feature/pre-fetch-batches-in-enumeration
2 parents 320bf8e + 5be225e commit 8ea3372

File tree

96 files changed

+3133
-318
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+3133
-318
lines changed

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 128 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_NAME_MAIN;
1919
import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_NAME_SUBSEARCH;
2020
import static org.opensearch.sql.calcite.utils.PlanUtils.getRelation;
21+
import static org.opensearch.sql.calcite.utils.PlanUtils.getRexCall;
2122
import static org.opensearch.sql.calcite.utils.PlanUtils.transformPlanToAttachChild;
2223

2324
import com.google.common.base.Strings;
@@ -53,6 +54,7 @@
5354
import org.apache.calcite.rex.RexNode;
5455
import org.apache.calcite.rex.RexVisitorImpl;
5556
import org.apache.calcite.rex.RexWindowBounds;
57+
import org.apache.calcite.sql.SqlKind;
5658
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
5759
import org.apache.calcite.sql.type.SqlTypeFamily;
5860
import org.apache.calcite.sql.type.SqlTypeName;
@@ -691,7 +693,19 @@ public RelNode visitPatterns(Patterns node, CalcitePlanContext context) {
691693
context.relBuilder.field(node.getAlias()),
692694
context.relBuilder.field(PatternUtils.SAMPLE_LOGS));
693695
flattenParsedPattern(node.getAlias(), parsedNode, context, false);
694-
context.relBuilder.projectExcept(context.relBuilder.field(PatternUtils.SAMPLE_LOGS));
696+
// Reorder fields for consistency with Brain's output
697+
projectPlusOverriding(
698+
List.of(
699+
context.relBuilder.field(node.getAlias()),
700+
context.relBuilder.field(PatternUtils.PATTERN_COUNT),
701+
context.relBuilder.field(PatternUtils.TOKENS),
702+
context.relBuilder.field(PatternUtils.SAMPLE_LOGS)),
703+
List.of(
704+
node.getAlias(),
705+
PatternUtils.PATTERN_COUNT,
706+
PatternUtils.TOKENS,
707+
PatternUtils.SAMPLE_LOGS),
708+
context);
695709
} else {
696710
RexNode parsedNode =
697711
PPLFuncImpTable.INSTANCE.resolve(
@@ -813,6 +827,23 @@ private void projectPlusOverriding(
813827
context.relBuilder.rename(expectedRenameFields);
814828
}
815829

830+
private List<List<RexInputRef>> extractInputRefList(List<RelBuilder.AggCall> aggCalls) {
831+
return aggCalls.stream()
832+
.map(RelBuilder.AggCall::over)
833+
.map(RelBuilder.OverCall::toRex)
834+
.map(node -> getRexCall(node, this::isCountField))
835+
.map(list -> list.isEmpty() ? null : list.getFirst())
836+
.map(PlanUtils::getInputRefs)
837+
.toList();
838+
}
839+
840+
/** Is count(FIELD) */
841+
private boolean isCountField(RexCall call) {
842+
return call.isA(SqlKind.COUNT)
843+
&& call.getOperands().size() == 1 // count(FIELD)
844+
&& call.getOperands().get(0) instanceof RexInputRef;
845+
}
846+
816847
/**
817848
* Resolve the aggregation with trimming unused fields to avoid bugs in {@link
818849
* org.apache.calcite.sql2rel.RelDecorrelator#decorrelateRel(Aggregate, boolean)}
@@ -826,6 +857,72 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
826857
List<UnresolvedExpression> groupExprList,
827858
List<UnresolvedExpression> aggExprList,
828859
CalcitePlanContext context) {
860+
Pair<List<RexNode>, List<AggCall>> resolved =
861+
resolveAttributesForAggregation(groupExprList, aggExprList, context);
862+
List<RexNode> resolvedGroupByList = resolved.getLeft();
863+
List<AggCall> resolvedAggCallList = resolved.getRight();
864+
865+
// `doc_count` optimization required a filter `isNotNull(RexInputRef)` for the
866+
// `count(FIELD)` aggregation which only can be applied to single FIELD without grouping:
867+
//
868+
// Example 1: source=t | stats count(a)
869+
// Before: Aggregate(count(a))
870+
// \- Scan t
871+
// After: Aggregate(count(a))
872+
// \- Filter(isNotNull(a))
873+
// \- Scan t
874+
//
875+
// Example 2: source=t | stats count(a), count(a)
876+
// Before: Aggregate(count(a), count(a))
877+
// \- Scan t
878+
// After: Aggregate(count(a), count(a))
879+
// \- Filter(isNotNull(a))
880+
// \- Scan t
881+
//
882+
// Example 3: source=t | stats count(a) by b
883+
// Before & After: Aggregate(count(a) by b)
884+
// \- Scan t
885+
//
886+
// Example 4: source=t | stats count()
887+
// Before & After: Aggregate(count())
888+
// \- Scan t
889+
//
890+
// Example 5: source=t | stats count(), count(a)
891+
// Before & After: Aggregate(count(), count(a))
892+
// \- Scan t
893+
//
894+
// Example 6: source=t | stats count(a), count(b)
895+
// Before & After: Aggregate(count(a), count(b))
896+
// \- Scan t
897+
//
898+
// Example 7: source=t | stats count(a+1)
899+
// Before & After: Aggregate(count(a+1))
900+
// \- Scan t
901+
if (resolvedGroupByList.isEmpty()) {
902+
List<List<RexInputRef>> refsPerCount = extractInputRefList(resolvedAggCallList);
903+
List<RexInputRef> distinctRefsOfCounts;
904+
if (context.relBuilder.peek() instanceof org.apache.calcite.rel.core.Project project) {
905+
List<RexNode> mappedInProject =
906+
refsPerCount.stream()
907+
.flatMap(List::stream)
908+
.map(ref -> project.getProjects().get(ref.getIndex()))
909+
.toList();
910+
if (mappedInProject.stream().allMatch(RexInputRef.class::isInstance)) {
911+
distinctRefsOfCounts =
912+
mappedInProject.stream().map(RexInputRef.class::cast).distinct().toList();
913+
} else {
914+
distinctRefsOfCounts = List.of();
915+
}
916+
} else {
917+
distinctRefsOfCounts = refsPerCount.stream().flatMap(List::stream).distinct().toList();
918+
}
919+
if (distinctRefsOfCounts.size() == 1 && refsPerCount.stream().noneMatch(List::isEmpty)) {
920+
context.relBuilder.filter(context.relBuilder.isNotNull(distinctRefsOfCounts.getFirst()));
921+
}
922+
}
923+
924+
// Add project before aggregate:
925+
//
829926
// Example 1: source=t | where a > 1 | stats avg(b + 1) by c
830927
// Before: Aggregate(avg(b + 1))
831928
// \- Filter(a > 1)
@@ -836,23 +933,22 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
836933
// \- Scan t
837934
//
838935
// Example 2: source=t | where a > 1 | top b by c
839-
// Before: Aggregate(count)
840-
// \-Filter(a > 1)
936+
// Before: Aggregate(count(b) by c)
937+
// \-Filter(a > 1 && isNotNull(b))
841938
// \- Scan t
842-
// After: Aggregate(count)
939+
// After: Aggregate(count(b) by c)
843940
// \- Project([c, b])
844-
// \- Filter(a > 1)
941+
// \- Filter(a > 1 && isNotNull(b))
845942
// \- Scan t
846-
// Example 3: source=t | stats count(): no project added for count()
847-
// Before: Aggregate(count)
943+
//
944+
// Example 3: source=t | stats count(): no change for count()
945+
// Before: Aggregate(count())
848946
// \- Scan t
849-
// After: Aggregate(count)
947+
// After: Aggregate(count())
850948
// \- Scan t
851-
Pair<List<RexNode>, List<AggCall>> resolved =
852-
resolveAttributesForAggregation(groupExprList, aggExprList, context);
853949
List<RexInputRef> trimmedRefs = new ArrayList<>();
854-
trimmedRefs.addAll(PlanUtils.getInputRefs(resolved.getLeft())); // group-by keys first
855-
trimmedRefs.addAll(PlanUtils.getInputRefsFromAggCall(resolved.getRight()));
950+
trimmedRefs.addAll(PlanUtils.getInputRefs(resolvedGroupByList)); // group-by keys first
951+
trimmedRefs.addAll(PlanUtils.getInputRefsFromAggCall(resolvedAggCallList));
856952
context.relBuilder.project(trimmedRefs);
857953

858954
// Re-resolve all attributes based on adding trimmed Project.
@@ -2258,7 +2354,7 @@ private void flattenParsedPattern(
22582354
String originalPatternResultAlias,
22592355
RexNode parsedNode,
22602356
CalcitePlanContext context,
2261-
boolean flattenPatternCount) {
2357+
boolean flattenPatternAggResult) {
22622358
List<RexNode> fattenedNodes = new ArrayList<>();
22632359
List<String> projectNames = new ArrayList<>();
22642360
// Flatten map struct fields
@@ -2274,7 +2370,7 @@ private void flattenParsedPattern(
22742370
true);
22752371
fattenedNodes.add(context.relBuilder.alias(patternExpr, originalPatternResultAlias));
22762372
projectNames.add(originalPatternResultAlias);
2277-
if (flattenPatternCount) {
2373+
if (flattenPatternAggResult) {
22782374
RexNode patternCountExpr =
22792375
context.rexBuilder.makeCast(
22802376
context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BIGINT),
@@ -2300,6 +2396,24 @@ private void flattenParsedPattern(
23002396
true);
23012397
fattenedNodes.add(context.relBuilder.alias(tokensExpr, PatternUtils.TOKENS));
23022398
projectNames.add(PatternUtils.TOKENS);
2399+
if (flattenPatternAggResult) {
2400+
RexNode sampleLogsExpr =
2401+
context.rexBuilder.makeCast(
2402+
context
2403+
.rexBuilder
2404+
.getTypeFactory()
2405+
.createArrayType(
2406+
context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR), -1),
2407+
PPLFuncImpTable.INSTANCE.resolve(
2408+
context.rexBuilder,
2409+
BuiltinFunctionName.INTERNAL_ITEM,
2410+
parsedNode,
2411+
context.rexBuilder.makeLiteral(PatternUtils.SAMPLE_LOGS)),
2412+
true,
2413+
true);
2414+
fattenedNodes.add(context.relBuilder.alias(sampleLogsExpr, PatternUtils.SAMPLE_LOGS));
2415+
projectNames.add(PatternUtils.SAMPLE_LOGS);
2416+
}
23032417
projectPlusOverriding(fattenedNodes, projectNames, context);
23042418
}
23052419

core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ public Object value(Object... argList) {
184184
PatternUtils.PATTERN,
185185
parseResult.toTokenOrderString(PatternUtils.WILDCARD_PREFIX),
186186
PatternUtils.PATTERN_COUNT, count,
187-
PatternUtils.TOKENS, tokensMap);
187+
PatternUtils.TOKENS, tokensMap,
188+
PatternUtils.SAMPLE_LOGS, sampleLogs);
188189
})
189190
.collect(Collectors.toList());
190191
}

core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import java.util.ArrayList;
1616
import java.util.List;
1717
import java.util.Objects;
18+
import java.util.function.Predicate;
1819
import java.util.stream.Collectors;
1920
import javax.annotation.Nullable;
2021
import org.apache.calcite.plan.RelOptTable;
@@ -255,6 +256,9 @@ static RelBuilder.AggCall makeAggCall(
255256

256257
/** Get all uniq input references from a RexNode. */
257258
static List<RexInputRef> getInputRefs(RexNode node) {
259+
if (node == null) {
260+
return List.of();
261+
}
258262
List<RexInputRef> inputRefs = new ArrayList<>();
259263
node.accept(
260264
new RexVisitorImpl<Void>(true) {
@@ -274,6 +278,26 @@ static List<RexInputRef> getInputRefs(List<RexNode> nodes) {
274278
return nodes.stream().flatMap(node -> getInputRefs(node).stream()).toList();
275279
}
276280

281+
/** Get all uniq RexCall from RexNode with a predicate */
282+
static List<RexCall> getRexCall(RexNode node, Predicate<RexCall> predicate) {
283+
List<RexCall> list = new ArrayList<>();
284+
node.accept(
285+
new RexVisitorImpl<Void>(true) {
286+
@Override
287+
public Void visitCall(RexCall inputCall) {
288+
if (predicate.test(inputCall)) {
289+
if (!list.contains(inputCall)) {
290+
list.add(inputCall);
291+
}
292+
} else {
293+
inputCall.getOperands().forEach(call -> call.accept(this));
294+
}
295+
return null;
296+
}
297+
});
298+
return list;
299+
}
300+
277301
/** Get all uniq input references from a list of agg calls. */
278302
static List<RexInputRef> getInputRefsFromAggCall(List<RelBuilder.AggCall> aggCalls) {
279303
return aggCalls.stream()

docs/category.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
"user/ppl/cmd/subquery.rst",
3232
"user/ppl/general/identifiers.rst",
3333
"user/ppl/general/datatypes.rst",
34-
"user/ppl/functions/condition.rst",
3534
"user/ppl/functions/datetime.rst",
3635
"user/ppl/functions/expressions.rst",
3736
"user/ppl/functions/ip.rst",
@@ -56,6 +55,7 @@
5655
],
5756
"ppl_cli_calcite": [
5857
"user/ppl/cmd/append.rst",
58+
"user/ppl/functions/condition.rst",
5959
"user/ppl/cmd/eventstats.rst",
6060
"user/ppl/cmd/fields.rst",
6161
"user/ppl/cmd/regex.rst",

0 commit comments

Comments
 (0)