Skip to content

Commit 24da530

Browse files
Fix for top
1 parent 8bcb0c4 commit 24da530

File tree

3 files changed

+144
-30
lines changed
  • x-pack/plugin/esql
    • qa/testFixtures/src/main/resources
    • src
      • main/java/org/elasticsearch/xpack/esql/expression/function/aggregate
      • test/java/org/elasticsearch/xpack/esql/expression/function/aggregate

3 files changed

+144
-30
lines changed

x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top.csv-spec

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,20 @@ date:date | double:double | integer:integer | long:long
2626
[1985-02-18T00:00:00.000Z,1985-02-24T00:00:00.000Z] | [-9.81,-9.28] | [25324,25945] | [-9,-9]
2727
;
2828

29+
topFoldableExpressions
30+
required_capability: agg_top
31+
FROM employees
32+
| STATS
33+
date = TOP(hire_date, 1+1, "dEsc"),
34+
double = TOP(salary_change, 100-98, REVERSE("csed")),
35+
integer = TOP(salary, 4-(1+1), Substring("Ascending",0,3)),
36+
long = TOP(salary_change.long, 10 - 4*2, Concat("as","c"))
37+
;
38+
39+
date:date | double:double | integer:integer | long:long
40+
[1999-04-30T00:00:00.000Z, 1997-05-19T00:00:00.000Z] | [14.74, 14.68] | [25324,25945] | [-9,-9]
41+
;
42+
2943
topAllTypesDesc
3044
required_capability: agg_top
3145
FROM employees

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java

Lines changed: 128 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
import org.elasticsearch.compute.aggregation.TopIpAggregatorFunctionSupplier;
2121
import org.elasticsearch.compute.aggregation.TopLongAggregatorFunctionSupplier;
2222
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
23+
import org.elasticsearch.xpack.esql.capabilities.PostOptimizationVerificationAware;
24+
import org.elasticsearch.xpack.esql.common.Failures;
2325
import org.elasticsearch.xpack.esql.core.expression.Expression;
24-
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
2526
import org.elasticsearch.xpack.esql.core.expression.Literal;
2627
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
2728
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -39,14 +40,15 @@
3940

4041
import static java.util.Arrays.asList;
4142
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
43+
import static org.elasticsearch.xpack.esql.common.Failure.fail;
4244
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
4345
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
4446
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD;
45-
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable;
47+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
4648
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString;
4749
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
4850

