|
58 | 58 | import static org.elasticsearch.compute.gen.Types.INTERMEDIATE_STATE_DESC; |
59 | 59 | import static org.elasticsearch.compute.gen.Types.INT_ARRAY_BLOCK; |
60 | 60 | import static org.elasticsearch.compute.gen.Types.INT_BIG_ARRAY_BLOCK; |
| 61 | +import static org.elasticsearch.compute.gen.Types.INT_BLOCK; |
61 | 62 | import static org.elasticsearch.compute.gen.Types.INT_VECTOR; |
62 | 63 | import static org.elasticsearch.compute.gen.Types.LIST_AGG_FUNC_DESC; |
63 | 64 | import static org.elasticsearch.compute.gen.Types.LIST_INTEGER; |
@@ -609,77 +610,98 @@ private MethodSpec addIntermediateInput(TypeName groupsType) { |
609 | 610 | .collect(joining(" && ")) |
610 | 611 | ); |
611 | 612 | } |
612 | | - if (intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) { |
613 | | - builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF); |
614 | | - } |
615 | | - builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)"); |
616 | | - { |
617 | | - if (groupsIsBlock) { |
618 | | - builder.beginControlFlow("if (groups.isNull(groupPosition))"); |
619 | | - builder.addStatement("continue"); |
620 | | - builder.endControlFlow(); |
621 | | - builder.addStatement("int groupStart = groups.getFirstValueIndex(groupPosition)"); |
622 | | - builder.addStatement("int groupEnd = groupStart + groups.getValueCount(groupPosition)"); |
623 | | - builder.beginControlFlow("for (int g = groupStart; g < groupEnd; g++)"); |
624 | | - builder.addStatement("int groupId = groups.getInt(g)"); |
625 | | - } else { |
626 | | - builder.addStatement("int groupId = groups.getInt(groupPosition)"); |
| 613 | + var bulkCombineIntermediateMethod = optionalStaticMethod( |
| 614 | + declarationType, |
| 615 | + requireVoidType(), |
| 616 | + requireName("combineIntermediate"), |
| 617 | + requireArgs( |
| 618 | + Stream.concat( |
| 619 | + // aggState, positionOffset, groupIds |
| 620 | + Stream.of(aggState.declaredType(), TypeName.INT, groupsIsBlock ? INT_BLOCK : INT_VECTOR), |
| 621 | + intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType) |
| 622 | + ).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new) |
| 623 | + ) |
| 624 | + ); |
| 625 | + if (bulkCombineIntermediateMethod != null) { |
| 626 | + var states = intermediateState.stream() |
| 627 | + .map(AggregatorImplementer.IntermediateStateDesc::name) |
| 628 | + .collect(Collectors.joining(", ")); |
| 629 | + builder.addStatement("$T.combineIntermediate(state, positionOffset, groups, " + states + ")", declarationType); |
| 630 | + } else { |
| 631 | + if (intermediateState.stream() |
| 632 | + .map(AggregatorImplementer.IntermediateStateDesc::elementType) |
| 633 | + .anyMatch(n -> n.equals("BYTES_REF"))) { |
| 634 | + builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF); |
627 | 635 | } |
628 | | - |
629 | | - if (aggState.declaredType().isPrimitive()) { |
630 | | - if (warnExceptions.isEmpty()) { |
631 | | - assert intermediateState.size() == 2; |
632 | | - assert intermediateState.get(1).name().equals("seen"); |
633 | | - builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))"); |
| 636 | + builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)"); |
| 637 | + { |
| 638 | + if (groupsIsBlock) { |
| 639 | + builder.beginControlFlow("if (groups.isNull(groupPosition))"); |
| 640 | + builder.addStatement("continue"); |
| 641 | + builder.endControlFlow(); |
| 642 | + builder.addStatement("int groupStart = groups.getFirstValueIndex(groupPosition)"); |
| 643 | + builder.addStatement("int groupEnd = groupStart + groups.getValueCount(groupPosition)"); |
| 644 | + builder.beginControlFlow("for (int g = groupStart; g < groupEnd; g++)"); |
| 645 | + builder.addStatement("int groupId = groups.getInt(g)"); |
634 | 646 | } else { |
635 | | - assert intermediateState.size() == 3; |
636 | | - assert intermediateState.get(1).name().equals("seen"); |
637 | | - assert intermediateState.get(2).name().equals("failed"); |
638 | | - builder.beginControlFlow("if (failed.getBoolean(groupPosition + positionOffset))"); |
639 | | - { |
640 | | - builder.addStatement("state.setFailed(groupId)"); |
641 | | - } |
642 | | - builder.nextControlFlow("else if (seen.getBoolean(groupPosition + positionOffset))"); |
| 647 | + builder.addStatement("int groupId = groups.getInt(groupPosition)"); |
643 | 648 | } |
644 | 649 |
|
645 | | - warningsBlock(builder, () -> { |
646 | | - var name = intermediateState.get(0).name(); |
647 | | - var vectorAccessor = vectorAccessorName(intermediateState.get(0).elementType()); |
648 | | - builder.addStatement( |
649 | | - "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))", |
| 650 | + if (aggState.declaredType().isPrimitive()) { |
| 651 | + if (warnExceptions.isEmpty()) { |
| 652 | + assert intermediateState.size() == 2; |
| 653 | + assert intermediateState.get(1).name().equals("seen"); |
| 654 | + builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))"); |
| 655 | + } else { |
| 656 | + assert intermediateState.size() == 3; |
| 657 | + assert intermediateState.get(1).name().equals("seen"); |
| 658 | + assert intermediateState.get(2).name().equals("failed"); |
| 659 | + builder.beginControlFlow("if (failed.getBoolean(groupPosition + positionOffset))"); |
| 660 | + { |
| 661 | + builder.addStatement("state.setFailed(groupId)"); |
| 662 | + } |
| 663 | + builder.nextControlFlow("else if (seen.getBoolean(groupPosition + positionOffset))"); |
| 664 | + } |
| 665 | + |
| 666 | + warningsBlock(builder, () -> { |
| 667 | + var name = intermediateState.get(0).name(); |
| 668 | + var vectorAccessor = vectorAccessorName(intermediateState.get(0).elementType()); |
| 669 | + builder.addStatement( |
| 670 | + "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))", |
| 671 | + declarationType, |
| 672 | + name, |
| 673 | + vectorAccessor |
| 674 | + ); |
| 675 | + }); |
| 676 | + builder.endControlFlow(); |
| 677 | + } else { |
| 678 | + var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block); |
| 679 | + requireStaticMethod( |
650 | 680 | declarationType, |
651 | | - name, |
652 | | - vectorAccessor |
| 681 | + requireVoidType(), |
| 682 | + requireName("combineIntermediate"), |
| 683 | + requireArgs( |
| 684 | + Stream.of( |
| 685 | + Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId |
| 686 | + intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType), |
| 687 | + Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position |
| 688 | + ).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new) |
| 689 | + ) |
653 | 690 | ); |
654 | | - }); |
655 | | - builder.endControlFlow(); |
656 | | - } else { |
657 | | - var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block); |
658 | | - requireStaticMethod( |
659 | | - declarationType, |
660 | | - requireVoidType(), |
661 | | - requireName("combineIntermediate"), |
662 | | - requireArgs( |
663 | | - Stream.of( |
664 | | - Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId |
665 | | - intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType), |
666 | | - Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position |
667 | | - ).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new) |
668 | | - ) |
669 | | - ); |
670 | 691 |
|
671 | | - builder.addStatement( |
672 | | - "$T.combineIntermediate(state, groupId, " |
673 | | - + intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", ")) |
674 | | - + (stateHasBlock ? ", groupPosition + positionOffset" : "") |
675 | | - + ")", |
676 | | - declarationType |
677 | | - ); |
678 | | - } |
679 | | - if (groupsIsBlock) { |
| 692 | + builder.addStatement( |
| 693 | + "$T.combineIntermediate(state, groupId, " |
| 694 | + + intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", ")) |
| 695 | + + (stateHasBlock ? ", groupPosition + positionOffset" : "") |
| 696 | + + ")", |
| 697 | + declarationType |
| 698 | + ); |
| 699 | + } |
| 700 | + if (groupsIsBlock) { |
| 701 | + builder.endControlFlow(); |
| 702 | + } |
680 | 703 | builder.endControlFlow(); |
681 | 704 | } |
682 | | - builder.endControlFlow(); |
683 | 705 | } |
684 | 706 | return builder.build(); |
685 | 707 | } |
|
0 commit comments