Skip to content

Commit 84dccbf

Browse files
committed
Support nested aggregation
Signed-off-by: Lantao Jin <ltjin@amazon.com>
1 parent f241f34 commit 84dccbf

File tree

14 files changed

+492
-19
lines changed

14 files changed

+492
-19
lines changed

core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import java.util.List;
1313
import java.util.stream.Collectors;
1414
import org.opensearch.sql.planner.logical.LogicalPlan;
15+
import org.opensearch.sql.planner.optimizer.rule.EliminateNested;
1516
import org.opensearch.sql.planner.optimizer.rule.MergeFilterAndFilter;
1617
import org.opensearch.sql.planner.optimizer.rule.PushFilterUnderSort;
1718
import org.opensearch.sql.planner.optimizer.rule.read.CreateTableScanBuilder;
@@ -58,7 +59,11 @@ public static LogicalPlanOptimizer create() {
5859
TableScanPushDown.PUSH_DOWN_HIGHLIGHT,
5960
TableScanPushDown.PUSH_DOWN_NESTED,
6061
TableScanPushDown.PUSH_DOWN_PROJECT,
61-
new CreateTableWriteBuilder()));
62+
new CreateTableWriteBuilder(),
63+
/*
64+
* Phase 3: Transformations for others
65+
*/
66+
new EliminateNested()));
6267
}
6368

6469
/** Optimize {@link LogicalPlan}. */
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.planner.optimizer.rule;
7+
8+
import static com.facebook.presto.matching.Pattern.typeOf;
9+
import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source;
10+
11+
import com.facebook.presto.matching.Capture;
12+
import com.facebook.presto.matching.Captures;
13+
import com.facebook.presto.matching.Pattern;
14+
import lombok.Getter;
15+
import lombok.experimental.Accessors;
16+
import org.opensearch.sql.planner.logical.LogicalAggregation;
17+
import org.opensearch.sql.planner.logical.LogicalNested;
18+
import org.opensearch.sql.planner.logical.LogicalPlan;
19+
import org.opensearch.sql.planner.optimizer.Rule;
20+
21+
/**
22+
* Eliminate LogicalNested if its child is LogicalAggregation.<br>
23+
* LogicalNested - LogicalAggregation - Child --> LogicalAggregation - Child<br>
24+
* E.g. count(nested(foo.bar, foo))
25+
*/
26+
public class EliminateNested implements Rule<LogicalNested> {
27+
28+
private final Capture<LogicalAggregation> capture;
29+
30+
@Accessors(fluent = true)
31+
@Getter
32+
private final Pattern<LogicalNested> pattern;
33+
34+
public EliminateNested() {
35+
this.capture = Capture.newCapture();
36+
this.pattern =
37+
typeOf(LogicalNested.class)
38+
.with(source().matching(typeOf(LogicalAggregation.class).capturedAs(capture)));
39+
}
40+
41+
@Override
42+
public LogicalPlan apply(LogicalNested plan, Captures captures) {
43+
return captures.get(capture);
44+
}
45+
}

core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.sql.planner.optimizer;
77

8+
import static java.util.Collections.emptyList;
89
import static org.junit.jupiter.api.Assertions.assertEquals;
910
import static org.junit.jupiter.api.Assertions.assertThrows;
1011
import static org.mockito.ArgumentMatchers.any;
@@ -122,6 +123,43 @@ void multiple_filter_should_eventually_be_merged() {
122123
DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))))));
123124
}
124125

