Skip to content

Commit 1329e02

Browse files
authored
Add UT for PredicateAnalyzer and AggregateAnalyzer (#3612)
* Add UT for PredicateAnalyzer and AggregateAnalyzer Signed-off-by: Heng Qian <qianheng@amazon.com> * Add copyright Signed-off-by: Heng Qian <qianheng@amazon.com> --------- Signed-off-by: Heng Qian <qianheng@amazon.com>
1 parent eadeca2 commit 1329e02

6 files changed

Lines changed: 725 additions & 0 deletions

File tree

opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ private static Pair<Builder, List<MetricParser>> processAggregateCalls(
147147
List<AggregateCall> aggCalls,
148148
FieldExpressionCreator fieldExpressionCreator,
149149
List<String> outputFields) {
150+
assert aggCalls.size() + groupOffset == outputFields.size()
151+
: "groups size and agg calls size should match with output fields";
150152
Builder metricBuilder = new AggregatorFactories.Builder();
151153
List<MetricParser> metricParserList = new ArrayList<>();
152154

opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
import java.util.Map;
2020
import java.util.stream.Collectors;
2121
import lombok.EqualsAndHashCode;
22+
import lombok.Getter;
2223
import org.opensearch.search.aggregations.Aggregations;
2324
import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation;
2425

2526
/** Composite Aggregation Parser which include composite aggregation and metric parsers. */
27+
@Getter
2628
@EqualsAndHashCode
2729
public class CompositeAggregationParser implements OpenSearchAggregationResponseParser {
2830

opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
import java.util.Map;
1919
import java.util.stream.Collectors;
2020
import lombok.EqualsAndHashCode;
21+
import lombok.Getter;
2122
import lombok.RequiredArgsConstructor;
2223
import org.opensearch.search.aggregations.Aggregation;
2324
import org.opensearch.search.aggregations.Aggregations;
2425
import org.opensearch.sql.common.utils.StringUtils;
2526

2627
/** Parse multiple metrics in one bucket. */
28+
@Getter
2729
@EqualsAndHashCode
2830
@RequiredArgsConstructor
2931
public class MetricParserHelper {

opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/NoBucketAggregationParser.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
import java.util.Collections;
1818
import java.util.List;
1919
import java.util.Map;
20+
import lombok.Getter;
2021
import org.opensearch.search.aggregations.Aggregations;
2122

2223
/** No Bucket Aggregation Parser which include only metric parsers. */
24+
@Getter
2325
public class NoBucketAggregationParser implements OpenSearchAggregationResponseParser {
2426

2527
private final MetricParserHelper metricsParser;
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.opensearch.request;
7+
8+
import static org.junit.jupiter.api.Assertions.assertEquals;
9+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
10+
import static org.junit.jupiter.api.Assertions.assertThrows;
11+
import static org.junit.jupiter.api.Assertions.assertTrue;
12+
import static org.mockito.Mockito.mock;
13+
import static org.mockito.Mockito.when;
14+
15+
import com.google.common.collect.ImmutableList;
16+
import java.util.List;
17+
import java.util.Map;
18+
import org.apache.calcite.rel.RelCollations;
19+
import org.apache.calcite.rel.core.Aggregate;
20+
import org.apache.calcite.rel.core.AggregateCall;
21+
import org.apache.calcite.rel.type.RelDataTypeFactory;
22+
import org.apache.calcite.rel.type.RelDataTypeSystem;
23+
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
24+
import org.apache.calcite.sql.type.SqlTypeFactoryImpl;
25+
import org.apache.calcite.sql.type.SqlTypeName;
26+
import org.apache.calcite.util.ImmutableBitSet;
27+
import org.apache.commons.lang3.tuple.Pair;
28+
import org.junit.jupiter.api.Test;
29+
import org.opensearch.search.aggregations.AggregationBuilder;
30+
import org.opensearch.sql.opensearch.data.type.OpenSearchDataType;
31+
import org.opensearch.sql.opensearch.data.type.OpenSearchDataType.MappingType;
32+
import org.opensearch.sql.opensearch.request.AggregateAnalyzer.ExpressionNotAnalyzableException;
33+
import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser;
34+
import org.opensearch.sql.opensearch.response.agg.MetricParserHelper;
35+
import org.opensearch.sql.opensearch.response.agg.NoBucketAggregationParser;
36+
import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser;
37+
import org.opensearch.sql.opensearch.response.agg.SingleValueParser;
38+
import org.opensearch.sql.opensearch.response.agg.StatsParser;
39+
40+
class AggregateAnalyzerTest {
41+
42+
private final RelDataTypeFactory typeFactory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
43+
private final List<String> schema = List.of("a", "b", "c");
44+
final Map<String, OpenSearchDataType> typeMapping =
45+
Map.of(
46+
"a",
47+
OpenSearchDataType.of(MappingType.Integer),
48+
"b",
49+
OpenSearchDataType.of(
50+
MappingType.Text, Map.of("fields", Map.of("keyword", Map.of("type", "keyword")))),
51+
"c",
52+
OpenSearchDataType.of(MappingType.Text)); // Text without keyword cannot be push down
53+
54+
@Test
55+
void analyze_aggCall_simple() throws ExpressionNotAnalyzableException {
56+
AggregateCall countCall =
57+
AggregateCall.create(
58+
SqlStdOperatorTable.COUNT,
59+
false,
60+
false,
61+
false,
62+
ImmutableList.of(),
63+
ImmutableList.of(),
64+
-1,
65+
null,
66+
RelCollations.EMPTY,
67+
typeFactory.createSqlType(SqlTypeName.INTEGER),
68+
"cnt");
69+
AggregateCall avgCall =
70+
AggregateCall.create(
71+
SqlStdOperatorTable.AVG,
72+
false,
73+
false,
74+
false,
75+
ImmutableList.of(),
76+
ImmutableList.of(0),
77+
-1,
78+
null,
79+
RelCollations.EMPTY,
80+
typeFactory.createSqlType(SqlTypeName.INTEGER),
81+
"avg");
82+
AggregateCall sumCall =
83+
AggregateCall.create(
84+
SqlStdOperatorTable.SUM,
85+
false,
86+
false,
87+
false,
88+
ImmutableList.of(),
89+
ImmutableList.of(0),
90+
-1,
91+
null,
92+
RelCollations.EMPTY,
93+
typeFactory.createSqlType(SqlTypeName.INTEGER),
94+
"sum");
95+
AggregateCall minCall =
96+
AggregateCall.create(
97+
SqlStdOperatorTable.MIN,
98+
false,
99+
false,
100+
false,
101+
ImmutableList.of(),
102+
ImmutableList.of(0),
103+
-1,
104+
null,
105+
RelCollations.EMPTY,
106+
typeFactory.createSqlType(SqlTypeName.INTEGER),
107+
"min");
108+
AggregateCall maxCall =
109+
AggregateCall.create(
110+
SqlStdOperatorTable.MAX,
111+
false,
112+
false,
113+
false,
114+
ImmutableList.of(),
115+
ImmutableList.of(0),
116+
-1,
117+
null,
118+
RelCollations.EMPTY,
119+
typeFactory.createSqlType(SqlTypeName.INTEGER),
120+
"max");
121+
122+
List<String> outputFields = List.of("cnt", "avg", "sum", "min", "max");
123+
Aggregate aggregate =
124+
createMockAggregate(
125+
List.of(countCall, avgCall, sumCall, minCall, maxCall), ImmutableBitSet.of());
126+
Pair<List<AggregationBuilder>, OpenSearchAggregationResponseParser> result =
127+
AggregateAnalyzer.analyze(aggregate, schema, typeMapping, outputFields);
128+
assertEquals(
129+
"[{\"cnt\":{\"value_count\":{\"field\":\"_index\"}}},"
130+
+ " {\"avg\":{\"avg\":{\"field\":\"a\"}}},"
131+
+ " {\"sum\":{\"sum\":{\"field\":\"a\"}}},"
132+
+ " {\"min\":{\"min\":{\"field\":\"a\"}}},"
133+
+ " {\"max\":{\"max\":{\"field\":\"a\"}}}]",
134+
result.getLeft().toString());
135+
assertInstanceOf(NoBucketAggregationParser.class, result.getRight());
136+
MetricParserHelper metricsParser =
137+
((NoBucketAggregationParser) result.getRight()).getMetricsParser();
138+
assertEquals(5, metricsParser.getMetricParserMap().size());
139+
metricsParser
140+
.getMetricParserMap()
141+
.forEach(
142+
(k, v) -> {
143+
assertTrue(outputFields.contains(k));
144+
assertInstanceOf(SingleValueParser.class, v);
145+
});
146+
}
147+
148+
@Test
149+
void analyze_aggCall_extended() throws ExpressionNotAnalyzableException {
150+
AggregateCall varSampCall =
151+
AggregateCall.create(
152+
SqlStdOperatorTable.VAR_SAMP,
153+
false,
154+
false,
155+
false,
156+
ImmutableList.of(),
157+
ImmutableList.of(0),
158+
-1,
159+
null,
160+
RelCollations.EMPTY,
161+
typeFactory.createSqlType(SqlTypeName.INTEGER),
162+
"var_samp");
163+
AggregateCall varPopCall =
164+
AggregateCall.create(
165+
SqlStdOperatorTable.VAR_POP,
166+
false,
167+
false,
168+
false,
169+
ImmutableList.of(),
170+
ImmutableList.of(0),
171+
-1,
172+
null,
173+
RelCollations.EMPTY,
174+
typeFactory.createSqlType(SqlTypeName.INTEGER),
175+
"var_pop");
176+
AggregateCall stddevSampCall =
177+
AggregateCall.create(
178+
SqlStdOperatorTable.STDDEV_SAMP,
179+
false,
180+
false,
181+
false,
182+
ImmutableList.of(),
183+
ImmutableList.of(0),
184+
-1,
185+
null,
186+
RelCollations.EMPTY,
187+
typeFactory.createSqlType(SqlTypeName.INTEGER),
188+
"stddev_samp");
189+
AggregateCall stddevPopCall =
190+
AggregateCall.create(
191+
SqlStdOperatorTable.STDDEV_SAMP,
192+
false,
193+
false,
194+
false,
195+
ImmutableList.of(),
196+
ImmutableList.of(0),
197+
-1,
198+
null,
199+
RelCollations.EMPTY,
200+
typeFactory.createSqlType(SqlTypeName.INTEGER),
201+
"stddev_pop");
202+
List<String> outputFields = List.of("var_samp", "var_pop", "stddev_samp", "stddev_pop");
203+
Aggregate aggregate =
204+
createMockAggregate(
205+
List.of(varSampCall, varPopCall, stddevSampCall, stddevPopCall), ImmutableBitSet.of());
206+
Pair<List<AggregationBuilder>, OpenSearchAggregationResponseParser> result =
207+
AggregateAnalyzer.analyze(aggregate, schema, typeMapping, outputFields);
208+
assertEquals(
209+
"[{\"var_samp\":{\"extended_stats\":{\"field\":\"a\",\"sigma\":2.0}}},"
210+
+ " {\"var_pop\":{\"extended_stats\":{\"field\":\"a\",\"sigma\":2.0}}},"
211+
+ " {\"stddev_samp\":{\"extended_stats\":{\"field\":\"a\",\"sigma\":2.0}}},"
212+
+ " {\"stddev_pop\":{\"extended_stats\":{\"field\":\"a\",\"sigma\":2.0}}}]",
213+
result.getLeft().toString());
214+
assertInstanceOf(NoBucketAggregationParser.class, result.getRight());
215+
MetricParserHelper metricsParser =
216+
((NoBucketAggregationParser) result.getRight()).getMetricsParser();
217+
assertEquals(4, metricsParser.getMetricParserMap().size());
218+
metricsParser
219+
.getMetricParserMap()
220+
.forEach(
221+
(k, v) -> {
222+
assertTrue(outputFields.contains(k));
223+
assertInstanceOf(StatsParser.class, v);
224+
});
225+
}
226+
227+
@Test
228+
void analyze_groupBy() throws ExpressionNotAnalyzableException {
229+
AggregateCall aggCall =
230+
AggregateCall.create(
231+
SqlStdOperatorTable.COUNT,
232+
false,
233+
false,
234+
false,
235+
ImmutableList.of(),
236+
ImmutableList.of(),
237+
-1,
238+
null,
239+
RelCollations.EMPTY,
240+
typeFactory.createSqlType(SqlTypeName.INTEGER),
241+
"cnt");
242+
List<String> outputFields = List.of("a", "b", "cnt");
243+
Aggregate aggregate = createMockAggregate(List.of(aggCall), ImmutableBitSet.of(0, 1));
244+
Pair<List<AggregationBuilder>, OpenSearchAggregationResponseParser> result =
245+
AggregateAnalyzer.analyze(aggregate, schema, typeMapping, outputFields);
246+
247+
assertEquals(
248+
"[{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":["
249+
+ "{\"a\":{\"terms\":{\"field\":\"a\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},"
250+
+ "{\"b\":{\"terms\":{\"field\":\"b.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},"
251+
+ "\"aggregations\":{\"cnt\":{\"value_count\":{\"field\":\"_index\"}}}}}]",
252+
result.getLeft().toString());
253+
assertInstanceOf(CompositeAggregationParser.class, result.getRight());
254+
MetricParserHelper metricsParser =
255+
((CompositeAggregationParser) result.getRight()).getMetricsParser();
256+
assertEquals(1, metricsParser.getMetricParserMap().size());
257+
metricsParser
258+
.getMetricParserMap()
259+
.forEach(
260+
(k, v) -> {
261+
assertTrue(outputFields.contains(k));
262+
assertInstanceOf(SingleValueParser.class, v);
263+
});
264+
}
265+
266+
@Test
267+
void analyze_aggCall_TextWithoutKeyword() {
268+
AggregateCall aggCall =
269+
AggregateCall.create(
270+
SqlStdOperatorTable.SUM,
271+
false,
272+
false,
273+
false,
274+
ImmutableList.of(),
275+
ImmutableList.of(2),
276+
-1,
277+
null,
278+
RelCollations.EMPTY,
279+
typeFactory.createSqlType(SqlTypeName.INTEGER),
280+
"sum");
281+
Aggregate aggregate = createMockAggregate(List.of(aggCall), ImmutableBitSet.of());
282+
ExpressionNotAnalyzableException exception =
283+
assertThrows(
284+
ExpressionNotAnalyzableException.class,
285+
() -> AggregateAnalyzer.analyze(aggregate, schema, typeMapping, List.of("sum")));
286+
assertEquals("[field] must not be null: [sum]", exception.getCause().getMessage());
287+
}
288+
289+
@Test
290+
void analyze_groupBy_TextWithoutKeyword() {
291+
AggregateCall aggCall =
292+
AggregateCall.create(
293+
SqlStdOperatorTable.COUNT,
294+
false,
295+
false,
296+
false,
297+
ImmutableList.of(),
298+
ImmutableList.of(),
299+
-1,
300+
null,
301+
RelCollations.EMPTY,
302+
typeFactory.createSqlType(SqlTypeName.INTEGER),
303+
"cnt");
304+
List<String> outputFields = List.of("c", "cnt");
305+
Aggregate aggregate = createMockAggregate(List.of(aggCall), ImmutableBitSet.of(2));
306+
ExpressionNotAnalyzableException exception =
307+
assertThrows(
308+
ExpressionNotAnalyzableException.class,
309+
() -> AggregateAnalyzer.analyze(aggregate, schema, typeMapping, outputFields));
310+
assertEquals("[field] must not be null", exception.getCause().getMessage());
311+
}
312+
313+
private Aggregate createMockAggregate(List<AggregateCall> calls, ImmutableBitSet groups) {
314+
Aggregate agg = mock(Aggregate.class);
315+
when(agg.getGroupSet()).thenReturn(groups);
316+
when(agg.getAggCallList()).thenReturn(calls);
317+
return agg;
318+
}
319+
}

0 commit comments

Comments
 (0)