1818import static org .opensearch .sql .calcite .utils .PlanUtils .ROW_NUMBER_COLUMN_NAME_MAIN ;
1919import static org .opensearch .sql .calcite .utils .PlanUtils .ROW_NUMBER_COLUMN_NAME_SUBSEARCH ;
2020import static org .opensearch .sql .calcite .utils .PlanUtils .getRelation ;
21+ import static org .opensearch .sql .calcite .utils .PlanUtils .getRexCall ;
2122import static org .opensearch .sql .calcite .utils .PlanUtils .transformPlanToAttachChild ;
2223
2324import com .google .common .base .Strings ;
5354import org .apache .calcite .rex .RexNode ;
5455import org .apache .calcite .rex .RexVisitorImpl ;
5556import org .apache .calcite .rex .RexWindowBounds ;
57+ import org .apache .calcite .sql .SqlKind ;
5658import org .apache .calcite .sql .fun .SqlStdOperatorTable ;
5759import org .apache .calcite .sql .type .SqlTypeFamily ;
5860import org .apache .calcite .sql .type .SqlTypeName ;
@@ -691,7 +693,19 @@ public RelNode visitPatterns(Patterns node, CalcitePlanContext context) {
691693 context .relBuilder .field (node .getAlias ()),
692694 context .relBuilder .field (PatternUtils .SAMPLE_LOGS ));
693695 flattenParsedPattern (node .getAlias (), parsedNode , context , false );
694- context .relBuilder .projectExcept (context .relBuilder .field (PatternUtils .SAMPLE_LOGS ));
696+ // Reorder fields for consistency with Brain's output
697+ projectPlusOverriding (
698+ List .of (
699+ context .relBuilder .field (node .getAlias ()),
700+ context .relBuilder .field (PatternUtils .PATTERN_COUNT ),
701+ context .relBuilder .field (PatternUtils .TOKENS ),
702+ context .relBuilder .field (PatternUtils .SAMPLE_LOGS )),
703+ List .of (
704+ node .getAlias (),
705+ PatternUtils .PATTERN_COUNT ,
706+ PatternUtils .TOKENS ,
707+ PatternUtils .SAMPLE_LOGS ),
708+ context );
695709 } else {
696710 RexNode parsedNode =
697711 PPLFuncImpTable .INSTANCE .resolve (
@@ -813,6 +827,23 @@ private void projectPlusOverriding(
813827 context .relBuilder .rename (expectedRenameFields );
814828 }
815829
830+ private List <List <RexInputRef >> extractInputRefList (List <RelBuilder .AggCall > aggCalls ) {
831+ return aggCalls .stream ()
832+ .map (RelBuilder .AggCall ::over )
833+ .map (RelBuilder .OverCall ::toRex )
834+ .map (node -> getRexCall (node , this ::isCountField ))
835+ .map (list -> list .isEmpty () ? null : list .getFirst ())
836+ .map (PlanUtils ::getInputRefs )
837+ .toList ();
838+ }
839+
840+ /** Is count(FIELD) */
841+ private boolean isCountField (RexCall call ) {
842+ return call .isA (SqlKind .COUNT )
843+ && call .getOperands ().size () == 1 // count(FIELD)
844+ && call .getOperands ().get (0 ) instanceof RexInputRef ;
845+ }
846+
816847 /**
817848 * Resolve the aggregation with trimming unused fields to avoid bugs in {@link
818849 * org.apache.calcite.sql2rel.RelDecorrelator#decorrelateRel(Aggregate, boolean)}
@@ -826,6 +857,72 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
826857 List <UnresolvedExpression > groupExprList ,
827858 List <UnresolvedExpression > aggExprList ,
828859 CalcitePlanContext context ) {
860+ Pair <List <RexNode >, List <AggCall >> resolved =
861+ resolveAttributesForAggregation (groupExprList , aggExprList , context );
862+ List <RexNode > resolvedGroupByList = resolved .getLeft ();
863+ List <AggCall > resolvedAggCallList = resolved .getRight ();
864+
865+ // `doc_count` optimization required a filter `isNotNull(RexInputRef)` for the
866+ // `count(FIELD)` aggregation which only can be applied to single FIELD without grouping:
867+ //
868+ // Example 1: source=t | stats count(a)
869+ // Before: Aggregate(count(a))
870+ // \- Scan t
871+ // After: Aggregate(count(a))
872+ // \- Filter(isNotNull(a))
873+ // \- Scan t
874+ //
875+ // Example 2: source=t | stats count(a), count(a)
876+ // Before: Aggregate(count(a), count(a))
877+ // \- Scan t
878+ // After: Aggregate(count(a), count(a))
879+ // \- Filter(isNotNull(a))
880+ // \- Scan t
881+ //
882+ // Example 3: source=t | stats count(a) by b
883+ // Before & After: Aggregate(count(a) by b)
884+ // \- Scan t
885+ //
886+ // Example 4: source=t | stats count()
887+ // Before & After: Aggregate(count())
888+ // \- Scan t
889+ //
890+ // Example 5: source=t | stats count(), count(a)
891+ // Before & After: Aggregate(count(), count(a))
892+ // \- Scan t
893+ //
894+ // Example 6: source=t | stats count(a), count(b)
895+ // Before & After: Aggregate(count(a), count(b))
896+ // \- Scan t
897+ //
898+ // Example 7: source=t | stats count(a+1)
899+ // Before & After: Aggregate(count(a+1))
900+ // \- Scan t
901+ if (resolvedGroupByList .isEmpty ()) {
902+ List <List <RexInputRef >> refsPerCount = extractInputRefList (resolvedAggCallList );
903+ List <RexInputRef > distinctRefsOfCounts ;
904+ if (context .relBuilder .peek () instanceof org .apache .calcite .rel .core .Project project ) {
905+ List <RexNode > mappedInProject =
906+ refsPerCount .stream ()
907+ .flatMap (List ::stream )
908+ .map (ref -> project .getProjects ().get (ref .getIndex ()))
909+ .toList ();
910+ if (mappedInProject .stream ().allMatch (RexInputRef .class ::isInstance )) {
911+ distinctRefsOfCounts =
912+ mappedInProject .stream ().map (RexInputRef .class ::cast ).distinct ().toList ();
913+ } else {
914+ distinctRefsOfCounts = List .of ();
915+ }
916+ } else {
917+ distinctRefsOfCounts = refsPerCount .stream ().flatMap (List ::stream ).distinct ().toList ();
918+ }
919+ if (distinctRefsOfCounts .size () == 1 && refsPerCount .stream ().noneMatch (List ::isEmpty )) {
920+ context .relBuilder .filter (context .relBuilder .isNotNull (distinctRefsOfCounts .getFirst ()));
921+ }
922+ }
923+
924+ // Add project before aggregate:
925+ //
829926 // Example 1: source=t | where a > 1 | stats avg(b + 1) by c
830927 // Before: Aggregate(avg(b + 1))
831928 // \- Filter(a > 1)
@@ -836,23 +933,22 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
836933 // \- Scan t
837934 //
838935 // Example 2: source=t | where a > 1 | top b by c
839- // Before: Aggregate(count)
840- // \-Filter(a > 1)
936+ // Before: Aggregate(count(b) by c )
937+ // \-Filter(a > 1 && isNotNull(b) )
841938 // \- Scan t
842- // After: Aggregate(count)
939+ // After: Aggregate(count(b) by c )
843940 // \- Project([c, b])
844- // \- Filter(a > 1)
941+ // \- Filter(a > 1 && isNotNull(b) )
845942 // \- Scan t
846- // Example 3: source=t | stats count(): no project added for count()
847- // Before: Aggregate(count)
943+ //
944+ // Example 3: source=t | stats count(): no change for count()
945+ // Before: Aggregate(count())
848946 // \- Scan t
849- // After: Aggregate(count)
947+ // After: Aggregate(count() )
850948 // \- Scan t
851- Pair <List <RexNode >, List <AggCall >> resolved =
852- resolveAttributesForAggregation (groupExprList , aggExprList , context );
853949 List <RexInputRef > trimmedRefs = new ArrayList <>();
854- trimmedRefs .addAll (PlanUtils .getInputRefs (resolved . getLeft () )); // group-by keys first
855- trimmedRefs .addAll (PlanUtils .getInputRefsFromAggCall (resolved . getRight () ));
950+ trimmedRefs .addAll (PlanUtils .getInputRefs (resolvedGroupByList )); // group-by keys first
951+ trimmedRefs .addAll (PlanUtils .getInputRefsFromAggCall (resolvedAggCallList ));
856952 context .relBuilder .project (trimmedRefs );
857953
858954 // Re-resolve all attributes based on adding trimmed Project.
@@ -2258,7 +2354,7 @@ private void flattenParsedPattern(
22582354 String originalPatternResultAlias ,
22592355 RexNode parsedNode ,
22602356 CalcitePlanContext context ,
2261- boolean flattenPatternCount ) {
2357+ boolean flattenPatternAggResult ) {
22622358 List <RexNode > fattenedNodes = new ArrayList <>();
22632359 List <String > projectNames = new ArrayList <>();
22642360 // Flatten map struct fields
@@ -2274,7 +2370,7 @@ private void flattenParsedPattern(
22742370 true );
22752371 fattenedNodes .add (context .relBuilder .alias (patternExpr , originalPatternResultAlias ));
22762372 projectNames .add (originalPatternResultAlias );
2277- if (flattenPatternCount ) {
2373+ if (flattenPatternAggResult ) {
22782374 RexNode patternCountExpr =
22792375 context .rexBuilder .makeCast (
22802376 context .rexBuilder .getTypeFactory ().createSqlType (SqlTypeName .BIGINT ),
@@ -2300,6 +2396,24 @@ private void flattenParsedPattern(
23002396 true );
23012397 fattenedNodes .add (context .relBuilder .alias (tokensExpr , PatternUtils .TOKENS ));
23022398 projectNames .add (PatternUtils .TOKENS );
2399+ if (flattenPatternAggResult ) {
2400+ RexNode sampleLogsExpr =
2401+ context .rexBuilder .makeCast (
2402+ context
2403+ .rexBuilder
2404+ .getTypeFactory ()
2405+ .createArrayType (
2406+ context .rexBuilder .getTypeFactory ().createSqlType (SqlTypeName .VARCHAR ), -1 ),
2407+ PPLFuncImpTable .INSTANCE .resolve (
2408+ context .rexBuilder ,
2409+ BuiltinFunctionName .INTERNAL_ITEM ,
2410+ parsedNode ,
2411+ context .rexBuilder .makeLiteral (PatternUtils .SAMPLE_LOGS )),
2412+ true ,
2413+ true );
2414+ fattenedNodes .add (context .relBuilder .alias (sampleLogsExpr , PatternUtils .SAMPLE_LOGS ));
2415+ projectNames .add (PatternUtils .SAMPLE_LOGS );
2416+ }
23032417 projectPlusOverriding (fattenedNodes , projectNames , context );
23042418 }
23052419
0 commit comments