126+
@Test
127+
void eliminate_nested_in_aggregation() {
128+
List<Map<String, ReferenceExpression>> nestedArgs =
129+
ImmutableList.of(
130+
Map.of(
131+
"field", new ReferenceExpression("message.info", STRING),
132+
"path", new ReferenceExpression("message", STRING)));
133+
List<NamedExpression> projectList =
134+
ImmutableList.of(
135+
DSL.named(
136+
"count(nested(message.info, message))",
137+
DSL.ref("count(nested(message.info, message))", INTEGER)));
138+
139+
assertEquals(
140+
aggregation(
141+
tableScanBuilder,
142+
ImmutableList.of(
143+
DSL.named(
144+
"count(nested(message.info, message))",
145+
DSL.count(
146+
DSL.nested(DSL.ref("message.info", STRING), DSL.ref("message", ARRAY))))),
147+
emptyList()),
148+
optimize(
149+
nested(
150+
aggregation(
151+
relation("schema", table),
152+
ImmutableList.of(
153+
DSL.named(
154+
"count(nested(message.info, message))",
155+
DSL.count(
156+
DSL.nested(
157+
DSL.ref("message.info", STRING), DSL.ref("message", ARRAY))))),
158+
emptyList()),
159+
nestedArgs,
160+
projectList)));
161+
}
162+
125163
@Test
126164
void default_table_scan_builder_should_not_push_down_anything() {
127165
LogicalPlan[] plans = {

docs/user/dql/aggregations.rst

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ The aggregation could has expression as arguments::
126126
| M | 202 |
127127
+----------+--------+
128128

129-
COUNT Aggregations
129+
COUNT Aggregation
130130
------------------
131131

132132
Besides regular identifiers, ``COUNT`` aggregate function also accepts arguments such as ``*`` or literals like ``1``. The meaning of these different forms are as follows:
@@ -135,6 +135,30 @@ Besides regular identifiers, ``COUNT`` aggregate function also accepts arguments
135135
2. ``COUNT(*)`` will count the number of all its input rows.
136136
3. ``COUNT(1)`` is same as ``COUNT(*)`` because any non-null literal will count.
137137

138+
NESTED Aggregation
139+
------------------
140+
The nested aggregation lets you aggregate on fields inside a nested object. You can use ``nested`` function to return a nested field, ref :ref:`nested function <nested_function_label>`.
141+
142+
To understand why we need nested aggregations, read `Nested Aggregations DSL doc <https://opensearch.org/docs/latest/aggregations/bucket/nested/>`_ to get more details.
143+
144+
The nested aggregation could be expression::
145+
146+
os> SELECT count(nested(message.info, message)) FROM nested;
147+
fetched rows / total rows = 1/1
148+
+----------------------------------------+
149+
| count(nested(message.info, message)) |
150+
|----------------------------------------|
151+
| 2 |
152+
+----------------------------------------+
153+
154+
os> SELECT count(nested(message.info)) FROM nested;
155+
fetched rows / total rows = 1/1
156+
+-------------------------------+
157+
| count(nested(message.info)) |
158+
|-------------------------------|
159+
| 2 |
160+
+-------------------------------+
161+
138162
Aggregation Functions
139163
=====================
140164

docs/user/dql/functions.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4419,6 +4419,7 @@ Another example to show how to set custom values for the optional parameters::
44194419
+-------------------------------------------+
44204420

44214421

4422+
.. _nested_function_label:
44224423
NESTED
44234424
------
44244425

integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import org.json.JSONArray;
2222
import org.json.JSONObject;
2323
import org.junit.Test;
24-
import org.junit.jupiter.api.Disabled;
2524
import org.opensearch.sql.legacy.SQLIntegTestCase;
2625

2726
public class NestedIT extends SQLIntegTestCase {
@@ -75,20 +74,18 @@ public void nested_function_in_select_test() {
7574
rows("zz", "bb", 6));
7675
}
7776

78-
// Has to be tested with JSON format when https://github.com/opensearch-project/sql/issues/1317
79-
// gets resolved
80-
@Disabled // TODO fix me when aggregation is supported
77+
@Test
8178
public void nested_function_in_an_aggregate_function_in_select_test() {
8279
String query =
83-
"SELECT sum(nested(message.dayOfWeek)) FROM " + TEST_INDEX_NESTED_TYPE_WITHOUT_ARRAYS;
80+
"SELECT sum(nested(message.dayOfWeek, message)) FROM "
81+
+ TEST_INDEX_NESTED_TYPE_WITHOUT_ARRAYS;
8482
JSONObject result = executeJdbcRequest(query);
8583
verifyDataRows(result, rows(14));
8684
}
8785

88-
// TODO Enable me when nested aggregation is supported
89-
@Disabled
86+
@Test
9087
public void nested_function_with_arrays_in_an_aggregate_function_in_select_test() {
91-
String query = "SELECT sum(nested(message.dayOfWeek)) FROM " + TEST_INDEX_NESTED_TYPE;
88+
String query = "SELECT sum(nested(message.dayOfWeek, message)) FROM " + TEST_INDEX_NESTED_TYPE;
9289
JSONObject result = executeJdbcRequest(query);
9390
verifyDataRows(result, rows(19));
9491
}

opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import lombok.RequiredArgsConstructor;
2222
import org.opensearch.search.aggregations.Aggregation;
2323
import org.opensearch.search.aggregations.Aggregations;
24+
import org.opensearch.search.aggregations.bucket.nested.InternalNested;
2425
import org.opensearch.sql.common.utils.StringUtils;
2526

2627
/** Parse multiple metrics in one bucket. */
@@ -44,6 +45,9 @@ public MetricParserHelper(List<MetricParser> metricParserList) {
4445
public Map<String, Object> parse(Aggregations aggregations) {
4546
Map<String, Object> resultMap = new HashMap<>();
4647
for (Aggregation aggregation : aggregations) {
48+
if (aggregation instanceof InternalNested) {
49+
aggregation = ((InternalNested) aggregation).getAggregations().asList().getFirst();
50+
}
4751
if (metricParserMap.containsKey(aggregation.getName())) {
4852
resultMap.putAll(metricParserMap.get(aggregation.getName()).parse(aggregation));
4953
} else {

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ public Integer getMaxResultWindow() {
155155
@Override
156156
public PhysicalPlan implement(LogicalPlan plan) {
157157
// TODO: Leave it here to avoid impact Prometheus and AD operators. Need to move to Planner.
158-
return plan.accept(new OpenSearchDefaultImplementor(client), null);
158+
PhysicalPlan pp = plan.accept(new OpenSearchDefaultImplementor(client), null);
159+
return pp;
159160
}
160161

161162
@Override

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/AggregationBuilderHelper.java

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,17 @@
99
import static org.opensearch.script.Script.DEFAULT_SCRIPT_TYPE;
1010
import static org.opensearch.sql.opensearch.storage.script.ExpressionScriptEngine.EXPRESSION_LANG_NAME;
1111

12+
import java.util.List;
1213
import java.util.function.Function;
1314
import lombok.RequiredArgsConstructor;
1415
import org.opensearch.script.Script;
16+
import org.opensearch.search.aggregations.AggregationBuilder;
17+
import org.opensearch.search.aggregations.AggregationBuilders;
1518
import org.opensearch.sql.expression.Expression;
1619
import org.opensearch.sql.expression.FunctionExpression;
1720
import org.opensearch.sql.expression.LiteralExpression;
1821
import org.opensearch.sql.expression.ReferenceExpression;
22+
import org.opensearch.sql.expression.function.BuiltinFunctionName;
1923
import org.opensearch.sql.opensearch.data.type.OpenSearchTextType;
2024
import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer;
2125

@@ -25,18 +29,59 @@ public class AggregationBuilderHelper {
2529

2630
private final ExpressionSerializer serializer;
2731

32+
/** Build Composite Builder from Expression. */
33+
public <T> T buildComposite(
34+
Expression expression, Function<String, T> fieldBuilder, Function<Script, T> scriptBuilder) {
35+
if (expression instanceof ReferenceExpression) {
36+
String fieldName = ((ReferenceExpression) expression).getAttr();
37+
return fieldBuilder.apply(
38+
OpenSearchTextType.convertTextToKeyword(fieldName, expression.type()));
39+
} else if (expression instanceof FunctionExpression
40+
|| expression instanceof LiteralExpression) {
41+
return scriptBuilder.apply(
42+
new Script(
43+
DEFAULT_SCRIPT_TYPE,
44+
EXPRESSION_LANG_NAME,
45+
serializer.serialize(expression),
46+
emptyMap()));
47+
} else {
48+
throw new IllegalStateException(
49+
String.format("bucket aggregation doesn't support " + "expression %s", expression));
50+
}
51+
}
52+
2853
/**
2954
* Build AggregationBuilder from Expression.
3055
*
3156
* @param expression Expression
3257
* @return AggregationBuilder
3358
*/
34-
public <T> T build(
35-
Expression expression, Function<String, T> fieldBuilder, Function<Script, T> scriptBuilder) {
59+
public AggregationBuilder build(
60+
Expression expression,
61+
Function<String, AggregationBuilder> fieldBuilder,
62+
Function<Script, AggregationBuilder> scriptBuilder) {
3663
if (expression instanceof ReferenceExpression) {
3764
String fieldName = ((ReferenceExpression) expression).getAttr();
3865
return fieldBuilder.apply(
3966
OpenSearchTextType.convertTextToKeyword(fieldName, expression.type()));
67+
} else if (expression instanceof FunctionExpression
68+
&& ((FunctionExpression) expression)
69+
.getFunctionName()
70+
.equals(BuiltinFunctionName.NESTED.getName())) {
71+
List<Expression> args = ((FunctionExpression) expression).getArguments();
72+
// NestedAnalyzer has validated the number of arguments.
73+
// Here we can safety invoke args.getFirst().
74+
String fieldName = ((ReferenceExpression) args.getFirst()).getAttr();
75+
if (fieldName.contains("*")) {
76+
throw new IllegalArgumentException("Nested aggregation doesn't support multiple fields");
77+
}
78+
String path =
79+
args.size() == 2
80+
? ((ReferenceExpression) args.get(1)).getAttr()
81+
: fieldName.substring(0, fieldName.lastIndexOf("."));
82+
AggregationBuilder subAgg =
83+
fieldBuilder.apply(OpenSearchTextType.convertTextToKeyword(fieldName, expression.type()));
84+
return AggregationBuilders.nested(path + "_nested", path).subAggregation(subAgg);
4085
} else if (expression instanceof FunctionExpression
4186
|| expression instanceof LiteralExpression) {
4287
return scriptBuilder.apply(

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilder.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ private CompositeValuesSourceBuilder<?> buildCompositeValuesSourceBuilder(
6868
if (List.of(TIMESTAMP, TIME, DATE).contains(expr.getDelegated().type())) {
6969
sourceBuilder.userValuetypeHint(ValueType.LONG);
7070
}
71-
return helper.build(expr.getDelegated(), sourceBuilder::field, sourceBuilder::script);
71+
return helper.buildComposite(
72+
expr.getDelegated(), sourceBuilder::field, sourceBuilder::script);
7273
}
7374
}
7475

0 commit comments

Comments
 (0)