Skip to content

Commit cbd4992

Browse files
authored
Clean UnaryScalarFunction slightly (ESQL-1359)
Our base class for `UnaryScalarFunction` only takes one argument because it's, well, unary. But it was reporting type errors on that argument as though it were the first of many. That's silly. I also added some tests for the `Abs` function which extends our `UnaryScalarFunction` that would have caught this error. While I was there I ported `Length` from QL's `UnaryScalarFunction` to ours. Let's use our stuff. Even if it's wrong we can change it without bothing QL. Finally I added some javadocs and removed some unused code.
1 parent 9ba3a12 commit cbd4992

File tree

6 files changed

+137
-15
lines changed

6 files changed

+137
-15
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/UnaryScalarFunction.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.esql.expression.function.scalar;
99

1010
import org.elasticsearch.xpack.ql.expression.Expression;
11+
import org.elasticsearch.xpack.ql.expression.TypeResolutions;
1112
import org.elasticsearch.xpack.ql.expression.function.scalar.ScalarFunction;
1213
import org.elasticsearch.xpack.ql.expression.gen.script.ScriptTemplate;
1314
import org.elasticsearch.xpack.ql.tree.Source;
@@ -16,7 +17,6 @@
1617
import java.util.Arrays;
1718
import java.util.Objects;
1819

19-
import static org.elasticsearch.xpack.ql.expression.TypeResolutions.ParamOrdinal.FIRST;
2020
import static org.elasticsearch.xpack.ql.expression.TypeResolutions.isNumeric;
2121

2222
public abstract class UnaryScalarFunction extends ScalarFunction {
@@ -33,7 +33,7 @@ protected Expression.TypeResolution resolveType() {
3333
return new Expression.TypeResolution("Unresolved children");
3434
}
3535

36-
return isNumeric(field, sourceText(), FIRST);
36+
return isNumeric(field, sourceText(), TypeResolutions.ParamOrdinal.DEFAULT);
3737
}
3838

3939
@Override

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Length.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111
import org.apache.lucene.util.UnicodeUtil;
1212
import org.elasticsearch.compute.ann.Evaluator;
1313
import org.elasticsearch.compute.operator.EvalOperator;
14+
import org.elasticsearch.xpack.esql.expression.function.scalar.UnaryScalarFunction;
1415
import org.elasticsearch.xpack.esql.planner.Mappable;
1516
import org.elasticsearch.xpack.ql.expression.Expression;
16-
import org.elasticsearch.xpack.ql.expression.function.scalar.UnaryScalarFunction;
17-
import org.elasticsearch.xpack.ql.expression.gen.processor.Processor;
1817
import org.elasticsearch.xpack.ql.tree.NodeInfo;
1918
import org.elasticsearch.xpack.ql.tree.Source;
2019
import org.elasticsearch.xpack.ql.type.DataType;
2120
import org.elasticsearch.xpack.ql.type.DataTypes;
2221

22+
import java.util.List;
2323
import java.util.function.Function;
2424
import java.util.function.Supplier;
2525

@@ -62,13 +62,8 @@ static int process(BytesRef val) {
6262
}
6363

