Skip to content

Commit 254fc59

Browse files
committed
Slightly more efficient null handling and fix null test logic
1 parent 979aac9 commit 254fc59

File tree

2 files changed

+96
-4
lines changed
  • x-pack/plugin/esql/src

2 files changed

+96
-4
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvContainsAll.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,15 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
174174
right()
175175
);
176176
}
177+
if(supersetType == ElementType.NULL || subsetType == ElementType.NULL) {
178+
return new MvContainsAllNullEvaluator(toEvaluator.apply(right()));
179+
}
177180
return switch (supersetType) {
178181
case BOOLEAN -> new MvContainsAllBooleanEvaluator.Factory(source(), toEvaluator.apply(left()), toEvaluator.apply(right()));
179182
case BYTES_REF -> new MvContainsAllBytesRefEvaluator.Factory(source(), toEvaluator.apply(left()), toEvaluator.apply(right()));
180183
case DOUBLE -> new MvContainsAllDoubleEvaluator.Factory(source(), toEvaluator.apply(left()), toEvaluator.apply(right()));
181184
case INT -> new MvContainsAllIntEvaluator.Factory(source(), toEvaluator.apply(left()), toEvaluator.apply(right()));
182185
case LONG -> new MvContainsAllLongEvaluator.Factory(source(), toEvaluator.apply(left()), toEvaluator.apply(right()));
183-
case NULL -> new MvContainsAllNullEvaluator(toEvaluator.apply(right()));
184186
default -> throw EsqlIllegalArgumentException.illegalDataType(dataType());
185187
};
186188
}
@@ -287,11 +289,12 @@ interface ValueExtractor<BlockType extends Block, Type> {
287289
Type extractValue(BlockType block, int position);
288290
}
289291

