Skip to content

Commit 7f5e36a

Browse files
committed
ESQL: Speed up CASE for some parameters (elastic#112295)
This speeds up the `CASE` function when it has two or three arguments and both of the arguments are constants or fields. This works because `CASE` is lazy so it can avoid warnings in cases like ``` CASE(foo != 0, 2 / foo, 1) ``` And, in the case where the function is *very* slow, it can avoid the computations. But if the lhs and rhs of the `CASE` are constant then there isn't any work to avoid. The performance improvment is pretty substantial: ``` (operation) Before Error After Error Units case_1_lazy 97.422 ± 1.048 101.571 ± 0.737 ns/op case_1_eager 79.312 ± 1.190 4.601 ± 0.049 ns/op ``` The top line is a `CASE` that has to be lazy - it shouldn't change. The 4 nanos change here is noise. The eager version improves by about 94%.
1 parent b74e592 commit 7f5e36a

File tree

6 files changed

+210
-39
lines changed

6 files changed

+210
-39
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/EvalBenchmark.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@
2525
import org.elasticsearch.compute.operator.EvalOperator;
2626
import org.elasticsearch.compute.operator.Operator;
2727
import org.elasticsearch.core.TimeValue;
28+
import org.elasticsearch.xpack.esql.core.expression.Expression;
2829
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
2930
import org.elasticsearch.xpack.esql.core.expression.Literal;
3031
import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern;
3132
import org.elasticsearch.xpack.esql.core.tree.Source;
3233
import org.elasticsearch.xpack.esql.core.type.DataType;
3334
import org.elasticsearch.xpack.esql.core.type.EsField;
3435
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
36+
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
3537
import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateTrunc;
3638
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Abs;
3739
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin;
@@ -53,6 +55,7 @@
5355

5456
import java.time.Duration;
5557
import java.util.Arrays;
58+
import java.util.List;
5659
import java.util.Map;
5760
import java.util.concurrent.TimeUnit;
5861

@@ -91,6 +94,8 @@ public class EvalBenchmark {
9194
"abs",
9295
"add",
9396
"add_double",
97+
"case_1_eager",
98+
"case_1_lazy",
9499
"date_trunc",
95100
"equal_to_const",
96101
"long_equal_to_long",
@@ -125,6 +130,18 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) {
125130
layout(doubleField)
126131
).get(driverContext);
127132
}
133+
case "case_1_eager", "case_1_lazy" -> {
134+
FieldAttribute f1 = longField();
135+
FieldAttribute f2 = longField();
136+
Expression condition = new Equals(Source.EMPTY, f1, new Literal(Source.EMPTY, 1L, DataType.LONG));
137+
Expression lhs = f1;
138+
Expression rhs = f2;
139+
if (operation.endsWith("lazy")) {
140+
lhs = new Add(Source.EMPTY, lhs, new Literal(Source.EMPTY, 1L, DataType.LONG));
141+
rhs = new Add(Source.EMPTY, rhs, new Literal(Source.EMPTY, 1L, DataType.LONG));
142+
}
143+
yield EvalMapper.toEvaluator(new Case(Source.EMPTY, condition, List.of(lhs, rhs)), layout(f1, f2)).get(driverContext);
144+
}
128145
case "date_trunc" -> {
129146
FieldAttribute timestamp = new FieldAttribute(
130147
Source.EMPTY,
@@ -216,6 +233,28 @@ private static void checkExpected(String operation, Page actual) {
216233
}
217234
}
218235
}
236+
case "case_1_eager" -> {
237+
LongVector f1 = actual.<LongBlock>getBlock(0).asVector();
238+
LongVector f2 = actual.<LongBlock>getBlock(1).asVector();
239+
LongVector result = actual.<LongBlock>getBlock(2).asVector();
240+
for (int i = 0; i < BLOCK_LENGTH; i++) {
241+
long expected = f1.getLong(i) == 1 ? f1.getLong(i) : f2.getLong(i);
242+
if (result.getLong(i) != expected) {
243+
throw new AssertionError("[" + operation + "] expected [" + expected + "] but was [" + result.getLong(i) + "]");
244+
}
245+
}
246+
}
247+
case "case_1_lazy" -> {
248+
LongVector f1 = actual.<LongBlock>getBlock(0).asVector();
249+
LongVector f2 = actual.<LongBlock>getBlock(1).asVector();
250+
LongVector result = actual.<LongBlock>getBlock(2).asVector();
251+
for (int i = 0; i < BLOCK_LENGTH; i++) {
252+
long expected = 1 + (f1.getLong(i) == 1 ? f1.getLong(i) : f2.getLong(i));
253+
if (result.getLong(i) != expected) {
254+
throw new AssertionError("[" + operation + "] expected [" + expected + "] but was [" + result.getLong(i) + "]");
255+
}
256+
}
257+
}
219258
case "date_trunc" -> {
220259
LongVector v = actual.<LongBlock>getBlock(1).asVector();
221260
long oneDay = TimeValue.timeValueHours(24).millis();
@@ -280,6 +319,15 @@ private static Page page(String operation) {
280319
}
281320
yield new Page(builder.build());
282321
}
322+
case "case_1_eager", "case_1_lazy" -> {
323+
var f1 = blockFactory.newLongBlockBuilder(BLOCK_LENGTH);
324+
var f2 = blockFactory.newLongBlockBuilder(BLOCK_LENGTH);
325+
for (int i = 0; i < BLOCK_LENGTH; i++) {
326+
f1.appendLong(i);
327+
f2.appendLong(-i);
328+
}
329+
yield new Page(f1.build(), f2.build());
330+
}
283331
case "long_equal_to_long" -> {
284332
var lhs = blockFactory.newLongBlockBuilder(BLOCK_LENGTH);
285333
var rhs = blockFactory.newLongBlockBuilder(BLOCK_LENGTH);

docs/changelog/112295.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 112295
2+
summary: "ESQL: Speed up CASE for some parameters"
3+
area: ES|QL
4+
type: enhancement
5+
issues: []

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,17 @@ public interface ExpressionEvaluator extends Releasable {
6363
/** A Factory for creating ExpressionEvaluators. */
6464
interface Factory {
6565
ExpressionEvaluator get(DriverContext context);
66+
67+
/**
68+
* {@code true} if it is safe and fast to evaluate this expression eagerly
69+
* in {@link ExpressionEvaluator}s that need to be lazy, like {@code CASE}.
70+
* This defaults to {@code false}, but expressions
71+
* that evaluate quickly and can not produce warnings may override this to
72+
* {@code true} to get a significant speed-up in {@code CASE}-like operations.
73+
*/
74+
default boolean eagerEvalSafeInLazy() {
75+
return false;
76+
}
6677
}
6778

6879
/**

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,11 @@ public ExpressionEvaluator get(DriverContext driverContext) {
176176
public String toString() {
177177
return "Attribute[channel=" + channel + "]";
178178
}
179+
180+
@Override
181+
public boolean eagerEvalSafeInLazy() {
182+
return true;
183+
}
179184
}
180185
return new AttributeFactory(layout.get(attr.id()).channel());
181186
}
@@ -209,6 +214,11 @@ public ExpressionEvaluator get(DriverContext driverContext) {
209214
public String toString() {
210215
return "LiteralsEvaluator[lit=" + lit + "]";
211216
}
217+
218+
@Override
219+
public boolean eagerEvalSafeInLazy() {
220+
return true;
221+
}
212222
}
213223
return new LiteralsEvaluatorFactory(lit);
214224
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java

Lines changed: 117 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.compute.data.BooleanBlock;
1616
import org.elasticsearch.compute.data.ElementType;
1717
import org.elasticsearch.compute.data.Page;
18+
import org.elasticsearch.compute.data.ToMask;
1819
import org.elasticsearch.compute.operator.DriverContext;
1920
import org.elasticsearch.compute.operator.EvalOperator;
2021
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
@@ -311,25 +312,16 @@ private Expression finishPartialFold(List<Expression> newChildren) {
311312

312313
@Override
313314
public ExpressionEvaluator.Factory toEvaluator(Function<Expression, ExpressionEvaluator.Factory> toEvaluator) {
314-
ElementType resultType = PlannerUtils.toElementType(dataType());
315315
List<ConditionEvaluatorSupplier> conditionsFactories = conditions.stream().map(c -> c.toEvaluator(toEvaluator)).toList();
316316
ExpressionEvaluator.Factory elseValueFactory = toEvaluator.apply(elseValue);
317-
return new ExpressionEvaluator.Factory() {
318-
@Override
319-
public ExpressionEvaluator get(DriverContext context) {
320-
return new CaseEvaluator(
321-
context.blockFactory(),
322-
resultType,
323-
conditionsFactories.stream().map(x -> x.apply(context)).toList(),
324-
elseValueFactory.get(context)
325-
);
326-
}
317+
ElementType resultType = PlannerUtils.toElementType(dataType());
327318

328-
@Override
329-
public String toString() {
330-
return "CaseEvaluator[conditions=" + conditionsFactories + ", elseVal=" + elseValueFactory + ']';
331-
}
332-
};
319+
if (conditionsFactories.size() == 1
320+
&& conditionsFactories.get(0).value.eagerEvalSafeInLazy()
321+
&& elseValueFactory.eagerEvalSafeInLazy()) {
322+
return new CaseEagerEvaluatorFactory(resultType, conditionsFactories.get(0), elseValueFactory);
323+
}
324+
return new CaseLazyEvaluatorFactory(resultType, conditionsFactories, elseValueFactory);
333325
}
334326

335327
record ConditionEvaluatorSupplier(Source conditionSource, ExpressionEvaluator.Factory condition, ExpressionEvaluator.Factory value)
@@ -375,9 +367,42 @@ public void close() {
375367
public String toString() {
376368
return "ConditionEvaluator[condition=" + condition + ", value=" + value + ']';
377369
}
370+
371+
public void registerMultivalue() {
372+
conditionWarnings.registerException(new IllegalArgumentException("CASE expects a single-valued boolean"));
373+
}
378374
}
379375

380-
private record CaseEvaluator(
376+
private record CaseLazyEvaluatorFactory(
377+
ElementType resultType,
378+
List<ConditionEvaluatorSupplier> conditionsFactories,
379+
ExpressionEvaluator.Factory elseValueFactory
380+
) implements ExpressionEvaluator.Factory {
381+
@Override
382+
public ExpressionEvaluator get(DriverContext context) {
383+
List<ConditionEvaluator> conditions = new ArrayList<>(conditionsFactories.size());
384+
ExpressionEvaluator elseValue = null;
385+
try {
386+
for (ConditionEvaluatorSupplier cond : conditionsFactories) {
387+
conditions.add(cond.apply(context));
388+
}
389+
elseValue = elseValueFactory.get(context);
390+
ExpressionEvaluator result = new CaseLazyEvaluator(context.blockFactory(), resultType, conditions, elseValue);
391+
conditions = null;
392+
elseValue = null;
393+
return result;
394+
} finally {
395+
Releasables.close(conditions == null ? () -> {} : Releasables.wrap(conditions), elseValue);
396+
}
397+
}
398+
399+
@Override
400+
public String toString() {
401+
return "CaseLazyEvaluator[conditions=" + conditionsFactories + ", elseVal=" + elseValueFactory + ']';
402+
}
403+
}
404+
405+
private record CaseLazyEvaluator(
381406
BlockFactory blockFactory,
382407
ElementType resultType,
383408
List<ConditionEvaluator> conditions,
@@ -409,9 +434,7 @@ public Block eval(Page page) {
409434
continue;
410435
}
411436
if (b.getValueCount(0) > 1) {
412-
condition.conditionWarnings.registerException(
413-
new IllegalArgumentException("CASE expects a single-valued boolean")
414-
);
437+
condition.registerMultivalue();
415438
continue;
416439
}
417440
if (false == b.getBoolean(b.getFirstValueIndex(0))) {
@@ -439,7 +462,80 @@ public void close() {
439462

440463
@Override
441464
public String toString() {
442-
return "CaseEvaluator[conditions=" + conditions + ", elseVal=" + elseVal + ']';
465+
return "CaseLazyEvaluator[conditions=" + conditions + ", elseVal=" + elseVal + ']';
466+
}
467+
}
468+
469+
private record CaseEagerEvaluatorFactory(
470+
ElementType resultType,
471+
ConditionEvaluatorSupplier conditionFactory,
472+
ExpressionEvaluator.Factory elseValueFactory
473+
) implements ExpressionEvaluator.Factory {
474+
@Override
475+
public ExpressionEvaluator get(DriverContext context) {
476+
ConditionEvaluator conditionEvaluator = conditionFactory.apply(context);
477+
ExpressionEvaluator elseValue = null;
478+
try {
479+
elseValue = elseValueFactory.get(context);
480+
ExpressionEvaluator result = new CaseEagerEvaluator(resultType, context.blockFactory(), conditionEvaluator, elseValue);
481+
conditionEvaluator = null;
482+
elseValue = null;
483+
return result;
484+
} finally {
485+
Releasables.close(conditionEvaluator, elseValue);
486+
}
487+
}
488+
489+
@Override
490+
public String toString() {
491+
return "CaseEagerEvaluator[conditions=[" + conditionFactory + "], elseVal=" + elseValueFactory + ']';
492+
}
493+
}
494+
495+
private record CaseEagerEvaluator(
496+
ElementType resultType,
497+
BlockFactory blockFactory,
498+
ConditionEvaluator condition,
499+
EvalOperator.ExpressionEvaluator elseVal
500+
) implements EvalOperator.ExpressionEvaluator {
501+
@Override
502+
public Block eval(Page page) {
503+
try (BooleanBlock lhsOrRhsBlock = (BooleanBlock) condition.condition.eval(page); ToMask lhsOrRhs = lhsOrRhsBlock.toMask()) {
504+
if (lhsOrRhs.hadMultivaluedFields()) {
505+
condition.registerMultivalue();
506+
}
507+
if (lhsOrRhs.mask().isConstant()) {
508+
if (lhsOrRhs.mask().getBoolean(0)) {
509+
return condition.value.eval(page);
510+
} else {
511+
return elseVal.eval(page);
512+
}
513+
}
514+
try (
515+
Block lhs = condition.value.eval(page);
516+
Block rhs = elseVal.eval(page);
517+
Block.Builder builder = resultType.newBlockBuilder(lhs.getTotalValueCount(), blockFactory)
518+
) {
519+
for (int p = 0; p < lhs.getPositionCount(); p++) {
520+
if (lhsOrRhs.mask().getBoolean(p)) {
521+
builder.copyFrom(lhs, p, p + 1);
522+
} else {
523+
builder.copyFrom(rhs, p, p + 1);
524+
}
525+
}
526+
return builder.build();
527+
}
528+
}
529+
}
530+
531+
@Override
532+
public void close() {
533+
Releasables.closeExpectNoException(condition, elseVal);
534+
}
535+
536+
@Override
537+
public String toString() {
538+
return "CaseEagerEvaluator[conditions=[" + condition + "], elseVal=" + elseVal + ']';
443539
}
444540
}
445541
}

0 commit comments

Comments
 (0)