6464
@Override
65-
protected UnaryScalarFunction replaceChild(Expression newChild) {
66-
return new Length(source(), newChild);
67-
}
68-
69-
@Override
70-
protected Processor makeProcessor() {
71-
throw new UnsupportedOperationException();
65+
public Expression replaceChildren(List<Expression> newChildren) {
66+
return new Length(source(), newChildren.get(0));
7267
}
7368

7469
@Override

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ public static List<PlanNameRegistry.Entry> namedTypeEntries() {
266266
// UnaryScalarFunction
267267
of(QL_UNARY_SCLR_CLS, IsNotNull.class, PlanNamedTypes::writeQLUnaryScalar, PlanNamedTypes::readQLUnaryScalar),
268268
of(QL_UNARY_SCLR_CLS, IsNull.class, PlanNamedTypes::writeQLUnaryScalar, PlanNamedTypes::readQLUnaryScalar),
269-
of(QL_UNARY_SCLR_CLS, Length.class, PlanNamedTypes::writeQLUnaryScalar, PlanNamedTypes::readQLUnaryScalar),
269+
of(ESQL_UNARY_SCLR_CLS, Length.class, PlanNamedTypes::writeESQLUnaryScalar, PlanNamedTypes::readESQLUnaryScalar),
270270
of(QL_UNARY_SCLR_CLS, Not.class, PlanNamedTypes::writeQLUnaryScalar, PlanNamedTypes::readQLUnaryScalar),
271271
of(ESQL_UNARY_SCLR_CLS, Abs.class, PlanNamedTypes::writeESQLUnaryScalar, PlanNamedTypes::readESQLUnaryScalar),
272272
of(ScalarFunction.class, E.class, PlanNamedTypes::writeNoArgScalar, PlanNamedTypes::readNoArgScalar),
@@ -943,6 +943,7 @@ static void writeBinaryLogic(PlanStreamOutput out, BinaryLogic binaryLogic) thro
943943
entry(name(IsFinite.class), IsFinite::new),
944944
entry(name(IsInfinite.class), IsInfinite::new),
945945
entry(name(IsNaN.class), IsNaN::new),
946+
entry(name(Length.class), Length::new),
946947
entry(name(Metadata.class), Metadata::new),
947948
entry(name(ToBoolean.class), ToBoolean::new),
948949
entry(name(ToDatetime.class), ToDatetime::new),
@@ -989,7 +990,6 @@ static void writeNoArgScalar(PlanStreamOutput out, ScalarFunction function) {}
989990
Map.ofEntries(
990991
entry(name(IsNotNull.class), IsNotNull::new),
991992
entry(name(IsNull.class), IsNull::new),
992-
entry(name(Length.class), Length::new),
993993
entry(name(Not.class), Not::new)
994994
);
995995

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AbstractScalarFunctionTestCase.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,27 @@
3232
* Base class for function tests.
3333
*/
3434
public abstract class AbstractScalarFunctionTestCase extends AbstractFunctionTestCase {
35-
35+
/**
36+
* Describe supported arguments. Build each argument with
37+
* {@link #required} or {@link #optional}.
38+
*/
3639
protected abstract List<ArgumentSpec> argSpec();
3740

41+
/**
42+
* The data type that applying this function to arguments of this type should produce.
43+
*/
3844
protected abstract DataType expectedType(List<DataType> argTypes);
3945

46+
/**
47+
* Define a required argument.
48+
*/
4049
protected final ArgumentSpec required(DataType... validTypes) {
4150
return new ArgumentSpec(false, withNullAndSorted(validTypes));
4251
}
4352

53+
/**
54+
* Define an optional argument.
55+
*/
4456
protected final ArgumentSpec optional(DataType... validTypes) {
4557
return new ArgumentSpec(true, withNullAndSorted(validTypes));
4658
}
@@ -52,18 +64,30 @@ private Set<DataType> withNullAndSorted(DataType[] validTypes) {
5264
return realValidTypes;
5365
}
5466

67+
/**
68+
* All string types (keyword, text, match_only_text, etc). For passing to {@link #required} or {@link #optional}.
69+
*/
5570
protected final DataType[] strings() {
5671
return EsqlDataTypes.types().stream().filter(DataTypes::isString).toArray(DataType[]::new);
5772
}
5873

74+
/**
75+
* All integer types (long, int, short, byte). For passing to {@link #required} or {@link #optional}.
76+
*/
5977
protected final DataType[] integers() {
6078
return EsqlDataTypes.types().stream().filter(DataType::isInteger).toArray(DataType[]::new);
6179
}
6280

81+
/**
82+
* All rational types (double, float, whatever). For passing to {@link #required} or {@link #optional}.
83+
*/
6384
protected final DataType[] rationals() {
6485
return EsqlDataTypes.types().stream().filter(DataType::isRational).toArray(DataType[]::new);
6586
}
6687

88+
/**
89+
* All numeric types (integers and rationals.) For passing to {@link #required} or {@link #optional}.
90+
*/
6791
protected final DataType[] numerics() {
6892
return EsqlDataTypes.types().stream().filter(DataType::isNumeric).toArray(DataType[]::new);
6993
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.scalar.math;
9+
10+
import org.elasticsearch.xpack.esql.expression.function.scalar.AbstractScalarFunctionTestCase;
11+
import org.elasticsearch.xpack.ql.expression.Expression;
12+
import org.elasticsearch.xpack.ql.expression.Literal;
13+
import org.elasticsearch.xpack.ql.tree.Source;
14+
import org.elasticsearch.xpack.ql.type.DataType;
15+
import org.elasticsearch.xpack.ql.type.DataTypes;
16+
import org.hamcrest.Matcher;
17+
18+
import java.util.List;
19+
20+
import static org.elasticsearch.compute.data.BlockUtils.toJavaObject;
21+
import static org.hamcrest.Matchers.equalTo;
22+
23+
public class AbsTests extends AbstractScalarFunctionTestCase {
24+
@Override
25+
protected List<Object> simpleData() {
26+
return List.of(randomInt());
27+
}
28+
29+
@Override
30+
protected Expression expressionForSimpleData() {
31+
return new Abs(Source.EMPTY, field("arg", DataTypes.INTEGER));
32+
}
33+
34+
@Override
35+
protected Matcher<Object> resultMatcher(List<Object> data, DataType dataType) {
36+
Object in = data.get(0);
37+
if (dataType == DataTypes.INTEGER) {
38+
return equalTo(Math.abs(((Integer) in).intValue()));
39+
}
40+
if (dataType == DataTypes.LONG) {
41+
return equalTo(Math.abs(((Long) in).longValue()));
42+
}
43+
if (dataType == DataTypes.UNSIGNED_LONG) {
44+
return equalTo(in);
45+
}
46+
if (dataType == DataTypes.DOUBLE) {
47+
return equalTo(Math.abs(((Double) in).doubleValue()));
48+
}
49+
throw new IllegalArgumentException("can't match " + in);
50+
}
51+
52+
@Override
53+
protected String expectedEvaluatorSimpleToString() {
54+
return "AbsIntEvaluator[fieldVal=Attribute[channel=0]]";
55+
}
56+
57+
@Override
58+
protected Expression constantFoldable(List<Object> data) {
59+
return new Abs(Source.EMPTY, new Literal(Source.EMPTY, data.get(0), DataTypes.INTEGER));
60+
}
61+
62+
@Override
63+
protected Expression build(Source source, List<Literal> args) {
64+
return new Abs(source, args.get(0));
65+
}
66+
67+
@Override
68+
protected List<ArgumentSpec> argSpec() {
69+
return List.of(required(numerics()));
70+
}
71+
72+
@Override
73+
protected DataType expectedType(List<DataType> argTypes) {
74+
return argTypes.get(0);
75+
}
76+
77+
public final void testLong() {
78+
List<Object> data = List.of(randomLong());
79+
Expression expression = new Abs(Source.EMPTY, field("arg", DataTypes.LONG));
80+
Object result = toJavaObject(evaluator(expression).get().eval(row(data)), 0);
81+
assertThat(result, resultMatcher(data, DataTypes.LONG));
82+
}
83+
84+
public final void testUnsignedLong() {
85+
List<Object> data = List.of(randomLong());
86+
Expression expression = new Abs(Source.EMPTY, field("arg", DataTypes.UNSIGNED_LONG));
87+
Object result = toJavaObject(evaluator(expression).get().eval(row(data)), 0);
88+
assertThat(result, resultMatcher(data, DataTypes.UNSIGNED_LONG));
89+
}
90+
91+
public final void testInt() {
92+
List<Object> data = List.of(randomInt());
93+
Expression expression = new Abs(Source.EMPTY, field("arg", DataTypes.INTEGER));
94+
Object result = toJavaObject(evaluator(expression).get().eval(row(data)), 0);
95+
assertThat(result, resultMatcher(data, DataTypes.INTEGER));
96+
}
97+
98+
public final void testDouble() {
99+
List<Object> data = List.of(randomDouble());
100+
Expression expression = new Abs(Source.EMPTY, field("arg", DataTypes.DOUBLE));
101+
Object result = toJavaObject(evaluator(expression).get().eval(row(data)), 0);
102+
assertThat(result, resultMatcher(data, DataTypes.DOUBLE));
103+
}
104+
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowTests.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ protected Expression constantFoldable(List<Object> data) {
123123

124124
@Override
125125
protected List<ArgumentSpec> argSpec() {
126-
var validDataTypes = new DataType[] { DataTypes.DOUBLE, DataTypes.LONG, DataTypes.INTEGER };
127126
return List.of(required(numerics()), required(numerics()));
128127
}
129128

0 commit comments

Comments
 (0)