49-
public class Top extends AggregateFunction implements ToAggregator, SurrogateExpression {
51+
public class Top extends AggregateFunction implements ToAggregator, SurrogateExpression, PostOptimizationVerificationAware {
5052
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Top", Top::new);
5153

5254
private static final String ORDER_ASC = "ASC";
@@ -116,16 +118,26 @@ Expression orderField() {
116118
return parameters().get(1);
117119
}
118120

119-
private int limitValue() {
120-
return (int) limitField().fold(FoldContext.small() /* TODO remove me */);
121-
}
122-
123-
private String orderRawValue() {
124-
return BytesRefs.toString(orderField().fold(FoldContext.small() /* TODO remove me */));
121+
private Integer limitValue() {
122+
if (limitField() instanceof Literal literal) {
123+
Object value = literal.value();
124+
if (value instanceof Integer intValue) {
125+
return intValue;
126+
}
127+
}
128+
throw new EsqlIllegalArgumentException(
129+
format(null, "Limit value must be an integer in [{}], found [{}]", sourceText(), limitField())
130+
);
125131
}
126132

127133
private boolean orderValue() {
128-
return orderRawValue().equalsIgnoreCase(ORDER_ASC);
134+
if (orderField() instanceof Literal literal) {
135+
String order = BytesRefs.toString(literal.value());
136+
if (ORDER_ASC.equalsIgnoreCase(order) || ORDER_DESC.equalsIgnoreCase(order)) {
137+
return order.equalsIgnoreCase(ORDER_ASC);
138+
}
139+
}
140+
throw new EsqlIllegalArgumentException("Order value must be a literal, found: " + orderField());
129141
}
130142

131143
@Override
@@ -148,31 +160,114 @@ protected TypeResolution resolveType() {
148160
"ip",
149161
"string",
150162
"numeric except unsigned_long or counter types"
151-
).and(isNotNullAndFoldable(limitField(), sourceText(), SECOND))
163+
).and(isNotNull(limitField(), sourceText(), SECOND))
152164
.and(isType(limitField(), dt -> dt == DataType.INTEGER, sourceText(), SECOND, "integer"))
153-
.and(isNotNullAndFoldable(orderField(), sourceText(), THIRD))
165+
.and(isNotNull(orderField(), sourceText(), THIRD))
154166
.and(isString(orderField(), sourceText(), THIRD));
155167

156168
if (typeResolution.unresolved()) {
157169
return typeResolution;
158170
}
159171

160-
var limit = limitValue();
161-
var order = orderRawValue();
162-
163-
if (limit <= 0) {
164-
return new TypeResolution(format(null, "Limit must be greater than 0 in [{}], found [{}]", sourceText(), limit));
172+
TypeResolution result = resolveTypeLimit();
173+
if (result.equals(TypeResolution.TYPE_RESOLVED) == false) {
174+
return result;
175+
}
176+
result = resolveTypeOrder();
177+
if (result.equals(TypeResolution.TYPE_RESOLVED) == false) {
178+
return result;
165179
}
180+
return TypeResolution.TYPE_RESOLVED;
181+
}
166182

167-
if (order.equalsIgnoreCase(ORDER_ASC) == false && order.equalsIgnoreCase(ORDER_DESC) == false) {
168-
return new TypeResolution(
169-
format(null, "Invalid order value in [{}], expected [{}, {}] but got [{}]", sourceText(), ORDER_ASC, ORDER_DESC, order)
170-
);
183+
/**
184+
* We check that the limit is not null and that if it is a literal, it is a positive integer
185+
* We will do a more thorough check in the postOptimizationVerification once folding is done.
186+
*/
187+
private TypeResolution resolveTypeLimit() {
188+
Expression limit = limitField();
189+
if (limit == null) {
190+
return new TypeResolution(format(null, "Limit must be a constant integer in [{}], found [{}]", sourceText(), limit));
191+
}
192+
if (limit instanceof Literal literal) {
193+
if (literal.value() == null) {
194+
return new TypeResolution(format(null, "Limit must be a constant integer in [{}], found [{}]", sourceText(), limit));
195+
}
196+
int value = (Integer) literal.value();
197+
if (value <= 0) {
198+
return new TypeResolution(format(null, "Limit must be greater than 0 in [{}], found [{}]", sourceText(), value));
199+
}
171200
}
201+
return TypeResolution.TYPE_RESOLVED;
202+
}
172203

204+
/**
205+
* We check that the order is not null and that if it is a literal, it is one of the two valid values: "asc" or "desc".
206+
* We will do a more thorough check in the postOptimizationVerification once folding is done.
207+
*/
208+
private TypeResolution resolveTypeOrder() {
209+
Expression order = orderField();
210+
if (order == null) {
211+
return new TypeResolution(format(null, "Order must be a valid string in [{}], found [{}]", sourceText(), order));
212+
}
213+
if (order instanceof Literal literal) {
214+
if (literal.value() == null) {
215+
return new TypeResolution(
216+
format(null, "Invalid order value in [{}], expected [{}, {}] but got [{}]", sourceText(), ORDER_ASC, ORDER_DESC, order)
217+
);
218+
}
219+
String value = BytesRefs.toString(literal.value());
220+
if (value == null || value.equalsIgnoreCase(ORDER_ASC) == false && value.equalsIgnoreCase(ORDER_DESC) == false) {
221+
return new TypeResolution(
222+
format(null, "Invalid order value in [{}], expected [{}, {}] but got [{}]", sourceText(), ORDER_ASC, ORDER_DESC, order)
223+
);
224+
}
225+
}
173226
return TypeResolution.TYPE_RESOLVED;
174227
}
175228

229+
@Override
230+
public void postOptimizationVerification(Failures failures) {
231+
postOptimizationVerificationLimit(failures);
232+
postOptimizationVerificationOrder(failures);
233+
}
234+
235+
private void postOptimizationVerificationLimit(Failures failures) {
236+
Expression limit = limitField();
237+
if (limit == null) {
238+
failures.add(fail(limit, "Limit must be a constant integer in [{}], found [{}]", sourceText(), limit));
239+
}
240+
if (limit instanceof Literal literal) {
241+
int value = (Integer) literal.value();
242+
if (value <= 0) {
243+
failures.add(fail(limit, "Limit must be greater than 0 in [{}], found [{}]", sourceText(), value));
244+
}
245+
} else {
246+
// it is expected that the expression is a literal after folding
247+
// we fail if it is not a literal
248+
failures.add(fail(limit, "Limit must be a constant integer in [{}], found [{}]", sourceText(), limit));
249+
}
250+
}
251+
252+
private void postOptimizationVerificationOrder(Failures failures) {
253+
Expression order = orderField();
254+
if (order == null) {
255+
failures.add(fail(order, "Order must be a valid string in [{}], found [{}]", sourceText(), order));
256+
}
257+
if (order instanceof Literal literal) {
258+
String value = BytesRefs.toString(literal.value());
259+
if (value == null || value.equalsIgnoreCase(ORDER_ASC) == false && value.equalsIgnoreCase(ORDER_DESC) == false) {
260+
failures.add(
261+
fail(order, "Invalid order value in [{}], expected [{}, {}] but got [{}]", sourceText(), ORDER_ASC, ORDER_DESC, order)
262+
);
263+
}
264+
} else {
265+
// it is expected that the expression is a literal after folding
266+
// we fail if it is not a literal
267+
failures.add(fail(order, "Order must be a valid string in [{}], found [{}]", sourceText(), order));
268+
}
269+
}
270+
176271
@Override
177272
public DataType dataType() {
178273
return field().dataType().noText();
@@ -215,15 +310,20 @@ public AggregatorFunctionSupplier supplier() {
215310
@Override
216311
public Expression surrogate() {
217312
var s = source();
218-
219-
if (limitValue() == 1) {
220-
if (orderValue()) {
221-
return new Min(s, field());
222-
} else {
223-
return new Max(s, field());
313+
try {
314+
if (limitValue() == 1) {
315+
if (orderValue()) {
316+
return new Min(s, field());
317+
} else {
318+
return new Max(s, field());
319+
}
224320
}
321+
} catch (EsqlIllegalArgumentException e) {
322+
// If the limit is not a literal or is not a positive integer, we cannot create a surrogate
323+
// so we return null to indicate that no surrogate can be created.
324+
// This is possible if the limit is an expression, and folding has not been done yet.
325+
return null;
225326
}
226-
227327
return null;
228328
}
229329
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ public static Iterable<Object[]> parameters() {
262262
new TestCaseSupplier.TypedData(null, DataType.INTEGER, "limit").forceLiteral(),
263263
new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral()
264264
),
265-
"second argument of [source] cannot be null, received [limit]"
265+
"Limit must be a constant integer in [source], found [null]"
266266
)
267267
),
268268
new TestCaseSupplier(
@@ -273,7 +273,7 @@ public static Iterable<Object[]> parameters() {
273273
new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(),
274274
new TestCaseSupplier.TypedData(null, DataType.KEYWORD, "order").forceLiteral()
275275
),
276-
"third argument of [source] cannot be null, received [order]"
276+
"Invalid order value in [source], expected [ASC, DESC] but got [null]"
277277
)
278278
)
279279
)

0 commit comments

Comments
 (0)