Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice tests!

  • I'd add one with ROW foo = CASE([true, true], ... just to make it obvious that CASE doesn't have implicit ANY or ALL logic.
  • Should we add ROW a = [true, false] | EVAL c = CASE(MV_SLICE(a, 0, 2), "foo", "bar") as well? This one failed differently from others, so I guess it exercised a different code path? A test case with an mv expression inside CASE is probably a good idea, anyway. Or, similarly but more realistically, using a conversion function like ROW a = ["true", "false"] | EVAL c = CASE(a::boolean, "foo", "bar")
  • If we want to be supremely paranoid, we could also check the interaction with multivalued union types, e.g. FROM idx* | EVAL x = CASE(field::boolean, 1, 2); I'd only do that if we expect weird interactions with block loading, though, which probably is not the case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally if you ask "should we add this csv-spec test?" the answer is just yes. It's cheap.

Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ M |10
;

caseOnMv
required_capability: case_mv

FROM employees
| WHERE emp_no == 10010
| EVAL foo = CASE(still_hired, "still", is_rehired, "rehired", "not")
Expand All @@ -106,7 +108,9 @@ still_hired:boolean | is_rehired:boolean | foo:keyword
false | [false, false, true, true] | not
;

caseOnConstantMv
caseOnConstantMvFalseTrue
required_capability: case_mv

ROW foo = CASE([false, true], "a", "b");
warning:Line 1:16: evaluation of [[false, true]] failed, treating result as false. Only first 20 failures recorded.
warning:Line 1:16: java.lang.IllegalArgumentException: CASE expects a single-valued boolean
Expand All @@ -115,6 +119,57 @@ foo:keyword
b
;

caseOnConstantMvTrueTrue
required_capability: case_mv

ROW foo = CASE([true, true], "a", "b");
warning:Line 1:16: evaluation of [[true, true]] failed, treating result as false. Only first 20 failures recorded.
warning:Line 1:16: java.lang.IllegalArgumentException: CASE expects a single-valued boolean

foo:keyword
b
;

caseOnMvSliceMv
required_capability: case_mv

ROW foo = [true, false, false] | EVAL foo = CASE(MV_SLICE(foo, 0, 1), "a", "b");
warning:Line 1:50: evaluation of [MV_SLICE(foo, 0, 1)] failed, treating result as false. Only first 20 failures recorded.
warning:Line 1:50: java.lang.IllegalArgumentException: CASE expects a single-valued boolean

foo:keyword
b
;

caseOnMvSliceSv
required_capability: case_mv

ROW foo = [true, false, false] | EVAL foo = CASE(MV_SLICE(foo, 0), "a", "b");

foo:keyword
a
;

caseOnConvertMvSliceMv
required_capability: case_mv

ROW foo = ["true", "false", "false"] | EVAL foo = CASE(MV_SLICE(foo::BOOLEAN, 0, 1), "a", "b");
warning:Line 1:56: evaluation of [MV_SLICE(foo::BOOLEAN, 0, 1)] failed, treating result as false. Only first 20 failures recorded.
warning:Line 1:56: java.lang.IllegalArgumentException: CASE expects a single-valued boolean

foo:keyword
b
;

caseOnConvertMvSliceSv
required_capability: case_mv

ROW foo = ["true", "false", "false"] | EVAL foo = CASE(MV_SLICE(foo::BOOLEAN, 0), "a", "b");

foo:keyword
a
;

