|
13 | 13 | import org.elasticsearch.common.breaker.NoopCircuitBreaker;
|
14 | 14 | import org.elasticsearch.common.logging.LogConfigurator;
|
15 | 15 | import org.elasticsearch.common.settings.Settings;
|
| 16 | +import org.elasticsearch.common.unit.ByteSizeUnit; |
16 | 17 | import org.elasticsearch.common.util.BigArrays;
|
17 | 18 | import org.elasticsearch.compute.data.Block;
|
18 | 19 | import org.elasticsearch.compute.data.BlockFactory;
|
|
44 | 45 | import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
|
45 | 46 | import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateTrunc;
|
46 | 47 | import org.elasticsearch.xpack.esql.expression.function.scalar.math.Abs;
|
| 48 | +import org.elasticsearch.xpack.esql.expression.function.scalar.math.RoundTo; |
47 | 49 | import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin;
|
48 | 50 | import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
|
49 | 51 | import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike;
|
50 | 52 | import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToLower;
|
51 | 53 | import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToUpper;
|
52 | 54 | import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
|
53 | 55 | import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
|
| 56 | +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan; |
54 | 57 | import org.elasticsearch.xpack.esql.planner.Layout;
|
55 | 58 | import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
|
56 | 59 | import org.elasticsearch.xpack.esql.session.Configuration;
|
@@ -128,6 +131,10 @@ static void selfTest() {
|
128 | 131 | "long_equal_to_int",
|
129 | 132 | "mv_min",
|
130 | 133 | "mv_min_ascending",
|
| 134 | + "round_to_4_via_case", |
| 135 | + "round_to_2", |
| 136 | + "round_to_3", |
| 137 | + "round_to_4", |
131 | 138 | "rlike",
|
132 | 139 | "to_lower",
|
133 | 140 | "to_lower_ords",
|
@@ -240,6 +247,65 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) {
|
240 | 247 | RLike rlike = new RLike(Source.EMPTY, keywordField, new RLikePattern(".ar"));
|
241 | 248 | yield EvalMapper.toEvaluator(FOLD_CONTEXT, rlike, layout(keywordField)).get(driverContext);
|
242 | 249 | }
|
| 250 | + case "round_to_4_via_case" -> { |
| 251 | + FieldAttribute f = longField(); |
| 252 | + |
| 253 | + Expression ltkb = new LessThan(Source.EMPTY, f, kb()); |
| 254 | + Expression ltmb = new LessThan(Source.EMPTY, f, mb()); |
| 255 | + Expression ltgb = new LessThan(Source.EMPTY, f, gb()); |
| 256 | + EvalOperator.ExpressionEvaluator evaluator = EvalMapper.toEvaluator( |
| 257 | + FOLD_CONTEXT, |
| 258 | + new Case(Source.EMPTY, ltkb, List.of(b(), ltmb, kb(), ltgb, mb(), gb())), |
| 259 | + layout(f) |
| 260 | + ).get(driverContext); |
| 261 | + String desc = "CaseLazyEvaluator"; |
| 262 | + if (evaluator.toString().contains(desc) == false) { |
| 263 | + throw new IllegalArgumentException("Evaluator was [" + evaluator + "] but expected one containing [" + desc + "]"); |
| 264 | + } |
| 265 | + yield evaluator; |
| 266 | + } |
| 267 | + case "round_to_2" -> { |
| 268 | + FieldAttribute f = longField(); |
| 269 | + |
| 270 | + EvalOperator.ExpressionEvaluator evaluator = EvalMapper.toEvaluator( |
| 271 | + FOLD_CONTEXT, |
| 272 | + new RoundTo(Source.EMPTY, f, List.of(b(), kb())), |
| 273 | + layout(f) |
| 274 | + ).get(driverContext); |
| 275 | + String desc = "RoundToLong2"; |
| 276 | + if (evaluator.toString().contains(desc) == false) { |
| 277 | + throw new IllegalArgumentException("Evaluator was [" + evaluator + "] but expected one containing [" + desc + "]"); |
| 278 | + } |
| 279 | + yield evaluator; |
| 280 | + } |
| 281 | + case "round_to_3" -> { |
| 282 | + FieldAttribute f = longField(); |
| 283 | + |
| 284 | + EvalOperator.ExpressionEvaluator evaluator = EvalMapper.toEvaluator( |
| 285 | + FOLD_CONTEXT, |
| 286 | + new RoundTo(Source.EMPTY, f, List.of(b(), kb(), mb())), |
| 287 | + layout(f) |
| 288 | + ).get(driverContext); |
| 289 | + String desc = "RoundToLong3"; |
| 290 | + if (evaluator.toString().contains(desc) == false) { |
| 291 | + throw new IllegalArgumentException("Evaluator was [" + evaluator + "] but expected one containing [" + desc + "]"); |
| 292 | + } |
| 293 | + yield evaluator; |
| 294 | + } |
| 295 | + case "round_to_4" -> { |
| 296 | + FieldAttribute f = longField(); |
| 297 | + |
| 298 | + EvalOperator.ExpressionEvaluator evaluator = EvalMapper.toEvaluator( |
| 299 | + FOLD_CONTEXT, |
| 300 | + new RoundTo(Source.EMPTY, f, List.of(b(), kb(), mb(), gb())), |
| 301 | + layout(f) |
| 302 | + ).get(driverContext); |
| 303 | + String desc = "RoundToLong4"; |
| 304 | + if (evaluator.toString().contains(desc) == false) { |
| 305 | + throw new IllegalArgumentException("Evaluator was [" + evaluator + "] but expected one containing [" + desc + "]"); |
| 306 | + } |
| 307 | + yield evaluator; |
| 308 | + } |
243 | 309 | case "to_lower", "to_lower_ords" -> {
|
244 | 310 | FieldAttribute keywordField = keywordField();
|
245 | 311 | ToLower toLower = new ToLower(Source.EMPTY, keywordField, configuration());
|
@@ -419,6 +485,69 @@ private static void checkExpected(String operation, Page actual) {
|
419 | 485 | }
|
420 | 486 | }
|
421 | 487 | }
|
| 488 | + case "round_to_4_via_case", "round_to_4" -> { |
| 489 | + long b = 1; |
| 490 | + long kb = ByteSizeUnit.KB.toBytes(1); |
| 491 | + long mb = ByteSizeUnit.MB.toBytes(1); |
| 492 | + long gb = ByteSizeUnit.GB.toBytes(1); |
| 493 | + |
| 494 | + LongVector f = actual.<LongBlock>getBlock(0).asVector(); |
| 495 | + LongVector result = actual.<LongBlock>getBlock(1).asVector(); |
| 496 | + for (int i = 0; i < BLOCK_LENGTH; i++) { |
| 497 | + long expected = f.getLong(i); |
| 498 | + if (expected < kb) { |
| 499 | + expected = b; |
| 500 | + } else if (expected < mb) { |
| 501 | + expected = kb; |
| 502 | + } else if (expected < gb) { |
| 503 | + expected = mb; |
| 504 | + } else { |
| 505 | + expected = gb; |
| 506 | + } |
| 507 | + if (result.getLong(i) != expected) { |
| 508 | + throw new AssertionError("[" + operation + "] expected [" + expected + "] but was [" + result.getLong(i) + "]"); |
| 509 | + } |
| 510 | + } |
| 511 | + } |
| 512 | + case "round_to_3" -> { |
| 513 | + long b = 1; |
| 514 | + long kb = ByteSizeUnit.KB.toBytes(1); |
| 515 | + long mb = ByteSizeUnit.MB.toBytes(1); |
| 516 | + |
| 517 | + LongVector f = actual.<LongBlock>getBlock(0).asVector(); |
| 518 | + LongVector result = actual.<LongBlock>getBlock(1).asVector(); |
| 519 | + for (int i = 0; i < BLOCK_LENGTH; i++) { |
| 520 | + long expected = f.getLong(i); |
| 521 | + if (expected < kb) { |
| 522 | + expected = b; |
| 523 | + } else if (expected < mb) { |
| 524 | + expected = kb; |
| 525 | + } else { |
| 526 | + expected = mb; |
| 527 | + } |
| 528 | + if (result.getLong(i) != expected) { |
| 529 | + throw new AssertionError("[" + operation + "] expected [" + expected + "] but was [" + result.getLong(i) + "]"); |
| 530 | + } |
| 531 | + } |
| 532 | + } |
| 533 | + case "round_to_2" -> { |
| 534 | + long b = 1; |
| 535 | + long kb = ByteSizeUnit.KB.toBytes(1); |
| 536 | + |
| 537 | + LongVector f = actual.<LongBlock>getBlock(0).asVector(); |
| 538 | + LongVector result = actual.<LongBlock>getBlock(1).asVector(); |
| 539 | + for (int i = 0; i < BLOCK_LENGTH; i++) { |
| 540 | + long expected = f.getLong(i); |
| 541 | + if (expected < kb) { |
| 542 | + expected = b; |
| 543 | + } else { |
| 544 | + expected = kb; |
| 545 | + } |
| 546 | + if (result.getLong(i) != expected) { |
| 547 | + throw new AssertionError("[" + operation + "] expected [" + expected + "] but was [" + result.getLong(i) + "]"); |
| 548 | + } |
| 549 | + } |
| 550 | + } |
422 | 551 | case "to_lower" -> checkBytes(operation, actual, false, new BytesRef[] { new BytesRef("foo"), new BytesRef("bar") });
|
423 | 552 | case "to_lower_ords" -> checkBytes(operation, actual, true, new BytesRef[] { new BytesRef("foo"), new BytesRef("bar") });
|
424 | 553 | case "to_upper" -> checkBytes(operation, actual, false, new BytesRef[] { new BytesRef("FOO"), new BytesRef("BAR") });
|
@@ -450,7 +579,7 @@ private static void checkBytes(String operation, Page actual, boolean expectOrds
|
450 | 579 |
|
451 | 580 | private static Page page(String operation) {
|
452 | 581 | return switch (operation) {
|
453 |
| - case "abs", "add", "date_trunc", "equal_to_const" -> { |
| 582 | + case "abs", "add", "date_trunc", "equal_to_const", "round_to_4_via_case", "round_to_2", "round_to_3", "round_to_4" -> { |
454 | 583 | var builder = blockFactory.newLongBlockBuilder(BLOCK_LENGTH);
|
455 | 584 | for (int i = 0; i < BLOCK_LENGTH; i++) {
|
456 | 585 | builder.appendLong(i * 100_000);
|
@@ -540,6 +669,26 @@ private static Page page(String operation) {
|
540 | 669 | };
|
541 | 670 | }
|
542 | 671 |
|
| 672 | + private static Literal b() { |
| 673 | + return lit(1L); |
| 674 | + } |
| 675 | + |
| 676 | + private static Literal kb() { |
| 677 | + return lit(ByteSizeUnit.KB.toBytes(1)); |
| 678 | + } |
| 679 | + |
| 680 | + private static Literal mb() { |
| 681 | + return lit(ByteSizeUnit.MB.toBytes(1)); |
| 682 | + } |
| 683 | + |
| 684 | + private static Literal gb() { |
| 685 | + return lit(ByteSizeUnit.GB.toBytes(1)); |
| 686 | + } |
| 687 | + |
| 688 | + private static Literal lit(long v) { |
| 689 | + return new Literal(Source.EMPTY, v, DataType.LONG); |
| 690 | + } |
| 691 | + |
543 | 692 | @Benchmark
|
544 | 693 | @OperationsPerInvocation(1024 * BLOCK_LENGTH)
|
545 | 694 | public void run() {
|
|
0 commit comments