|
12 | 12 | import org.apache.lucene.util.BytesRef; |
13 | 13 | import org.elasticsearch.common.breaker.NoopCircuitBreaker; |
14 | 14 | import org.elasticsearch.common.settings.Settings; |
| 15 | +import org.elasticsearch.common.unit.ByteSizeUnit; |
15 | 16 | import org.elasticsearch.common.util.BigArrays; |
16 | 17 | import org.elasticsearch.compute.data.Block; |
17 | 18 | import org.elasticsearch.compute.data.BlockFactory; |
|
41 | 42 | import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; |
42 | 43 | import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateTrunc; |
43 | 44 | import org.elasticsearch.xpack.esql.expression.function.scalar.math.Abs; |
| 45 | +import org.elasticsearch.xpack.esql.expression.function.scalar.math.RoundTo; |
44 | 46 | import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin; |
45 | 47 | import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; |
46 | 48 | import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike; |
47 | 49 | import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToLower; |
48 | 50 | import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToUpper; |
49 | 51 | import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; |
50 | 52 | import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; |
| 53 | +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan; |
51 | 54 | import org.elasticsearch.xpack.esql.planner.Layout; |
52 | 55 | import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; |
53 | 56 | import org.elasticsearch.xpack.esql.session.Configuration; |
@@ -116,6 +119,10 @@ public class EvalBenchmark { |
116 | 119 | "long_equal_to_int", |
117 | 120 | "mv_min", |
118 | 121 | "mv_min_ascending", |
| 122 | + "round_to_4_via_case", |
| 123 | + "round_to_2", |
| 124 | + "round_to_3", |
| 125 | + "round_to_4", |
119 | 126 | "rlike", |
120 | 127 | "to_lower", |
121 | 128 | "to_lower_ords", |
@@ -228,6 +235,65 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) { |
228 | 235 | RLike rlike = new RLike(Source.EMPTY, keywordField, new RLikePattern(".ar")); |
229 | 236 | yield EvalMapper.toEvaluator(FOLD_CONTEXT, rlike, layout(keywordField)).get(driverContext); |
230 | 237 | } |
| 238 | + case "round_to_4_via_case" -> { |
| 239 | + FieldAttribute f = longField(); |
| 240 | + |
| 241 | + Expression ltkb = new LessThan(Source.EMPTY, f, kb()); |
| 242 | + Expression ltmb = new LessThan(Source.EMPTY, f, mb()); |
| 243 | + Expression ltgb = new LessThan(Source.EMPTY, f, gb()); |
| 244 | + EvalOperator.ExpressionEvaluator evaluator = EvalMapper.toEvaluator( |
| 245 | + FOLD_CONTEXT, |
| 246 | + new Case(Source.EMPTY, ltkb, List.of(b(), ltmb, kb(), ltgb, mb(), gb())), |
| 247 | + layout(f) |
| 248 | + ).get(driverContext); |
| 249 | + String desc = "CaseLazyEvaluator"; |
| 250 | + if (evaluator.toString().contains(desc) == false) { |
| 251 | + throw new IllegalArgumentException("Evaluator was [" + evaluator + "] but expected one containing [" + desc + "]"); |
| 252 | + } |
| 253 | + yield evaluator; |
| 254 | + } |
| 255 | + case "round_to_2" -> { |
| 256 | + FieldAttribute f = longField(); |
| 257 | + |
| 258 | + EvalOperator.ExpressionEvaluator evaluator = EvalMapper.toEvaluator( |
| 259 | + FOLD_CONTEXT, |
| 260 | + new RoundTo(Source.EMPTY, f, List.of(b(), kb())), |
| 261 | + layout(f) |
| 262 | + ).get(driverContext); |
| 263 | + String desc = "RoundToLong2"; |
| 264 | + if (evaluator.toString().contains(desc) == false) { |
| 265 | + throw new IllegalArgumentException("Evaluator was [" + evaluator + "] but expected one containing [" + desc + "]"); |
| 266 | + } |
| 267 | + yield evaluator; |
| 268 | + } |
| 269 | + case "round_to_3" -> { |
| 270 | + FieldAttribute f = longField(); |
| 271 | + |
| 272 | + EvalOperator.ExpressionEvaluator evaluator = EvalMapper.toEvaluator( |
| 273 | + FOLD_CONTEXT, |
| 274 | + new RoundTo(Source.EMPTY, f, List.of(b(), kb(), mb())), |
| 275 | + layout(f) |
| 276 | + ).get(driverContext); |
| 277 | + String desc = "RoundToLong3"; |
| 278 | + if (evaluator.toString().contains(desc) == false) { |
| 279 | + throw new IllegalArgumentException("Evaluator was [" + evaluator + "] but expected one containing [" + desc + "]"); |
| 280 | + } |
| 281 | + yield evaluator; |
| 282 | + } |
| 283 | + case "round_to_4" -> { |
| 284 | + FieldAttribute f = longField(); |
| 285 | + |
| 286 | + EvalOperator.ExpressionEvaluator evaluator = EvalMapper.toEvaluator( |
| 287 | + FOLD_CONTEXT, |
| 288 | + new RoundTo(Source.EMPTY, f, List.of(b(), kb(), mb(), gb())), |
| 289 | + layout(f) |
| 290 | + ).get(driverContext); |
| 291 | + String desc = "RoundToLong4"; |
| 292 | + if (evaluator.toString().contains(desc) == false) { |
| 293 | + throw new IllegalArgumentException("Evaluator was [" + evaluator + "] but expected one containing [" + desc + "]"); |
| 294 | + } |
| 295 | + yield evaluator; |
| 296 | + } |
231 | 297 | case "to_lower", "to_lower_ords" -> { |
232 | 298 | FieldAttribute keywordField = keywordField(); |
233 | 299 | ToLower toLower = new ToLower(Source.EMPTY, keywordField, configuration()); |
@@ -390,6 +456,69 @@ private static void checkExpected(String operation, Page actual) { |
390 | 456 | } |
391 | 457 | } |
392 | 458 | } |
| 459 | + case "round_to_4_via_case", "round_to_4" -> { |
| 460 | + long b = 1; |
| 461 | + long kb = ByteSizeUnit.KB.toBytes(1); |
| 462 | + long mb = ByteSizeUnit.MB.toBytes(1); |
| 463 | + long gb = ByteSizeUnit.GB.toBytes(1); |
| 464 | + |
| 465 | + LongVector f = actual.<LongBlock>getBlock(0).asVector(); |
| 466 | + LongVector result = actual.<LongBlock>getBlock(1).asVector(); |
| 467 | + for (int i = 0; i < BLOCK_LENGTH; i++) { |
| 468 | + long expected = f.getLong(i); |
| 469 | + if (expected < kb) { |
| 470 | + expected = b; |
| 471 | + } else if (expected < mb) { |
| 472 | + expected = kb; |
| 473 | + } else if (expected < gb) { |
| 474 | + expected = mb; |
| 475 | + } else { |
| 476 | + expected = gb; |
| 477 | + } |
| 478 | + if (result.getLong(i) != expected) { |
| 479 | + throw new AssertionError("[" + operation + "] expected [" + expected + "] but was [" + result.getLong(i) + "]"); |
| 480 | + } |
| 481 | + } |
| 482 | + } |
| 483 | + case "round_to_3" -> { |
| 484 | + long b = 1; |
| 485 | + long kb = ByteSizeUnit.KB.toBytes(1); |
| 486 | + long mb = ByteSizeUnit.MB.toBytes(1); |
| 487 | + |
| 488 | + LongVector f = actual.<LongBlock>getBlock(0).asVector(); |
| 489 | + LongVector result = actual.<LongBlock>getBlock(1).asVector(); |
| 490 | + for (int i = 0; i < BLOCK_LENGTH; i++) { |
| 491 | + long expected = f.getLong(i); |
| 492 | + if (expected < kb) { |
| 493 | + expected = b; |
| 494 | + } else if (expected < mb) { |
| 495 | + expected = kb; |
| 496 | + } else { |
| 497 | + expected = mb; |
| 498 | + } |
| 499 | + if (result.getLong(i) != expected) { |
| 500 | + throw new AssertionError("[" + operation + "] expected [" + expected + "] but was [" + result.getLong(i) + "]"); |
| 501 | + } |
| 502 | + } |
| 503 | + } |
| 504 | + case "round_to_2" -> { |
| 505 | + long b = 1; |
| 506 | + long kb = ByteSizeUnit.KB.toBytes(1); |
| 507 | + |
| 508 | + LongVector f = actual.<LongBlock>getBlock(0).asVector(); |
| 509 | + LongVector result = actual.<LongBlock>getBlock(1).asVector(); |
| 510 | + for (int i = 0; i < BLOCK_LENGTH; i++) { |
| 511 | + long expected = f.getLong(i); |
| 512 | + if (expected < kb) { |
| 513 | + expected = b; |
| 514 | + } else { |
| 515 | + expected = kb; |
| 516 | + } |
| 517 | + if (result.getLong(i) != expected) { |
| 518 | + throw new AssertionError("[" + operation + "] expected [" + expected + "] but was [" + result.getLong(i) + "]"); |
| 519 | + } |
| 520 | + } |
| 521 | + } |
393 | 522 | case "to_lower" -> checkBytes(operation, actual, false, new BytesRef[] { new BytesRef("foo"), new BytesRef("bar") }); |
394 | 523 | case "to_lower_ords" -> checkBytes(operation, actual, true, new BytesRef[] { new BytesRef("foo"), new BytesRef("bar") }); |
395 | 524 | case "to_upper" -> checkBytes(operation, actual, false, new BytesRef[] { new BytesRef("FOO"), new BytesRef("BAR") }); |
@@ -421,7 +550,7 @@ private static void checkBytes(String operation, Page actual, boolean expectOrds |
421 | 550 |
|
422 | 551 | private static Page page(String operation) { |
423 | 552 | return switch (operation) { |
424 | | - case "abs", "add", "date_trunc", "equal_to_const" -> { |
| 553 | + case "abs", "add", "date_trunc", "equal_to_const", "round_to_4_via_case", "round_to_2", "round_to_3", "round_to_4" -> { |
425 | 554 | var builder = blockFactory.newLongBlockBuilder(BLOCK_LENGTH); |
426 | 555 | for (int i = 0; i < BLOCK_LENGTH; i++) { |
427 | 556 | builder.appendLong(i * 100_000); |
@@ -511,6 +640,26 @@ private static Page page(String operation) { |
511 | 640 | }; |
512 | 641 | } |
513 | 642 |
|
| 643 | + private static Literal b() { |
| 644 | + return lit(1L); |
| 645 | + } |
| 646 | + |
| 647 | + private static Literal kb() { |
| 648 | + return lit(ByteSizeUnit.KB.toBytes(1)); |
| 649 | + } |
| 650 | + |
| 651 | + private static Literal mb() { |
| 652 | + return lit(ByteSizeUnit.MB.toBytes(1)); |
| 653 | + } |
| 654 | + |
| 655 | + private static Literal gb() { |
| 656 | + return lit(ByteSizeUnit.GB.toBytes(1)); |
| 657 | + } |
| 658 | + |
| 659 | + private static Literal lit(long v) { |
| 660 | + return new Literal(Source.EMPTY, v, DataType.LONG); |
| 661 | + } |
| 662 | + |
514 | 663 | @Benchmark |
515 | 664 | @OperationsPerInvocation(1024 * BLOCK_LENGTH) |
516 | 665 | public void run() { |
|
0 commit comments