Skip to content

Commit db601d7

Browse files
committed
ESQL: Check a few toStrings for aggs
Adds `toString` checking for aggregators to the generic aggs test cases so we can make sure they spit out sensible looking results. We have this for scalar functions but it isn't plugged in for aggs and I noticed it while working on elastic#132603 where I stuck `asdf` for the toString thinking I'd fix it when the test failed. It didn't. There's to many changes to grab this in one go so I've made a hook that tests can opt into. We'll drop the hook once everything has opted into it.
1 parent 2f68ab1 commit db601d7

File tree

4 files changed

+86
-28
lines changed

4 files changed

+86
-28
lines changed

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@
4242
import java.util.stream.IntStream;
4343

4444
import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
45+
import static org.hamcrest.Matchers.endsWith;
4546
import static org.hamcrest.Matchers.equalTo;
4647
import static org.hamcrest.Matchers.instanceOf;
4748
import static org.hamcrest.Matchers.is;
4849
import static org.hamcrest.Matchers.lessThan;
4950
import static org.hamcrest.Matchers.not;
5051
import static org.hamcrest.Matchers.nullValue;
5152
import static org.hamcrest.Matchers.oneOf;
53+
import static org.hamcrest.Matchers.startsWith;
5254

5355
/**
5456
* Base class for aggregation tests.
@@ -164,6 +166,7 @@ public void testFold() {
164166
private void aggregateSingleMode(Expression expression) {
165167
Object result;
166168
try (var aggregator = aggregator(expression, initialInputChannels(), AggregatorMode.SINGLE)) {
169+
assertAggregatorToString(aggregator);
167170
for (Page inputPage : rows(testCase.getMultiRowFields())) {
168171
try (
169172
BooleanVector noMasking = driverContext().blockFactory().newConstantBooleanVector(true, inputPage.getPositionCount())
@@ -187,6 +190,7 @@ private void aggregateGroupingSingleMode(Expression expression) {
187190
assumeFalse("Grouping aggregations must receive data to check results", pages.isEmpty());
188191

189192
try (var aggregator = groupingAggregator(expression, initialInputChannels(), AggregatorMode.SINGLE)) {
193+
assertAggregatorToString(aggregator);
190194
var groupCount = randomIntBetween(1, 1000);
191195
for (Page inputPage : pages) {
192196
processPageGrouping(aggregator, inputPage, groupCount);
@@ -482,4 +486,43 @@ private void processPageGrouping(GroupingAggregator aggregator, Page inputPage,
482486
}
483487
}
484488
}
489+
490+
private void assertAggregatorToString(Object aggregator) {
491+
if (optIntoToAggregatorToStringChecks() == false) {
492+
return;
493+
}
494+
String expectedStart = switch (aggregator) {
495+
case Aggregator a -> "Aggregator[aggregatorFunction=";
496+
case GroupingAggregator a -> "GroupingAggregator[aggregatorFunction=";
497+
default -> throw new UnsupportedOperationException("can't check toString for [" + aggregator.getClass() + "]");
498+
};
499+
String expectedEnd = switch (aggregator) {
500+
case Aggregator a -> "AggregatorFunction[channels=[0]], mode=SINGLE]";
501+
case GroupingAggregator a -> "GroupingAggregatorFunction[channels=[0]], mode=SINGLE]";
502+
default -> throw new UnsupportedOperationException("can't check toString for [" + aggregator.getClass() + "]");
503+
};
504+
505+
String toString = aggregator.toString();
506+
assertThat(toString, startsWith(expectedStart));
507+
assertThat(toString.substring(expectedStart.length(), toString.length() - expectedEnd.length()), testCase.evaluatorToString());
508+
assertThat(toString, endsWith(expectedEnd));
509+
}
510+
511+
protected boolean optIntoToAggregatorToStringChecks() {
512+
// TODO remove this when everyone has opted in
513+
return false;
514+
}
515+
516+
protected static String standardAggregatorName(String prefix, DataType type) {
517+
String typeName = switch (type) {
518+
case BOOLEAN -> "Boolean";
519+
case KEYWORD, TEXT, VERSION -> "BytesRef";
520+
case DOUBLE -> "Double";
521+
case INTEGER -> "Int";
522+
case IP -> "Ip";
523+
case DATETIME, DATE_NANOS, LONG, UNSIGNED_LONG -> "Long";
524+
default -> throw new UnsupportedOperationException("name for [" + type + "]");
525+
};
526+
return prefix + typeName;
527+
}
485528
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxTests.java

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public static Iterable<Object[]> parameters() {
6060
List.of(DataType.INTEGER),
6161
() -> new TestCaseSupplier.TestCase(
6262
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "field")),
63-
"Max[field=Attribute[channel=0]]",
63+
standardAggregatorName("Max", DataType.INTEGER),
6464
DataType.INTEGER,
6565
equalTo(200)
6666
)
@@ -69,7 +69,7 @@ public static Iterable<Object[]> parameters() {
6969
List.of(DataType.LONG),
7070
() -> new TestCaseSupplier.TestCase(
7171
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.LONG, "field")),
72-
"Max[field=Attribute[channel=0]]",
72+
standardAggregatorName("Max", DataType.LONG),
7373
DataType.LONG,
7474
equalTo(200L)
7575
)
@@ -78,7 +78,7 @@ public static Iterable<Object[]> parameters() {
7878
List.of(DataType.UNSIGNED_LONG),
7979
() -> new TestCaseSupplier.TestCase(
8080
List.of(TestCaseSupplier.TypedData.multiRow(List.of(new BigInteger("200")), DataType.UNSIGNED_LONG, "field")),
81-
"Max[field=Attribute[channel=0]]",
81+
standardAggregatorName("Max", DataType.UNSIGNED_LONG),
8282
DataType.UNSIGNED_LONG,
8383
equalTo(new BigInteger("200"))
8484
)
@@ -87,7 +87,7 @@ public static Iterable<Object[]> parameters() {
8787
List.of(DataType.DOUBLE),
8888
() -> new TestCaseSupplier.TestCase(
8989
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200.), DataType.DOUBLE, "field")),
90-
"Max[field=Attribute[channel=0]]",
90+
standardAggregatorName("Max", DataType.DOUBLE),
9191
DataType.DOUBLE,
9292
equalTo(200.)
9393
)
@@ -96,7 +96,7 @@ public static Iterable<Object[]> parameters() {
9696
List.of(DataType.DATETIME),
9797
() -> new TestCaseSupplier.TestCase(
9898
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.DATETIME, "field")),
99-
"Max[field=Attribute[channel=0]]",
99+
standardAggregatorName("Max", DataType.DATETIME),
100100
DataType.DATETIME,
101101
equalTo(200L)
102102
)
@@ -105,7 +105,7 @@ public static Iterable<Object[]> parameters() {
105105
List.of(DataType.DATE_NANOS),
106106
() -> new TestCaseSupplier.TestCase(
107107
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.DATE_NANOS, "field")),
108-
"Max[field=Attribute[channel=0]]",
108+
standardAggregatorName("Max", DataType.DATE_NANOS),
109109
DataType.DATE_NANOS,
110110
equalTo(200L)
111111
)
@@ -114,7 +114,7 @@ public static Iterable<Object[]> parameters() {
114114
List.of(DataType.BOOLEAN),
115115
() -> new TestCaseSupplier.TestCase(
116116
List.of(TestCaseSupplier.TypedData.multiRow(List.of(true), DataType.BOOLEAN, "field")),
117-
"Max[field=Attribute[channel=0]]",
117+
standardAggregatorName("Max", DataType.BOOLEAN),
118118
DataType.BOOLEAN,
119119
equalTo(true)
120120
)
@@ -129,7 +129,7 @@ public static Iterable<Object[]> parameters() {
129129
"field"
130130
)
131131
),
132-
"Max[field=Attribute[channel=0]]",
132+
standardAggregatorName("Max", DataType.IP),
133133
DataType.IP,
134134
equalTo(new BytesRef(InetAddressPoint.encode(InetAddresses.forString("127.0.0.1"))))
135135
)
@@ -138,7 +138,7 @@ public static Iterable<Object[]> parameters() {
138138
var value = new BytesRef(randomAlphaOfLengthBetween(0, 50));
139139
return new TestCaseSupplier.TestCase(
140140
List.of(TestCaseSupplier.TypedData.multiRow(List.of(value), DataType.KEYWORD, "field")),
141-
"Max[field=Attribute[channel=0]]",
141+
standardAggregatorName("Max", DataType.KEYWORD),
142142
DataType.KEYWORD,
143143
equalTo(value)
144144
);
@@ -147,7 +147,7 @@ public static Iterable<Object[]> parameters() {
147147
var value = new BytesRef(randomAlphaOfLengthBetween(0, 50));
148148
return new TestCaseSupplier.TestCase(
149149
List.of(TestCaseSupplier.TypedData.multiRow(List.of(value), DataType.TEXT, "field")),
150-
"Max[field=Attribute[channel=0]]",
150+
standardAggregatorName("Max", DataType.TEXT),
151151
DataType.KEYWORD,
152152
equalTo(value)
153153
);
@@ -159,7 +159,7 @@ public static Iterable<Object[]> parameters() {
159159
.toBytesRef();
160160
return new TestCaseSupplier.TestCase(
161161
List.of(TestCaseSupplier.TypedData.multiRow(List.of(value), DataType.VERSION, "field")),
162-
"Max[field=Attribute[channel=0]]",
162+
standardAggregatorName("Max", DataType.VERSION),
163163
DataType.VERSION,
164164
equalTo(value)
165165
);
@@ -187,10 +187,15 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier
187187

188188
return new TestCaseSupplier.TestCase(
189189
List.of(fieldTypedData),
190-
"Max[field=Attribute[channel=0]]",
190+
standardAggregatorName("Max", fieldSupplier.type()),
191191
fieldSupplier.type(),
192192
equalTo(expected)
193193
);
194194
});
195195
}
196+
197+
@Override
198+
protected boolean optIntoToAggregatorToStringChecks() {
199+
return true;
200+
}
196201
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianTests.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public static Iterable<Object[]> parameters() {
4747
List.of(DataType.INTEGER),
4848
() -> new TestCaseSupplier.TestCase(
4949
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "number")),
50-
"Median[field=Attribute[channel=0]]",
50+
standardAggregatorName("Percentile", DataType.INTEGER),
5151
DataType.DOUBLE,
5252
equalTo(200.)
5353
)
@@ -56,7 +56,7 @@ public static Iterable<Object[]> parameters() {
5656
List.of(DataType.LONG),
5757
() -> new TestCaseSupplier.TestCase(
5858
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.LONG, "number")),
59-
"Median[field=Attribute[channel=0]]",
59+
standardAggregatorName("Percentile", DataType.LONG),
6060
DataType.DOUBLE,
6161
equalTo(200.)
6262
)
@@ -65,7 +65,7 @@ public static Iterable<Object[]> parameters() {
6565
List.of(DataType.DOUBLE),
6666
() -> new TestCaseSupplier.TestCase(
6767
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200.), DataType.DOUBLE, "number")),
68-
"Median[field=Attribute[channel=0]]",
68+
standardAggregatorName("Percentile", DataType.DOUBLE),
6969
DataType.DOUBLE,
7070
equalTo(200.)
7171
)
@@ -94,11 +94,16 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier
9494

9595
return new TestCaseSupplier.TestCase(
9696
List.of(fieldTypedData),
97-
"Median[number=Attribute[channel=0]]",
97+
standardAggregatorName("Percentile", fieldSupplier.type()),
9898
DataType.DOUBLE,
9999
equalTo(expected)
100100
);
101101
}
102102
});
103103
}
104+
105+
@Override
106+
protected boolean optIntoToAggregatorToStringChecks() {
107+
return true;
108+
}
104109
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinTests.java

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public static Iterable<Object[]> parameters() {
6060
List.of(DataType.INTEGER),
6161
() -> new TestCaseSupplier.TestCase(
6262
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "field")),
63-
"Min[field=Attribute[channel=0]]",
63+
standardAggregatorName("Min", DataType.INTEGER),
6464
DataType.INTEGER,
6565
equalTo(200)
6666
)
@@ -69,7 +69,7 @@ public static Iterable<Object[]> parameters() {
6969
List.of(DataType.LONG),
7070
() -> new TestCaseSupplier.TestCase(
7171
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.LONG, "field")),
72-
"Min[field=Attribute[channel=0]]",
72+
standardAggregatorName("Min", DataType.LONG),
7373
DataType.LONG,
7474
equalTo(200L)
7575
)
@@ -78,7 +78,7 @@ public static Iterable<Object[]> parameters() {
7878
List.of(DataType.UNSIGNED_LONG),
7979
() -> new TestCaseSupplier.TestCase(
8080
List.of(TestCaseSupplier.TypedData.multiRow(List.of(new BigInteger("200")), DataType.UNSIGNED_LONG, "field")),
81-
"Max[field=Attribute[channel=0]]",
81+
standardAggregatorName("Min", DataType.UNSIGNED_LONG),
8282
DataType.UNSIGNED_LONG,
8383
equalTo(new BigInteger("200"))
8484
)
@@ -87,7 +87,7 @@ public static Iterable<Object[]> parameters() {
8787
List.of(DataType.DOUBLE),
8888
() -> new TestCaseSupplier.TestCase(
8989
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200.), DataType.DOUBLE, "field")),
90-
"Min[field=Attribute[channel=0]]",
90+
standardAggregatorName("Min", DataType.DOUBLE),
9191
DataType.DOUBLE,
9292
equalTo(200.)
9393
)
@@ -96,7 +96,7 @@ public static Iterable<Object[]> parameters() {
9696
List.of(DataType.DATETIME),
9797
() -> new TestCaseSupplier.TestCase(
9898
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.DATETIME, "field")),
99-
"Min[field=Attribute[channel=0]]",
99+
standardAggregatorName("Min", DataType.DATETIME),
100100
DataType.DATETIME,
101101
equalTo(200L)
102102
)
@@ -105,7 +105,7 @@ public static Iterable<Object[]> parameters() {
105105
List.of(DataType.DATE_NANOS),
106106
() -> new TestCaseSupplier.TestCase(
107107
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.DATE_NANOS, "field")),
108-
"Min[field=Attribute[channel=0]]",
108+
standardAggregatorName("Min", DataType.DATE_NANOS),
109109
DataType.DATE_NANOS,
110110
equalTo(200L)
111111
)
@@ -114,7 +114,7 @@ public static Iterable<Object[]> parameters() {
114114
List.of(DataType.BOOLEAN),
115115
() -> new TestCaseSupplier.TestCase(
116116
List.of(TestCaseSupplier.TypedData.multiRow(List.of(true), DataType.BOOLEAN, "field")),
117-
"Min[field=Attribute[channel=0]]",
117+
standardAggregatorName("Min", DataType.BOOLEAN),
118118
DataType.BOOLEAN,
119119
equalTo(true)
120120
)
@@ -129,7 +129,7 @@ public static Iterable<Object[]> parameters() {
129129
"field"
130130
)
131131
),
132-
"Min[field=Attribute[channel=0]]",
132+
standardAggregatorName("Min", DataType.IP),
133133
DataType.IP,
134134
equalTo(new BytesRef(InetAddressPoint.encode(InetAddresses.forString("127.0.0.1"))))
135135
)
@@ -138,7 +138,7 @@ public static Iterable<Object[]> parameters() {
138138
var value = new BytesRef(randomAlphaOfLengthBetween(0, 50));
139139
return new TestCaseSupplier.TestCase(
140140
List.of(TestCaseSupplier.TypedData.multiRow(List.of(value), DataType.KEYWORD, "field")),
141-
"Min[field=Attribute[channel=0]]",
141+
standardAggregatorName("Min", DataType.KEYWORD),
142142
DataType.KEYWORD,
143143
equalTo(value)
144144
);
@@ -147,7 +147,7 @@ public static Iterable<Object[]> parameters() {
147147
var value = new BytesRef(randomAlphaOfLengthBetween(0, 50));
148148
return new TestCaseSupplier.TestCase(
149149
List.of(TestCaseSupplier.TypedData.multiRow(List.of(value), DataType.TEXT, "field")),
150-
"Min[field=Attribute[channel=0]]",
150+
standardAggregatorName("Min", DataType.TEXT),
151151
DataType.KEYWORD,
152152
equalTo(value)
153153
);
@@ -159,7 +159,7 @@ public static Iterable<Object[]> parameters() {
159159
.toBytesRef();
160160
return new TestCaseSupplier.TestCase(
161161
List.of(TestCaseSupplier.TypedData.multiRow(List.of(value), DataType.VERSION, "field")),
162-
"Min[field=Attribute[channel=0]]",
162+
standardAggregatorName("Min", DataType.VERSION),
163163
DataType.VERSION,
164164
equalTo(value)
165165
);
@@ -187,10 +187,15 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier
187187

188188
return new TestCaseSupplier.TestCase(
189189
List.of(fieldTypedData),
190-
"Min[field=Attribute[channel=0]]",
190+
standardAggregatorName("Min", fieldSupplier.type()),
191191
fieldSupplier.type(),
192192
equalTo(expected)
193193
);
194194
});
195195
}
196+
197+
@Override
198+
protected boolean optIntoToAggregatorToStringChecks() {
199+
return true;
200+
}
196201
}

0 commit comments

Comments
 (0)