Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ public First withFilter(Expression filter) {
return new First(source(), field(), filter, sort);
}

public Expression sort() {
return sort;
}

@Override
public DataType dataType() {
return field().dataType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ public Last withFilter(Expression filter) {
return new Last(source(), field(), filter, sort);
}

public Expression sort() {
return sort;
}

@Override
public DataType dataType() {
return field().dataType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,41 +506,46 @@ private void processPageGrouping(GroupingAggregator aggregator, Page inputPage,
}

private void assertAggregatorToString(Object aggregator) {
if (optIntoToAggregatorToStringChecks() == false) {
return;
}
String expectedStart = switch (aggregator) {
case Aggregator a -> "Aggregator[aggregatorFunction=";
case GroupingAggregator a -> "GroupingAggregator[aggregatorFunction=";
default -> throw new UnsupportedOperationException("can't check toString for [" + aggregator.getClass() + "]");
};
String channels = initialInputChannels().stream().map(Object::toString).collect(Collectors.joining(", "));
String expectedEnd = switch (aggregator) {
case Aggregator a -> "AggregatorFunction[channels=[0]], mode=SINGLE]";
case GroupingAggregator a -> "GroupingAggregatorFunction[channels=[0]], mode=SINGLE]";
case Aggregator a -> "AggregatorFunction[channels=[" + channels + "]], mode=SINGLE]";
case GroupingAggregator a -> "GroupingAggregatorFunction[channels=[" + channels + "]], mode=SINGLE]";
default -> throw new UnsupportedOperationException("can't check toString for [" + aggregator.getClass() + "]");
};

String toString = aggregator.toString();
assertThat(toString, startsWith(expectedStart));
assertThat(toString.substring(expectedStart.length(), toString.length() - expectedEnd.length()), testCase.evaluatorToString());
assertThat(toString, endsWith(expectedEnd));
}

protected boolean optIntoToAggregatorToStringChecks() {
// TODO remove this when everyone has opted in
return false;
assertThat(toString.substring(expectedStart.length(), toString.length() - expectedEnd.length()), testCase.evaluatorToString());
}

protected static String standardAggregatorName(String prefix, DataType type) {
String typeName = switch (type) {
case BOOLEAN -> "Boolean";
case CARTESIAN_POINT -> "CartesianPoint";
case CARTESIAN_SHAPE -> "CartesianShape";
case GEO_POINT -> "GeoPoint";
case GEO_SHAPE -> "GeoShape";
case KEYWORD, TEXT, VERSION -> "BytesRef";
case DOUBLE -> "Double";
case INTEGER -> "Int";
case DOUBLE, COUNTER_DOUBLE -> "Double";
case INTEGER, COUNTER_INTEGER -> "Int";
case IP -> "Ip";
case DATETIME, DATE_NANOS, LONG, UNSIGNED_LONG -> "Long";
case DATETIME, DATE_NANOS, LONG, COUNTER_LONG, UNSIGNED_LONG -> "Long";
case NULL -> "Null";
default -> throw new UnsupportedOperationException("name for [" + type + "]");
};
return prefix + typeName;
}

protected static String standardAggregatorNameAllBytesTheSame(String prefix, DataType type) {
return standardAggregatorName(prefix, switch (type) {
case CARTESIAN_POINT, CARTESIAN_SHAPE, GEO_POINT, GEO_SHAPE, IP -> DataType.KEYWORD;
default -> type;
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ private static TestCaseSupplier makeSupplier(

return new TestCaseSupplier.TestCase(
List.of(fieldTypedData, precisionTypedData),
"CountDistinct[field=Attribute[channel=0],precision=Attribute[channel=1]]",
standardAggregatorNameAllBytesTheSame("CountDistinct", fieldTypedData.type()),
DataType.LONG,
equalTo(result)
);
Expand All @@ -149,7 +149,7 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier

return new TestCaseSupplier.TestCase(
List.of(fieldTypedData),
"CountDistinct[field=Attribute[channel=0]]",
standardAggregatorNameAllBytesTheSame("CountDistinct", fieldTypedData.type()),
DataType.LONG,
equalTo(result)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public static Iterable<Object[]> parameters() {
List.of(dataType),
() -> new TestCaseSupplier.TestCase(
List.of(TestCaseSupplier.TypedData.multiRow(List.of(), dataType, "field")),
"Count[field=Attribute[channel=0]]",
"Count",
DataType.LONG,
equalTo(0L)
)
Expand All @@ -100,12 +100,7 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier
var fieldTypedData = fieldSupplier.get();
var rowCount = fieldTypedData.multiRowData().stream().filter(Objects::nonNull).count();

return new TestCaseSupplier.TestCase(
List.of(fieldTypedData),
"Count[field=Attribute[channel=0]]",
DataType.LONG,
equalTo(rowCount)
);
return new TestCaseSupplier.TestCase(List.of(fieldTypedData), "Count", DataType.LONG, equalTo(rowCount));
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier
}
return new TestCaseSupplier.TestCase(
List.of(fieldTypedData, timestampsField),
"FirstOverTime[field=Attribute[channel=0],timestamp=Attribute[channel=1]]",
standardAggregatorName("First", fieldSupplier.type()) + "ByTimestamp",
type,
equalTo(expected)
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.aggregate;

import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests;

import java.io.IOException;

public class FirstSerializationTests extends AbstractExpressionSerializationTests<First> {
@Override
protected First createTestInstance() {
return new First(randomSource(), randomChild(), randomChild());
}

@Override
protected First mutateInstance(First instance) throws IOException {
Expression field = instance.field();
Expression sort = instance.sort();
if (randomBoolean()) {
field = randomValueOtherThan(field, AbstractExpressionSerializationTests::randomChild);
} else {
sort = randomValueOtherThan(sort, AbstractExpressionSerializationTests::randomChild);
}
return new First(instance.source(), field, sort);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ static TestCaseSupplier makeSupplier(
}
return new TestCaseSupplier.TestCase(
List.of(values, sorts),
"unused",
standardAggregatorName(first ? "First" : "Last", values.type()) + "ByTimestamp",
values.type(),
anyOf(() -> Iterators.map(expected.iterator(), Matchers::equalTo))
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier
}
return new TestCaseSupplier.TestCase(
List.of(fieldTypedData, timestampsField),
"LastOverTime[field=Attribute[channel=0],timestamp=Attribute[channel=1]]",
standardAggregatorName("Last", fieldSupplier.type()) + "ByTimestamp",
type,
equalTo(expected)
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.aggregate;

import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests;

import java.io.IOException;

public class LastSerializationTests extends AbstractExpressionSerializationTests<Last> {
@Override
protected Last createTestInstance() {
return new Last(randomSource(), randomChild(), randomChild());
}

@Override
protected Last mutateInstance(Last instance) throws IOException {
Expression field = instance.field();
Expression sort = instance.sort();
if (randomBoolean()) {
field = randomValueOtherThan(field, AbstractExpressionSerializationTests::randomChild);
} else {
sort = randomValueOtherThan(sort, AbstractExpressionSerializationTests::randomChild);
}
return new Last(instance.source(), field, sort);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,4 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier
);
});
}

@Override
protected boolean optIntoToAggregatorToStringChecks() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier

return new TestCaseSupplier.TestCase(
List.of(fieldTypedData),
"MedianAbsoluteDeviation[number=Attribute[channel=0]]",
standardAggregatorName("MedianAbsoluteDeviation", fieldSupplier.type()),
DataType.DOUBLE,
equalTo(expected)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,4 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier
}
});
}

@Override
protected boolean optIntoToAggregatorToStringChecks() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,4 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier
);
});
}

@Override
protected boolean optIntoToAggregatorToStringChecks() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ private static TestCaseSupplier makeSupplier(

return new TestCaseSupplier.TestCase(
List.of(fieldTypedData, percentileTypedData),
"Percentile[number=Attribute[channel=0],percentile=Attribute[channel=1]]",
standardAggregatorName("Percentile", fieldSupplier.type()),
DataType.DOUBLE,
equalTo(expected)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier
}
return new TestCaseSupplier.TestCase(
List.of(fieldTypedData, timestampsField),
"Rate[field=Attribute[channel=0],timestamp=Attribute[channel=1]]",
standardAggregatorName("Rate", fieldTypedData.type()),
DataType.DOUBLE,
matcher
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.aggregate;

import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests;

import java.io.IOException;

public class SampleSerializationTests extends AbstractExpressionSerializationTests<Sample> {
@Override
protected Sample createTestInstance() {
Source source = randomSource();
Expression field = randomChild();
Expression limit = randomChild();
return new Sample(source, field, limit);
}

@Override
protected Sample mutateInstance(Sample instance) throws IOException {
Source source = randomSource();
Expression field = instance.field();
Expression limit = instance.limitField();
switch (between(0, 1)) {
case 0 -> field = randomValueOtherThan(field, AbstractExpressionSerializationTests::randomChild);
case 1 -> limit = randomValueOtherThan(limit, AbstractExpressionSerializationTests::randomChild);
}
return new Sample(source, field, limit);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ private static TestCaseSupplier makeSupplier(

return new TestCaseSupplier.TestCase(
List.of(fieldTypedData, limitTypedData),
"Sample[field=Attribute[channel=0], limit=Attribute[channel=1]]",
standardAggregatorNameAllBytesTheSame("Sample", fieldSupplier.type()),
fieldSupplier.type(),
subsetOfSize(rows, limit)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier

return new TestCaseSupplier.TestCase(
List.of(fieldTypedData),
"SpatialCentroid[field=Attribute[channel=0]]",
standardAggregatorName("SpatialCentroid", fieldSupplier.type()) + "SourceValues",
fieldTypedData.type(),
centroidMatches(expectedX, expectedY, 1e-14)
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.aggregate;

import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests;

import java.io.IOException;

public class SpatialExtentSerializationTests extends AbstractExpressionSerializationTests<SpatialExtent> {
@Override
protected SpatialExtent createTestInstance() {
return new SpatialExtent(randomSource(), randomChild());
}

@Override
protected SpatialExtent mutateInstance(SpatialExtent instance) throws IOException {
return new SpatialExtent(
instance.source(),
randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier
Rectangle result = pointVisitor.getResult();
return new TestCaseSupplier.TestCase(
List.of(fieldTypedData),
"SpatialExtent[field=Attribute[channel=0]]",
standardAggregatorName("SpatialExtent", fieldSupplier.type()) + "SourceValues",
expectedType,
new WellKnownBinaryBytesRefMatcher<>(RectangleMatcher.closeToFloat(result, 1e-3, pointType.encoder()))
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.aggregate;

import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests;

import java.io.IOException;

public class StdDevSerializationTests extends AbstractExpressionSerializationTests<StdDev> {
@Override
protected StdDev createTestInstance() {
return new StdDev(randomSource(), randomChild());
}

@Override
protected StdDev mutateInstance(StdDev instance) throws IOException {
return new StdDev(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier
var expected = Double.isFinite(result) ? result : null;
return new TestCaseSupplier.TestCase(
List.of(fieldTypedData),
"StdDev[field=Attribute[channel=0]]",
standardAggregatorName("StdDev", fieldSupplier.type()),
DataType.DOUBLE,
equalTo(expected)
);
Expand Down
Loading