1212import org .elasticsearch .common .util .BigArrays ;
1313import org .elasticsearch .common .util .BitArray ;
1414import org .elasticsearch .compute .aggregation .GroupingAggregatorFunction ;
15- import org .elasticsearch .compute .aggregation .Warnings ;
1615import org .elasticsearch .compute .ann .Fixed ;
1716import org .elasticsearch .compute .data .Block ;
1817import org .elasticsearch .compute .data .BlockFactory ;
2120import org .elasticsearch .compute .data .IntBlock ;
2221import org .elasticsearch .compute .data .IntVector ;
2322import org .elasticsearch .compute .data .Page ;
24- import org .elasticsearch .compute .operator .DriverContext ;
25- import org .elasticsearch .compute .operator .EvalOperator ;
23+ import org .elasticsearch .core .Releasable ;
2624import org .elasticsearch .core .Releasables ;
2725import org .elasticsearch .xpack .ml .aggs .categorization .TokenListCategorizer ;
2826import org .elasticsearch .xpack .ml .job .categorization .CategorizationAnalyzer ;
@@ -34,17 +32,18 @@ public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
3432
3533 CategorizeRawBlockHash (
3634 BlockFactory blockFactory ,
35+ int channel ,
3736 boolean outputPartial ,
38- TokenListCategorizer . CloseableTokenListCategorizer categorizer ,
39- CategorizeEvaluator evaluator
37+ CategorizationAnalyzer analyzer ,
38+ TokenListCategorizer . CloseableTokenListCategorizer categorizer
4039 ) {
41- super (blockFactory , outputPartial , categorizer );
42- this .evaluator = evaluator ;
40+ super (blockFactory , channel , outputPartial , categorizer );
41+ this .evaluator = new CategorizeEvaluator ( analyzer , categorizer , blockFactory ) ;
4342 }
4443
4544 @ Override
4645 public void add (Page page , GroupingAggregatorFunction .AddInput addInput ) {
47- IntBlock result = (IntBlock ) evaluator .eval (page );
46+ IntBlock result = (IntBlock ) evaluator .eval (page . getBlock ( channel ()) );
4847 addInput .add (0 , result );
4948 }
5049
@@ -66,18 +65,14 @@ public void close() {
6665 }
6766
6867 /**
69- * NOCOMMIT: Super-duper copy-pasted.
68+ * NOCOMMIT: Super-duper copy-pasted from the actually generated evaluator; needs cleanup .
7069 */
71- public static final class CategorizeEvaluator implements EvalOperator .ExpressionEvaluator {
72- private final Warnings warnings ;
73-
74- private final EvalOperator .ExpressionEvaluator v ;
75-
70+ public static final class CategorizeEvaluator implements Releasable {
7671 private final CategorizationAnalyzer analyzer ;
7772
7873 private final TokenListCategorizer .CloseableTokenListCategorizer categorizer ;
7974
80- private final DriverContext driverContext ;
75+ private final BlockFactory blockFactory ;
8176
8277 static int process (
8378 BytesRef v ,
@@ -93,31 +88,25 @@ static int process(
9388 }
9489
9590 public CategorizeEvaluator (
96- EvalOperator .ExpressionEvaluator v ,
9791 CategorizationAnalyzer analyzer ,
9892 TokenListCategorizer .CloseableTokenListCategorizer categorizer ,
99- DriverContext driverContext
93+ BlockFactory blockFactory
10094 ) {
101- this .v = v ;
10295 this .analyzer = analyzer ;
10396 this .categorizer = categorizer ;
104- this .driverContext = driverContext ;
105- this .warnings = Warnings .createWarnings (driverContext .warningsMode (), -1 , -1 , "" );
97+ this .blockFactory = blockFactory ;
10698 }
10799
108- @ Override
109- public Block eval (Page page ) {
110- try (BytesRefBlock vBlock = (BytesRefBlock ) v .eval (page )) {
111- BytesRefVector vVector = vBlock .asVector ();
112- if (vVector == null ) {
113- return eval (page .getPositionCount (), vBlock );
114- }
115- return eval (page .getPositionCount (), vVector ).asBlock ();
100+ public Block eval (BytesRefBlock vBlock ) {
101+ BytesRefVector vVector = vBlock .asVector ();
102+ if (vVector == null ) {
103+ return eval (vBlock .getPositionCount (), vBlock );
116104 }
105+ return eval (vBlock .getPositionCount (), vVector ).asBlock ();
117106 }
118107
119108 public IntBlock eval (int positionCount , BytesRefBlock vBlock ) {
120- try (IntBlock .Builder result = driverContext . blockFactory () .newIntBlockBuilder (positionCount )) {
109+ try (IntBlock .Builder result = blockFactory .newIntBlockBuilder (positionCount )) {
121110 BytesRef vScratch = new BytesRef ();
122111 position : for (int p = 0 ; p < positionCount ; p ++) {
123112 if (vBlock .isNull (p )) {
@@ -126,7 +115,7 @@ public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
126115 }
127116 if (vBlock .getValueCount (p ) != 1 ) {
128117 if (vBlock .getValueCount (p ) > 1 ) {
129- warnings . registerException ( new IllegalArgumentException ( "single-value function encountered multi-value" ));
118+ // TODO: handle multi-values
130119 }
131120 result .appendNull ();
132121 continue position ;
@@ -138,7 +127,7 @@ public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
138127 }
139128
140129 public IntVector eval (int positionCount , BytesRefVector vVector ) {
141- try (IntVector .FixedBuilder result = driverContext . blockFactory () .newIntVectorFixedBuilder (positionCount )) {
130+ try (IntVector .FixedBuilder result = blockFactory .newIntVectorFixedBuilder (positionCount )) {
142131 BytesRef vScratch = new BytesRef ();
143132 position : for (int p = 0 ; p < positionCount ; p ++) {
144133 result .appendInt (p , process (vVector .getBytesRef (p , vScratch ), this .analyzer , this .categorizer ));
@@ -149,12 +138,12 @@ public IntVector eval(int positionCount, BytesRefVector vVector) {
149138
150139 @ Override
151140 public String toString () {
152- return "CategorizeEvaluator[" + "v=" + v + "] " ;
141+ return "CategorizeEvaluator" ;
153142 }
154143
155144 @ Override
156145 public void close () {
157- Releasables .closeExpectNoException (v , analyzer , categorizer );
146+ Releasables .closeExpectNoException (analyzer , categorizer );
158147 }
159148 }
160149}
0 commit comments