docsCaseSuccessRate
// tag::docsCaseSuccessRate[]
FROM sample_data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ synopsis:keyword
"double avg(number:double|integer|long)"
"double|date bin(field:integer|long|double|date, buckets:integer|long|double|date_period|time_duration, ?from:integer|long|double|date|keyword|text, ?to:integer|long|double|date|keyword|text)"
"double|date bucket(field:integer|long|double|date, buckets:integer|long|double|date_period|time_duration, ?from:integer|long|double|date|keyword|text, ?to:integer|long|double|date|keyword|text)"
"boolean|cartesian_point|date|double|geo_point|integer|ip|keyword|long|text|unsigned_long|version case(condition:boolean, trueValue...:boolean|cartesian_point|date|double|geo_point|integer|ip|keyword|long|text|unsigned_long|version)"
"boolean|cartesian_point|cartesian_shape|date|date_nanos|double|geo_point|geo_shape|integer|ip|keyword|long|text|unsigned_long|version case(condition:boolean, trueValue...:boolean|cartesian_point|cartesian_shape|date|date_nanos|double|geo_point|geo_shape|integer|ip|keyword|long|text|unsigned_long|version)"
"double cbrt(number:double|integer|long|unsigned_long)"
"double|integer|long|unsigned_long ceil(number:double|integer|long|unsigned_long)"
"boolean cidr_match(ip:ip, blockX...:keyword|text)"
Expand Down Expand Up @@ -135,7 +135,7 @@ atan2 |[y_coordinate, x_coordinate] |["double|integer|long|unsign
avg |number |"double|integer|long" |[""]
bin |[field, buckets, from, to] |["integer|long|double|date", "integer|long|double|date_period|time_duration", "integer|long|double|date|keyword|text", "integer|long|double|date|keyword|text"] |[Numeric or date expression from which to derive buckets., Target number of buckets\, or desired bucket size if `from` and `to` parameters are omitted., Start of the range. Can be a number\, a date or a date expressed as a string., End of the range. Can be a number\, a date or a date expressed as a string.]
bucket |[field, buckets, from, to] |["integer|long|double|date", "integer|long|double|date_period|time_duration", "integer|long|double|date|keyword|text", "integer|long|double|date|keyword|text"] |[Numeric or date expression from which to derive buckets., Target number of buckets\, or desired bucket size if `from` and `to` parameters are omitted., Start of the range. Can be a number\, a date or a date expressed as a string., End of the range. Can be a number\, a date or a date expressed as a string.]
case |[condition, trueValue] |[boolean, "boolean|cartesian_point|date|double|geo_point|integer|ip|keyword|long|text|unsigned_long|version"] |[A condition., The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches.]
case |[condition, trueValue] |[boolean, "boolean|cartesian_point|cartesian_shape|date|date_nanos|double|geo_point|geo_shape|integer|ip|keyword|long|text|unsigned_long|version"] |[A condition., The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches.]
cbrt |number |"double|integer|long|unsigned_long" |"Numeric expression. If `null`, the function returns `null`."
ceil |number |"double|integer|long|unsigned_long" |Numeric expression. If `null`, the function returns `null`.
cidr_match |[ip, blockX] |[ip, "keyword|text"] |[IP address of type `ip` (both IPv4 and IPv6 are supported)., CIDR block to test the IP against.]
Expand Down Expand Up @@ -385,7 +385,7 @@ atan2 |double
avg |double |false |false |true
bin |"double|date" |[false, false, true, true] |false |false
bucket |"double|date" |[false, false, true, true] |false |false
case |"boolean|cartesian_point|date|double|geo_point|integer|ip|keyword|long|text|unsigned_long|version" |[false, false] |true |false
case |"boolean|cartesian_point|cartesian_shape|date|date_nanos|double|geo_point|geo_shape|integer|ip|keyword|long|text|unsigned_long|version" |[false, false] |true |false
cbrt |double |false |false |false
ceil |"double|integer|long|unsigned_long" |false |false |false
cidr_match |boolean |[false, false] |true |false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ public enum Cap {
*/
AGG_TOP_IP_SUPPORT,

/**
* {@code CASE} properly handling multivalue conditions.
*/
CASE_MV,

/**
* Optimization for ST_CENTROID changed some results in cartesian data. #108713
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ public void registerException(Exception exception) {
};

/**
* Create a new warnings object based on the given mode
* Create a new warnings object based on the given mode which warns that
* it treats the result as {@code null}.
* @param warningsMode The warnings collection strategy to use
* @param source used to indicate where in the query the warning occurred
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've given us a little more flexibility, specifically so I can tell folks that multivalues in CASE default to false. If it's confusing why we do that read the PR description. Er, well, it's still confusing a bit. But that tries to explain it.

* @return A warnings collector object
Expand All @@ -41,6 +42,17 @@ public static Warnings createWarnings(DriverContext.WarningsMode warningsMode, S
return createWarnings(warningsMode, source, "treating result as null");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: to avoid drift, we could put treating result as null into a string constant.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

}

/**
* Create a new warnings object based on the given mode which warns that
* it treats the result as {@code false}.
* @param warningsMode The warnings collection strategy to use
* @param source used to indicate where in the query the warning occurred
* @return A warnings collector object
*/
public static Warnings createWarningsTreatedAsFalse(DriverContext.WarningsMode warningsMode, Source source) {
return createWarnings(warningsMode, source, "treating result as false");
}

/**
* Create a new warnings object based on the given mode
* @param warningsMode The warnings collection strategy to use
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,12 @@ public Case(
type = {
"boolean",
"cartesian_point",
"cartesian_shape",
"date",
"date_nanos",
"double",
"geo_point",
"geo_shape",
"integer",
"ip",
"keyword",
Expand Down Expand Up @@ -224,18 +227,13 @@ public boolean foldable() {
if (condition.condition.foldable() == false) {
return false;
}
Object o = condition.condition.fold();
if (o instanceof List) {
if (Boolean.TRUE.equals(condition.condition.fold())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++

/*
* multivalued fields fold to null which folds to false.
* So they *are* foldable if the value is foldable.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the comment here means the converse.

Suggested change
* So they *are* foldable if the value is foldable.
* So they *are* foldable if remaining list of conditions is foldable.

Same in line 271.

*/
return condition.value.foldable();
}
Boolean b = (Boolean) o;
if (b != null && b) {
return condition.value.foldable();
}
}
return elseValue.foldable();
}
Expand Down Expand Up @@ -267,13 +265,11 @@ public Expression partiallyFold() {
continue;
}
modified = true;
Object o = condition.condition.fold();
if (o instanceof List) {
// multivalued field folds to null which folds to false
continue;
}
Boolean b = (Boolean) condition.condition.fold();
if (b != null && b) {
if (Boolean.TRUE.equals(condition.condition.fold())) {
/*
* multivalued fields fold to null which folds to false.
* So they *are* foldable if the value is foldable.
*/
newChildren.add(condition.value);
return finishPartialFold(newChildren);
}
Expand All @@ -288,10 +284,11 @@ public Expression partiallyFold() {
}

private Expression finishPartialFold(List<Expression> newChildren) {
if (newChildren.size() == 1) {
return newChildren.get(0);
}
return replaceChildren(newChildren);
return switch (newChildren.size()) {
case 0 -> new Literal(source(), null, dataType());
case 1 -> newChildren.get(0);
default -> replaceChildren(newChildren);
};
}

@Override
Expand Down Expand Up @@ -334,7 +331,7 @@ public ConditionEvaluator apply(DriverContext driverContext) {
* Rather than go into depth about this in the warning message,
* we just say "false".
*/
Warnings.createWarnings(driverContext.warningsMode(), conditionSource, "treating result as false"),
Warnings.createWarningsTreatedAsFalse(driverContext.warningsMode(), conditionSource),
condition.get(driverContext),
value.get(driverContext)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ protected static List<TestCaseSupplier> withNoRowsExpectingNull(List<TestCaseSup
testCase.expectedType(),
nullValue(),
null,
null,
testCase.getExpectedTypeError(),
null,
null,
null
);
}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,10 @@ protected static List<TestCaseSupplier> anyNullIsNull(
expectedType.expectedType(finalNullPosition, nulledData.type(), oc),
nullValue(),
null,
null,
oc.getExpectedTypeError(),
null,
null,
null
);
}));
Expand Down Expand Up @@ -246,8 +248,10 @@ protected static List<TestCaseSupplier> anyNullIsNull(
expectedType.expectedType(finalNullPosition, DataType.NULL, oc),
nullValue(),
null,
null,
oc.getExpectedTypeError(),
null,
null,
null
);
}));
Expand Down Expand Up @@ -642,9 +646,11 @@ protected static List<TestCaseSupplier> randomizeBytesRefsOffset(List<TestCaseSu
testCase.expectedType(),
testCase.getMatcher(),
testCase.getExpectedWarnings(),
testCase.getExpectedBuildEvaluatorWarnings(),
testCase.getExpectedTypeError(),
testCase.foldingExceptionClass(),
testCase.foldingExceptionMessage()
testCase.foldingExceptionMessage(),
testCase.extra()
);
})).toList();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.elasticsearch.compute.data.BlockUtils.toJavaObject;
import static org.hamcrest.Matchers.either;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
Expand Down Expand Up @@ -120,6 +119,9 @@ public final void testEvaluate() {

Object result;
try (ExpressionEvaluator evaluator = evaluator(expression).get(driverContext())) {
if (testCase.getExpectedBuildEvaluatorWarnings() != null) {
assertWarnings(testCase.getExpectedBuildEvaluatorWarnings());
}
try (Block block = evaluator.eval(row(testCase.getDataValues()))) {
assertThat(block.getPositionCount(), is(1));
result = toJavaObjectUnsignedLongAware(block, 0);
Expand Down Expand Up @@ -177,6 +179,10 @@ public final void testEvaluateBlockWithNulls() {
*/
public final void testCrankyEvaluateBlockWithoutNulls() {
assumeTrue("sometimes the cranky breaker silences warnings, just skip these cases", testCase.getExpectedWarnings() == null);
assumeTrue(
"sometimes the cranky breaker silences warnings, just skip these cases",
testCase.getExpectedBuildEvaluatorWarnings() == null
);
try {
testEvaluateBlock(driverContext().blockFactory(), crankyContext(), false);
} catch (CircuitBreakingException ex) {
Expand All @@ -190,6 +196,10 @@ public final void testCrankyEvaluateBlockWithoutNulls() {
*/
public final void testCrankyEvaluateBlockWithNulls() {
assumeTrue("sometimes the cranky breaker silences warnings, just skip these cases", testCase.getExpectedWarnings() == null);
assumeTrue(
"sometimes the cranky breaker silences warnings, just skip these cases",
testCase.getExpectedBuildEvaluatorWarnings() == null
);
try {
testEvaluateBlock(driverContext().blockFactory(), crankyContext(), true);
} catch (CircuitBreakingException ex) {
Expand Down Expand Up @@ -242,10 +252,13 @@ private void testEvaluateBlock(BlockFactory inputBlockFactory, DriverContext con
ExpressionEvaluator eval = evaluator(expression).get(context);
Block block = eval.eval(new Page(positions, manyPositionsBlocks))
) {
if (testCase.getExpectedBuildEvaluatorWarnings() != null) {
assertWarnings(testCase.getExpectedBuildEvaluatorWarnings());
}
assertThat(block.getPositionCount(), is(positions));
for (int p = 0; p < positions; p++) {
if (nullPositions.contains(p)) {
assertThat(toJavaObject(block, p), allNullsMatcher());
assertThat(toJavaObjectUnsignedLongAware(block, p), allNullsMatcher());
continue;
}
assertThat(toJavaObjectUnsignedLongAware(block, p), testCase.getMatcher());
Expand Down Expand Up @@ -275,6 +288,9 @@ public final void testEvaluateInManyThreads() throws ExecutionException, Interru
int count = 10_000;
int threads = 5;
var evalSupplier = evaluator(expression);
if (testCase.getExpectedBuildEvaluatorWarnings() != null) {
assertWarnings(testCase.getExpectedBuildEvaluatorWarnings());
}
ExecutorService exec = Executors.newFixedThreadPool(threads);
try {
List<Future<?>> futures = new ArrayList<>();
Expand Down Expand Up @@ -310,6 +326,9 @@ public final void testEvaluatorToString() {
assumeTrue("Can't build evaluator", testCase.canBuildEvaluator());
var factory = evaluator(expression);
try (ExpressionEvaluator ev = factory.get(driverContext())) {
if (testCase.getExpectedBuildEvaluatorWarnings() != null) {
assertWarnings(testCase.getExpectedBuildEvaluatorWarnings());
}
assertThat(ev.toString(), testCase.evaluatorToString());
}
}
Expand All @@ -322,6 +341,9 @@ public final void testFactoryToString() {
}
assumeTrue("Can't build evaluator", testCase.canBuildEvaluator());
var factory = evaluator(buildFieldExpression(testCase));
if (testCase.getExpectedBuildEvaluatorWarnings() != null) {
assertWarnings(testCase.getExpectedBuildEvaluatorWarnings());
}
assertThat(factory.toString(), testCase.evaluatorToString());
}

Expand All @@ -342,6 +364,9 @@ public final void testFold() {
result = NumericUtils.unsignedLongAsBigInteger((Long) result);
}
assertThat(result, testCase.getMatcher());
if (testCase.getExpectedBuildEvaluatorWarnings() != null) {
assertWarnings(testCase.getExpectedBuildEvaluatorWarnings());
}
if (testCase.getExpectedWarnings() != null) {
assertWarnings(testCase.getExpectedWarnings());
}
Expand Down
Loading