290-
private record MvContainsAllNullEvaluator(ExpressionEvaluator.Factory toEvaluator) implements ExpressionEvaluator.Factory {
292+
private record MvContainsAllNullEvaluator(ExpressionEvaluator.Factory subsetFieldEvaluator) implements ExpressionEvaluator.Factory {
293+
291294
@Override
292295
public ExpressionEvaluator get(DriverContext context) {
293296
return new ExpressionEvaluator() {
294-
final ExpressionEvaluator subsetField = toEvaluator.get(context);
297+
final ExpressionEvaluator subsetField = subsetFieldEvaluator.get(context);
295298

296299
@Override
297300
public Block eval(Page page) {
@@ -305,7 +308,17 @@ public Block eval(Page page) {
305308
public void close() {
306309
Releasables.closeExpectNoException(subsetField);
307310
}
311+
312+
@Override
313+
public String toString() {
314+
return "MvContainsAllNullEvaluator[" + "subsetField=" + subsetFieldEvaluator + "]";
315+
}
308316
};
309317
}
318+
319+
@Override
320+
public String toString() {
321+
return "MvContainsAllNullEvaluator[" + "subsetField=" + subsetFieldEvaluator + "]";
322+
}
310323
}
311324
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvContainsAllTests.java

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,20 @@
1919
import org.elasticsearch.xpack.esql.core.type.DataType;
2020
import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase;
2121
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
22+
import org.hamcrest.Matcher;
2223

2324
import java.util.ArrayList;
25+
import java.util.HashSet;
2426
import java.util.List;
27+
import java.util.Set;
2528
import java.util.function.Supplier;
2629

2730
import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral;
2831
import static org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes.CARTESIAN;
2932
import static org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes.GEO;
33+
import static org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier.TypedData.MULTI_ROW_NULL;
34+
import static org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier.TypedData.NULL;
35+
import static org.hamcrest.Matchers.anyOf;
3036
import static org.hamcrest.Matchers.equalTo;
3137

3238
public class MvContainsAllTests extends AbstractScalarFunctionTestCase {
@@ -42,7 +48,14 @@ public static Iterable<Object[]> parameters() {
4248
longs(suppliers);
4349
doubles(suppliers);
4450
bytesRefs(suppliers);
45-
return parameterSuppliersFromTypedDataWithDefaultChecksNoErrors(true, suppliers);
51+
52+
return parameterSuppliersFromTypedData(anyNullIsNull(
53+
suppliers,
54+
(nullPosition, nullValueDataType, original) -> false
55+
&& nullValueDataType == DataType.NULL
56+
&& original.getData().size() == 1 ? DataType.NULL : original.expectedType(),
57+
(nullPosition, nullData, original) -> original
58+
));
4659
}
4760

4861
@Override
@@ -270,4 +283,70 @@ private static void bytesRefs(List<TestCaseSupplier> suppliers) {
270283
}));
271284
}
272285

286+
// Adjusted from static method anyNullIsNull in {@code AbstractScalarFunctionTestCase#}
287+
protected static List<TestCaseSupplier> anyNullIsNull(
288+
List<TestCaseSupplier> testCaseSuppliers,
289+
ExpectedType expectedType,
290+
ExpectedEvaluatorToString evaluatorToString
291+
) {
292+
List<TestCaseSupplier> suppliers = new ArrayList<>(testCaseSuppliers.size());
293+
suppliers.addAll(testCaseSuppliers);
294+
295+
/*
296+
* For each original test case, add as many copies as there were
297+
* arguments, replacing one of the arguments with null and keeping
298+
* the others.
299+
*
300+
* Also, if this was the first time we saw the signature we copy it
301+
* *again*, replacing the argument with null, but annotating the
302+
* argument’s type as `null` explicitly.
303+
*/
304+
Set<List<DataType>> uniqueSignatures = new HashSet<>();
305+
for (TestCaseSupplier original : testCaseSuppliers) {
306+
boolean firstTimeSeenSignature = uniqueSignatures.add(original.types());
307+
for (int typeIndex = 0; typeIndex < original.types().size(); typeIndex++) {
308+
int nullPosition = typeIndex;
309+
310+
suppliers.add(new TestCaseSupplier(original.name() + " null in " + nullPosition, original.types(), () -> {
311+
TestCaseSupplier.TestCase originalTestCase = original.get();
312+
List<TestCaseSupplier.TypedData> data = new ArrayList<>(originalTestCase.getData());
313+
data.set(nullPosition, NULL);
314+
TestCaseSupplier.TypedData nulledData = originalTestCase.getData().get(nullPosition);
315+
return new TestCaseSupplier.TestCase(
316+
data,
317+
evaluatorToString.evaluatorToString(nullPosition, nulledData, originalTestCase.evaluatorToString()),
318+
expectedType.expectedType(nullPosition, nulledData.type(), originalTestCase),
319+
equalTo(nullPosition == 1)
320+
);
321+
}));
322+
323+
if (firstTimeSeenSignature) {
324+
var typesWithNull = new ArrayList<>(original.types());
325+
typesWithNull.set(nullPosition, DataType.NULL);
326+
boolean newSignature = uniqueSignatures.add(typesWithNull);
327+
if (newSignature) {
328+
suppliers.add(new TestCaseSupplier(typesWithNull, () -> {
329+
TestCaseSupplier.TestCase originalTestCase = original.get();
330+
var typeDataWithNull = new ArrayList<>(originalTestCase.getData());
331+
typeDataWithNull.set(nullPosition, typeDataWithNull.get(nullPosition).isMultiRow() ? MULTI_ROW_NULL : NULL);
332+
return new TestCaseSupplier.TestCase(
333+
typeDataWithNull,
334+
"MvContainsAllNullEvaluator[subsetField=Attribute[channel=1]]",
335+
DataType.BOOLEAN,
336+
equalTo(nullPosition == 1)
337+
);
338+
}));
339+
}
340+
}
341+
}
342+
}
343+
344+
return suppliers;
345+
}
346+
347+
// We always return a boolean.
348+
@Override
349+
protected Matcher<Object> allNullsMatcher() {
350+
return anyOf(equalTo(false),equalTo(true));
351+
}
273352
}

0 commit comments

Comments
 (0)