|
15 | 15 | import org.elasticsearch.compute.data.BooleanBlock;
|
16 | 16 | import org.elasticsearch.compute.data.ElementType;
|
17 | 17 | import org.elasticsearch.compute.data.Page;
|
| 18 | +import org.elasticsearch.compute.data.ToMask; |
18 | 19 | import org.elasticsearch.compute.operator.DriverContext;
|
19 | 20 | import org.elasticsearch.compute.operator.EvalOperator;
|
20 | 21 | import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
|
@@ -311,25 +312,16 @@ private Expression finishPartialFold(List<Expression> newChildren) {
|
311 | 312 |
|
312 | 313 | @Override
|
313 | 314 | public ExpressionEvaluator.Factory toEvaluator(Function<Expression, ExpressionEvaluator.Factory> toEvaluator) {
|
314 |
| - ElementType resultType = PlannerUtils.toElementType(dataType()); |
315 | 315 | List<ConditionEvaluatorSupplier> conditionsFactories = conditions.stream().map(c -> c.toEvaluator(toEvaluator)).toList();
|
316 | 316 | ExpressionEvaluator.Factory elseValueFactory = toEvaluator.apply(elseValue);
|
317 |
| - return new ExpressionEvaluator.Factory() { |
318 |
| - @Override |
319 |
| - public ExpressionEvaluator get(DriverContext context) { |
320 |
| - return new CaseEvaluator( |
321 |
| - context.blockFactory(), |
322 |
| - resultType, |
323 |
| - conditionsFactories.stream().map(x -> x.apply(context)).toList(), |
324 |
| - elseValueFactory.get(context) |
325 |
| - ); |
326 |
| - } |
| 317 | + ElementType resultType = PlannerUtils.toElementType(dataType()); |
327 | 318 |
|
328 |
| - @Override |
329 |
| - public String toString() { |
330 |
| - return "CaseEvaluator[conditions=" + conditionsFactories + ", elseVal=" + elseValueFactory + ']'; |
331 |
| - } |
332 |
| - }; |
| 319 | + if (conditionsFactories.size() == 1 |
| 320 | + && conditionsFactories.get(0).value.eagerEvalSafeInLazy() |
| 321 | + && elseValueFactory.eagerEvalSafeInLazy()) { |
| 322 | + return new CaseEagerEvaluatorFactory(resultType, conditionsFactories.get(0), elseValueFactory); |
| 323 | + } |
| 324 | + return new CaseLazyEvaluatorFactory(resultType, conditionsFactories, elseValueFactory); |
333 | 325 | }
|
334 | 326 |
|
335 | 327 | record ConditionEvaluatorSupplier(Source conditionSource, ExpressionEvaluator.Factory condition, ExpressionEvaluator.Factory value)
|
@@ -375,9 +367,42 @@ public void close() {
|
375 | 367 | public String toString() {
|
376 | 368 | return "ConditionEvaluator[condition=" + condition + ", value=" + value + ']';
|
377 | 369 | }
|
| 370 | + |
| 371 | + public void registerMultivalue() { |
| 372 | + conditionWarnings.registerException(new IllegalArgumentException("CASE expects a single-valued boolean")); |
| 373 | + } |
378 | 374 | }
|
379 | 375 |
|
380 |
| - private record CaseEvaluator( |
| 376 | + private record CaseLazyEvaluatorFactory( |
| 377 | + ElementType resultType, |
| 378 | + List<ConditionEvaluatorSupplier> conditionsFactories, |
| 379 | + ExpressionEvaluator.Factory elseValueFactory |
| 380 | + ) implements ExpressionEvaluator.Factory { |
| 381 | + @Override |
| 382 | + public ExpressionEvaluator get(DriverContext context) { |
| 383 | + List<ConditionEvaluator> conditions = new ArrayList<>(conditionsFactories.size()); |
| 384 | + ExpressionEvaluator elseValue = null; |
| 385 | + try { |
| 386 | + for (ConditionEvaluatorSupplier cond : conditionsFactories) { |
| 387 | + conditions.add(cond.apply(context)); |
| 388 | + } |
| 389 | + elseValue = elseValueFactory.get(context); |
| 390 | + ExpressionEvaluator result = new CaseLazyEvaluator(context.blockFactory(), resultType, conditions, elseValue); |
| 391 | + conditions = null; |
| 392 | + elseValue = null; |
| 393 | + return result; |
| 394 | + } finally { |
| 395 | + Releasables.close(conditions == null ? () -> {} : Releasables.wrap(conditions), elseValue); |
| 396 | + } |
| 397 | + } |
| 398 | + |
| 399 | + @Override |
| 400 | + public String toString() { |
| 401 | + return "CaseLazyEvaluator[conditions=" + conditionsFactories + ", elseVal=" + elseValueFactory + ']'; |
| 402 | + } |
| 403 | + } |
| 404 | + |
| 405 | + private record CaseLazyEvaluator( |
381 | 406 | BlockFactory blockFactory,
|
382 | 407 | ElementType resultType,
|
383 | 408 | List<ConditionEvaluator> conditions,
|
@@ -409,9 +434,7 @@ public Block eval(Page page) {
|
409 | 434 | continue;
|
410 | 435 | }
|
411 | 436 | if (b.getValueCount(0) > 1) {
|
412 |
| - condition.conditionWarnings.registerException( |
413 |
| - new IllegalArgumentException("CASE expects a single-valued boolean") |
414 |
| - ); |
| 437 | + condition.registerMultivalue(); |
415 | 438 | continue;
|
416 | 439 | }
|
417 | 440 | if (false == b.getBoolean(b.getFirstValueIndex(0))) {
|
@@ -439,7 +462,80 @@ public void close() {
|
439 | 462 |
|
440 | 463 | @Override
|
441 | 464 | public String toString() {
|
442 |
| - return "CaseEvaluator[conditions=" + conditions + ", elseVal=" + elseVal + ']'; |
| 465 | + return "CaseLazyEvaluator[conditions=" + conditions + ", elseVal=" + elseVal + ']'; |
| 466 | + } |
| 467 | + } |
| 468 | + |
| 469 | + private record CaseEagerEvaluatorFactory( |
| 470 | + ElementType resultType, |
| 471 | + ConditionEvaluatorSupplier conditionFactory, |
| 472 | + ExpressionEvaluator.Factory elseValueFactory |
| 473 | + ) implements ExpressionEvaluator.Factory { |
| 474 | + @Override |
| 475 | + public ExpressionEvaluator get(DriverContext context) { |
| 476 | + ConditionEvaluator conditionEvaluator = conditionFactory.apply(context); |
| 477 | + ExpressionEvaluator elseValue = null; |
| 478 | + try { |
| 479 | + elseValue = elseValueFactory.get(context); |
| 480 | + ExpressionEvaluator result = new CaseEagerEvaluator(resultType, context.blockFactory(), conditionEvaluator, elseValue); |
| 481 | + conditionEvaluator = null; |
| 482 | + elseValue = null; |
| 483 | + return result; |
| 484 | + } finally { |
| 485 | + Releasables.close(conditionEvaluator, elseValue); |
| 486 | + } |
| 487 | + } |
| 488 | + |
| 489 | + @Override |
| 490 | + public String toString() { |
| 491 | + return "CaseEagerEvaluator[conditions=[" + conditionFactory + "], elseVal=" + elseValueFactory + ']'; |
| 492 | + } |
| 493 | + } |
| 494 | + |
| 495 | + private record CaseEagerEvaluator( |
| 496 | + ElementType resultType, |
| 497 | + BlockFactory blockFactory, |
| 498 | + ConditionEvaluator condition, |
| 499 | + EvalOperator.ExpressionEvaluator elseVal |
| 500 | + ) implements EvalOperator.ExpressionEvaluator { |
| 501 | + @Override |
| 502 | + public Block eval(Page page) { |
| 503 | + try (BooleanBlock lhsOrRhsBlock = (BooleanBlock) condition.condition.eval(page); ToMask lhsOrRhs = lhsOrRhsBlock.toMask()) { |
| 504 | + if (lhsOrRhs.hadMultivaluedFields()) { |
| 505 | + condition.registerMultivalue(); |
| 506 | + } |
| 507 | + if (lhsOrRhs.mask().isConstant()) { |
| 508 | + if (lhsOrRhs.mask().getBoolean(0)) { |
| 509 | + return condition.value.eval(page); |
| 510 | + } else { |
| 511 | + return elseVal.eval(page); |
| 512 | + } |
| 513 | + } |
| 514 | + try ( |
| 515 | + Block lhs = condition.value.eval(page); |
| 516 | + Block rhs = elseVal.eval(page); |
| 517 | + Block.Builder builder = resultType.newBlockBuilder(lhs.getTotalValueCount(), blockFactory) |
| 518 | + ) { |
| 519 | + for (int p = 0; p < lhs.getPositionCount(); p++) { |
| 520 | + if (lhsOrRhs.mask().getBoolean(p)) { |
| 521 | + builder.copyFrom(lhs, p, p + 1); |
| 522 | + } else { |
| 523 | + builder.copyFrom(rhs, p, p + 1); |
| 524 | + } |
| 525 | + } |
| 526 | + return builder.build(); |
| 527 | + } |
| 528 | + } |
| 529 | + } |
| 530 | + |
| 531 | + @Override |
| 532 | + public void close() { |
| 533 | + Releasables.closeExpectNoException(condition, elseVal); |
| 534 | + } |
| 535 | + |
| 536 | + @Override |
| 537 | + public String toString() { |
| 538 | + return "CaseEagerEvaluator[conditions=[" + condition + "], elseVal=" + elseVal + ']'; |
443 | 539 | }
|
444 | 540 | }
|
445 | 541 | }
|
0 commit comments