|
12 | 12 | import org.apache.lucene.util.BytesRef;
|
13 | 13 | import org.elasticsearch.common.breaker.NoopCircuitBreaker;
|
14 | 14 | import org.elasticsearch.common.settings.Settings;
|
| 15 | +import org.elasticsearch.common.unit.ByteSizeUnit; |
15 | 16 | import org.elasticsearch.common.util.BigArrays;
|
16 | 17 | import org.elasticsearch.compute.data.Block;
|
17 | 18 | import org.elasticsearch.compute.data.BlockFactory;
|
|
41 | 42 | import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
|
42 | 43 | import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateTrunc;
|
43 | 44 | import org.elasticsearch.xpack.esql.expression.function.scalar.math.Abs;
|
| 45 | +import org.elasticsearch.xpack.esql.expression.function.scalar.math.RoundTo; |
44 | 46 | import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin;
|
45 | 47 | import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
|
46 | 48 | import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike;
|
47 | 49 | import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToLower;
|
48 | 50 | import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToUpper;
|
49 | 51 | import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
|
50 | 52 | import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
|
| 53 | +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan; |
51 | 54 | import org.elasticsearch.xpack.esql.planner.Layout;
|
52 | 55 | import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
|
53 | 56 | import org.elasticsearch.xpack.esql.session.Configuration;
|
@@ -116,6 +119,10 @@ public class EvalBenchmark {
|
116 | 119 | "long_equal_to_int",
|
117 | 120 | "mv_min",
|
118 | 121 | "mv_min_ascending",
|
| 122 | + "round_to_4_via_case", |
| 123 | + "round_to_2", |
| 124 | + "round_to_3", |
| 125 | + "round_to_4", |
119 | 126 | "rlike",
|
120 | 127 | "to_lower",
|
121 | 128 | "to_lower_ords",
|
@@ -228,6 +235,65 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) {
|
228 | 235 | RLike rlike = new RLike(Source.EMPTY, keywordField, new RLikePattern(".ar"));
|
229 | 236 | yield EvalMapper.toEvaluator(FOLD_CONTEXT, rlike, layout(keywordField)).get(driverContext);
|
230 | 237 | }
|
| 238 | + case "round_to_4_via_case" -> { |
| 239 | + FieldAttribute f = longField(); |
| 240 | + |
| 241 | + Expression ltkb = new LessThan(Source.EMPTY, f, kb()); |
| 242 | + Expression ltmb = new LessThan(Source.EMPTY, f, mb()); |
| 243 | + Expression ltgb = new LessThan(Source.EMPTY, f, gb()); |
| 244 | + EvalOperator.ExpressionEvaluator evaluator = EvalMapper.toEvaluator( |
| 245 | + FOLD_CONTEXT, |
| 246 | + new Case(Source.EMPTY, ltkb, List.of(b(), ltmb, kb(), ltgb, mb(), gb())), |
| 247 | + layout(f) |
| 248 | + ).get(driverContext); |
| 249 | + String desc = "CaseLazyEvaluator"; |
| 250 | + if (evaluator.toString().contains(desc) == false) { |
| 251 | + throw new IllegalArgumentException("Evaluator was [" + evaluator + "] but expected one containing [" + desc + "]"); |
| 252 | + } |
| 253 | + yield evaluator; |
| 254 | + } |
| 255 | + case "round_to_2" -> { |
| 256 | + FieldAttribute f = longField(); |
| 257 | + |
| 258 | + EvalOperator.ExpressionEvaluator evaluator = EvalMapper.toEvaluator( |
| 259 | + FOLD_CONTEXT, |
| 260 | + new RoundTo(Source.EMPTY, f, List.of(b(), kb())), |
| 261 | + layout(f) |
| 262 | + ).get(driverContext); |
| 263 | + String desc = "RoundToLong2"; |
| 264 | + if (evaluator.toString().contains(desc) == false) { |
| 265 | + throw new IllegalArgumentException("Evaluator was [" + evaluator + "] but expected one containing [" + desc + "]"); |
| 266 | + } |
| 267 | + yield evaluator; |
| 268 | + } |
| 269 | + case "round_to_3" -> { |
| 270 | + FieldAttribute f = longField(); |
| 271 | + |
| 272 | + EvalOperator.ExpressionEvaluator evaluator = EvalMapper.toEvaluator( |
| 273 | + FOLD_CONTEXT, |
| 274 | + new RoundTo(Source.EMPTY, f, List.of(b(), kb(), mb())), |
| 275 | + layout(f) |
| 276 | + ).get(driverContext); |
| 277 | + String desc = "RoundToLong3"; |
| 278 | + if (evaluator.toString().contains(desc) == false) { |
| 279 | + throw new IllegalArgumentException("Evaluator was [" + evaluator + "] but expected one containing [" + desc + "]"); |
| 280 | + } |
| 281 | + yield evaluator; |
| 282 | + } |
| 283 | + case "round_to_4" -> { |
| 284 | + FieldAttribute f = longField(); |
| 285 | + |
| 286 | + EvalOperator.ExpressionEvaluator evaluator = EvalMapper.toEvaluator( |
| 287 | + FOLD_CONTEXT, |
| 288 | + new RoundTo(Source.EMPTY, f, List.of(b(), kb(), mb(), gb())), |
| 289 | + layout(f) |
| 290 | + ).get(driverContext); |
| 291 | + String desc = "RoundToLong4"; |
| 292 | + if (evaluator.toString().contains(desc) == false) { |
| 293 | + throw new IllegalArgumentException("Evaluator was [" + evaluator + "] but expected one containing [" + desc + "]"); |
| 294 | + } |
| 295 | + yield evaluator; |
| 296 | + } |
231 | 297 | case "to_lower", "to_lower_ords" -> {
|
232 | 298 | FieldAttribute keywordField = keywordField();
|
233 | 299 | ToLower toLower = new ToLower(Source.EMPTY, keywordField, configuration());
|
@@ -390,6 +456,69 @@ private static void checkExpected(String operation, Page actual) {
|
390 | 456 | }
|
391 | 457 | }
|
392 | 458 | }
|
| 459 | + case "round_to_4_via_case", "round_to_4" -> { |
| 460 | + long b = 1; |
| 461 | + long kb = ByteSizeUnit.KB.toBytes(1); |
| 462 | + long mb = ByteSizeUnit.MB.toBytes(1); |
| 463 | + long gb = ByteSizeUnit.GB.toBytes(1); |
| 464 | + |
| 465 | + LongVector f = actual.<LongBlock>getBlock(0).asVector(); |
| 466 | + LongVector result = actual.<LongBlock>getBlock(1).asVector(); |
| 467 | + for (int i = 0; i < BLOCK_LENGTH; i++) { |
| 468 | + long expected = f.getLong(i); |
| 469 | + if (expected < kb) { |
| 470 | + expected = b; |
| 471 | + } else if (expected < mb) { |
| 472 | + expected = kb; |
| 473 | + } else if (expected < gb) { |
| 474 | + expected = mb; |
| 475 | + } else { |
| 476 | + expected = gb; |
| 477 | + } |
| 478 | + if (result.getLong(i) != expected) { |
| 479 | + throw new AssertionError("[" + operation + "] expected [" + expected + "] but was [" + result.getLong(i) + "]"); |
| 480 | + } |
| 481 | + } |
| 482 | + } |
| 483 | + case "round_to_3" -> { |
| 484 | + long b = 1; |
| 485 | + long kb = ByteSizeUnit.KB.toBytes(1); |
| 486 | + long mb = ByteSizeUnit.MB.toBytes(1); |
| 487 | + |
| 488 | + LongVector f = actual.<LongBlock>getBlock(0).asVector(); |
| 489 | + LongVector result = actual.<LongBlock>getBlock(1).asVector(); |
| 490 | + for (int i = 0; i < BLOCK_LENGTH; i++) { |
| 491 | + long expected = f.getLong(i); |
| 492 | + if (expected < kb) { |
| 493 | + expected = b; |
| 494 | + } else if (expected < mb) { |
| 495 | + expected = kb; |
| 496 | + } else { |
| 497 | + expected = mb; |
| 498 | + } |
| 499 | + if (result.getLong(i) != expected) { |
| 500 | + throw new AssertionError("[" + operation + "] expected [" + expected + "] but was [" + result.getLong(i) + "]"); |
| 501 | + } |
| 502 | + } |
| 503 | + } |
| 504 | + case "round_to_2" -> { |
| 505 | + long b = 1; |
| 506 | + long kb = ByteSizeUnit.KB.toBytes(1); |
| 507 | + |
| 508 | + LongVector f = actual.<LongBlock>getBlock(0).asVector(); |
| 509 | + LongVector result = actual.<LongBlock>getBlock(1).asVector(); |
| 510 | + for (int i = 0; i < BLOCK_LENGTH; i++) { |
| 511 | + long expected = f.getLong(i); |
| 512 | + if (expected < kb) { |
| 513 | + expected = b; |
| 514 | + } else { |
| 515 | + expected = kb; |
| 516 | + } |
| 517 | + if (result.getLong(i) != expected) { |
| 518 | + throw new AssertionError("[" + operation + "] expected [" + expected + "] but was [" + result.getLong(i) + "]"); |
| 519 | + } |
| 520 | + } |
| 521 | + } |
393 | 522 | case "to_lower" -> checkBytes(operation, actual, false, new BytesRef[] { new BytesRef("foo"), new BytesRef("bar") });
|
394 | 523 | case "to_lower_ords" -> checkBytes(operation, actual, true, new BytesRef[] { new BytesRef("foo"), new BytesRef("bar") });
|
395 | 524 | case "to_upper" -> checkBytes(operation, actual, false, new BytesRef[] { new BytesRef("FOO"), new BytesRef("BAR") });
|
@@ -421,7 +550,7 @@ private static void checkBytes(String operation, Page actual, boolean expectOrds
|
421 | 550 |
|
422 | 551 | private static Page page(String operation) {
|
423 | 552 | return switch (operation) {
|
424 |
| - case "abs", "add", "date_trunc", "equal_to_const" -> { |
| 553 | + case "abs", "add", "date_trunc", "equal_to_const", "round_to_4_via_case", "round_to_2", "round_to_3", "round_to_4" -> { |
425 | 554 | var builder = blockFactory.newLongBlockBuilder(BLOCK_LENGTH);
|
426 | 555 | for (int i = 0; i < BLOCK_LENGTH; i++) {
|
427 | 556 | builder.appendLong(i * 100_000);
|
@@ -511,6 +640,26 @@ private static Page page(String operation) {
|
511 | 640 | };
|
512 | 641 | }
|
513 | 642 |
|
| 643 | + private static Literal b() { |
| 644 | + return lit(1L); |
| 645 | + } |
| 646 | + |
| 647 | + private static Literal kb() { |
| 648 | + return lit(ByteSizeUnit.KB.toBytes(1)); |
| 649 | + } |
| 650 | + |
| 651 | + private static Literal mb() { |
| 652 | + return lit(ByteSizeUnit.MB.toBytes(1)); |
| 653 | + } |
| 654 | + |
| 655 | + private static Literal gb() { |
| 656 | + return lit(ByteSizeUnit.GB.toBytes(1)); |
| 657 | + } |
| 658 | + |
| 659 | + private static Literal lit(long v) { |
| 660 | + return new Literal(Source.EMPTY, v, DataType.LONG); |
| 661 | + } |
| 662 | + |
514 | 663 | @Benchmark
|
515 | 664 | @OperationsPerInvocation(1024 * BLOCK_LENGTH)
|
516 | 665 | public void run() {
|
|
0 commit comments