Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.type.EsField;
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateTrunc;
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Abs;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin;
Expand All @@ -53,6 +55,7 @@

import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -91,6 +94,8 @@ public class EvalBenchmark {
"abs",
"add",
"add_double",
"case_1_eager",
"case_1_lazy",
"date_trunc",
"equal_to_const",
"long_equal_to_long",
Expand Down Expand Up @@ -125,6 +130,18 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) {
layout(doubleField)
).get(driverContext);
}
case "case_1_eager", "case_1_lazy" -> {
FieldAttribute f1 = longField();
FieldAttribute f2 = longField();
Expression condition = new Equals(Source.EMPTY, f1, new Literal(Source.EMPTY, 1L, DataType.LONG));
Expression lhs = f1;
Expression rhs = f2;
if (operation.endsWith("lazy")) {
lhs = new Add(Source.EMPTY, lhs, new Literal(Source.EMPTY, 1L, DataType.LONG));
rhs = new Add(Source.EMPTY, rhs, new Literal(Source.EMPTY, 1L, DataType.LONG));
}
yield EvalMapper.toEvaluator(new Case(Source.EMPTY, condition, List.of(lhs, rhs)), layout(f1, f2)).get(driverContext);
}
case "date_trunc" -> {
FieldAttribute timestamp = new FieldAttribute(
Source.EMPTY,
Expand Down Expand Up @@ -216,6 +233,28 @@ private static void checkExpected(String operation, Page actual) {
}
}
}
case "case_1_eager" -> {
LongVector f1 = actual.<LongBlock>getBlock(0).asVector();
LongVector f2 = actual.<LongBlock>getBlock(1).asVector();
LongVector result = actual.<LongBlock>getBlock(2).asVector();
for (int i = 0; i < BLOCK_LENGTH; i++) {
long expected = f1.getLong(i) == 1 ? f1.getLong(i) : f2.getLong(i);
if (result.getLong(i) != expected) {
throw new AssertionError("[" + operation + "] expected [" + expected + "] but was [" + result.getLong(i) + "]");
}
}
}
case "case_1_lazy" -> {
LongVector f1 = actual.<LongBlock>getBlock(0).asVector();
LongVector f2 = actual.<LongBlock>getBlock(1).asVector();
LongVector result = actual.<LongBlock>getBlock(2).asVector();
for (int i = 0; i < BLOCK_LENGTH; i++) {
long expected = 1 + (f1.getLong(i) == 1 ? f1.getLong(i) : f2.getLong(i));
if (result.getLong(i) != expected) {
throw new AssertionError("[" + operation + "] expected [" + expected + "] but was [" + result.getLong(i) + "]");
}
}
}
case "date_trunc" -> {
LongVector v = actual.<LongBlock>getBlock(1).asVector();
long oneDay = TimeValue.timeValueHours(24).millis();
Expand Down Expand Up @@ -280,6 +319,15 @@ private static Page page(String operation) {
}
yield new Page(builder.build());
}
case "case_1_eager", "case_1_lazy" -> {
var f1 = blockFactory.newLongBlockBuilder(BLOCK_LENGTH);
var f2 = blockFactory.newLongBlockBuilder(BLOCK_LENGTH);
for (int i = 0; i < BLOCK_LENGTH; i++) {
f1.appendLong(i);
f2.appendLong(-i);
}
yield new Page(f1.build(), f2.build());
}
case "long_equal_to_long" -> {
var lhs = blockFactory.newLongBlockBuilder(BLOCK_LENGTH);
var rhs = blockFactory.newLongBlockBuilder(BLOCK_LENGTH);
Expand Down
5 changes: 5 additions & 0 deletions docs/changelog/112295.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 112295
summary: "ESQL: Speed up CASE for some parameters"
area: ES|QL
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ public interface ExpressionEvaluator extends Releasable {
/** A Factory for creating ExpressionEvaluators. */
interface Factory {
ExpressionEvaluator get(DriverContext context);

/**
* {@code true} if it is safe and fast to evaluate this expression eagerly
* in {@link ExpressionEvaluator}s that need to be lazy, like {@code CASE}.
* This defaults to {@code false}, but expressions
* that evaluate quickly and can not produce warnings may override this to
* {@code true} to get a significant speed-up in {@code CASE}-like operations.
*/
default boolean eagerEvalSafeInLazy() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another alternative is to be extract this as a marking interface that gets implemented by certain factories.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a lot less readable.

return false;
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this is something that we could derive from the Expression - but it felt simpler to put it here. And it's just a boolean at this point - though maybe it should be a cost estimate at some point.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case of constants, the unvisited branches can be removed and the case simplified.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works if the expression is constant, but not if the values are constant. We still have to evaluate in that case. And when we do we can do it the fast way with something like this.

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ public ExpressionEvaluator get(DriverContext driverContext) {
public String toString() {
return "Attribute[channel=" + channel + "]";
}

@Override
public boolean eagerEvalSafeInLazy() {
return true;
}
}
return new AttributeFactory(layout.get(attr.id()).channel());
}
Expand Down Expand Up @@ -209,6 +214,11 @@ public ExpressionEvaluator get(DriverContext driverContext) {
public String toString() {
return "LiteralsEvaluator[lit=" + lit + "]";
}

@Override
public boolean eagerEvalSafeInLazy() {
return true;
}
}
return new LiteralsEvaluatorFactory(lit);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.data.ToMask;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
Expand Down Expand Up @@ -311,25 +312,16 @@ private Expression finishPartialFold(List<Expression> newChildren) {

@Override
public ExpressionEvaluator.Factory toEvaluator(Function<Expression, ExpressionEvaluator.Factory> toEvaluator) {
ElementType resultType = PlannerUtils.toElementType(dataType());
List<ConditionEvaluatorSupplier> conditionsFactories = conditions.stream().map(c -> c.toEvaluator(toEvaluator)).toList();
ExpressionEvaluator.Factory elseValueFactory = toEvaluator.apply(elseValue);
return new ExpressionEvaluator.Factory() {
@Override
public ExpressionEvaluator get(DriverContext context) {
return new CaseEvaluator(
context.blockFactory(),
resultType,
conditionsFactories.stream().map(x -> x.apply(context)).toList(),
elseValueFactory.get(context)
);
}
ElementType resultType = PlannerUtils.toElementType(dataType());

@Override
public String toString() {
return "CaseEvaluator[conditions=" + conditionsFactories + ", elseVal=" + elseValueFactory + ']';
}
};
if (conditionsFactories.size() == 1
&& conditionsFactories.get(0).value.eagerEvalSafeInLazy()
&& elseValueFactory.eagerEvalSafeInLazy()) {
return new CaseEagerEvaluatorFactory(resultType, conditionsFactories.get(0), elseValueFactory);
}
return new CaseLazyEvaluatorFactory(resultType, conditionsFactories, elseValueFactory);
}

record ConditionEvaluatorSupplier(Source conditionSource, ExpressionEvaluator.Factory condition, ExpressionEvaluator.Factory value)
Expand Down Expand Up @@ -375,9 +367,42 @@ public void close() {
public String toString() {
return "ConditionEvaluator[condition=" + condition + ", value=" + value + ']';
}

public void registerMultivalue() {
conditionWarnings.registerException(new IllegalArgumentException("CASE expects a single-valued boolean"));
}
}

private record CaseEvaluator(
private record CaseLazyEvaluatorFactory(
ElementType resultType,
List<ConditionEvaluatorSupplier> conditionsFactories,
ExpressionEvaluator.Factory elseValueFactory
) implements ExpressionEvaluator.Factory {
@Override
public ExpressionEvaluator get(DriverContext context) {
List<ConditionEvaluator> conditions = new ArrayList<>(conditionsFactories.size());
ExpressionEvaluator elseValue = null;
try {
for (ConditionEvaluatorSupplier cond : conditionsFactories) {
conditions.add(cond.apply(context));
}
elseValue = elseValueFactory.get(context);
ExpressionEvaluator result = new CaseLazyEvaluator(context.blockFactory(), resultType, conditions, elseValue);
conditions = null;
elseValue = null;
return result;
} finally {
Releasables.close(conditions == null ? () -> {} : Releasables.wrap(conditions), elseValue);
}
}

@Override
public String toString() {
return "CaseLazyEvaluator[conditions=" + conditionsFactories + ", elseVal=" + elseValueFactory + ']';
}
}

private record CaseLazyEvaluator(
BlockFactory blockFactory,
ElementType resultType,
List<ConditionEvaluator> conditions,
Expand Down Expand Up @@ -409,9 +434,7 @@ public Block eval(Page page) {
continue;
}
if (b.getValueCount(0) > 1) {
condition.conditionWarnings.registerException(
new IllegalArgumentException("CASE expects a single-valued boolean")
);
condition.registerMultivalue();
continue;
}
if (false == b.getBoolean(b.getFirstValueIndex(0))) {
Expand Down Expand Up @@ -439,7 +462,80 @@ public void close() {

@Override
public String toString() {
return "CaseEvaluator[conditions=" + conditions + ", elseVal=" + elseVal + ']';
return "CaseLazyEvaluator[conditions=" + conditions + ", elseVal=" + elseVal + ']';
}
}

private record CaseEagerEvaluatorFactory(
ElementType resultType,
ConditionEvaluatorSupplier conditionFactory,
ExpressionEvaluator.Factory elseValueFactory
) implements ExpressionEvaluator.Factory {
@Override
public ExpressionEvaluator get(DriverContext context) {
ConditionEvaluator conditionEvaluator = conditionFactory.apply(context);
ExpressionEvaluator elseValue = null;
try {
elseValue = elseValueFactory.get(context);
ExpressionEvaluator result = new CaseEagerEvaluator(resultType, context.blockFactory(), conditionEvaluator, elseValue);
conditionEvaluator = null;
elseValue = null;
return result;
} finally {
Releasables.close(conditionEvaluator, elseValue);
}
}

@Override
public String toString() {
return "CaseEagerEvaluator[conditions=[" + conditionFactory + "], elseVal=" + elseValueFactory + ']';
}
}

private record CaseEagerEvaluator(
ElementType resultType,
BlockFactory blockFactory,
ConditionEvaluator condition,
EvalOperator.ExpressionEvaluator elseVal
) implements EvalOperator.ExpressionEvaluator {
@Override
public Block eval(Page page) {
try (BooleanBlock lhsOrRhsBlock = (BooleanBlock) condition.condition.eval(page); ToMask lhsOrRhs = lhsOrRhsBlock.toMask()) {
if (lhsOrRhs.hadMultivaluedFields()) {
condition.registerMultivalue();
}
if (lhsOrRhs.mask().isConstant()) {
if (lhsOrRhs.mask().getBoolean(0)) {
return condition.value.eval(page);
} else {
return elseVal.eval(page);
}
}
try (
Block lhs = condition.value.eval(page);
Block rhs = elseVal.eval(page);
Block.Builder builder = resultType.newBlockBuilder(lhs.getTotalValueCount(), blockFactory)
) {
for (int p = 0; p < lhs.getPositionCount(); p++) {
if (lhsOrRhs.mask().getBoolean(p)) {
builder.copyFrom(lhs, p, p + 1);
} else {
builder.copyFrom(rhs, p, p + 1);
}
}
return builder.build();
}
}
}

@Override
public void close() {
Releasables.closeExpectNoException(condition, elseVal);
}

@Override
public String toString() {
return "CaseEagerEvaluator[conditions=[" + condition + "], elseVal=" + elseVal + ']';
}
}
}
Loading