2424import org .elasticsearch .compute .data .Page ;
2525import org .elasticsearch .compute .operator .AsyncOperator ;
2626import org .elasticsearch .compute .operator .DriverContext ;
27- import org .elasticsearch .compute .operator .EvalOperator ;
27+ import org .elasticsearch .compute .operator .EvalOperator . ExpressionEvaluator ;
2828import org .elasticsearch .compute .operator .Operator ;
2929import org .elasticsearch .compute .operator .SourceOperator ;
3030import org .elasticsearch .compute .test .AbstractBlockSourceOperator ;
3131import org .elasticsearch .compute .test .OperatorTestCase ;
3232import org .elasticsearch .compute .test .RandomBlock ;
3333import org .elasticsearch .core .Releasables ;
34- import org .elasticsearch .core .Tuple ;
3534import org .elasticsearch .threadpool .FixedExecutorBuilder ;
3635import org .elasticsearch .threadpool .TestThreadPool ;
3736import org .elasticsearch .threadpool .ThreadPool ;
4342
4443import java .io .IOException ;
4544import java .util .ArrayList ;
46- import java .util .LinkedHashMap ;
4745import java .util .List ;
4846import java .util .Map ;
4947import java .util .function .BiFunction ;
5048import java .util .function .Consumer ;
49+ import java .util .function .Function ;
5150import java .util .stream .Collectors ;
5251import java .util .stream .IntStream ;
52+ import java .util .stream .Stream ;
5353
5454import static org .hamcrest .Matchers .equalTo ;
5555import static org .hamcrest .Matchers .greaterThanOrEqualTo ;
@@ -65,40 +65,16 @@ public class RerankOperatorTests extends OperatorTestCase {
6565 private static final String SIMPLE_INFERENCE_ID = "test_reranker" ;
6666 private static final String SIMPLE_QUERY = "query text" ;
6767 private ThreadPool threadPool ;
68- private Map < String , ElementType > inputChannelElementTypes ;
69- private Map <String , EvalOperator . ExpressionEvaluator .Factory > rerankFieldsEvaluatorFactories ;
68+ private List < ElementType > inputChannelElementTypes ;
69+ private Map <String , ExpressionEvaluator .Factory > rerankFieldsEvaluatorFactories ;
7070 private int scoreChannel ;
7171
7272 @ Before
7373 private void initChannels () {
7474 int channelCount = randomIntBetween (2 , 10 );
7575 scoreChannel = randomIntBetween (0 , channelCount - 1 );
76- inputChannelElementTypes = IntStream .range (0 , channelCount ).sorted ().mapToObj (i -> {
77- return i == scoreChannel
78- ? Map .entry ("_score" , ElementType .DOUBLE )
79- : Map .entry (randomIdentifier (), randomFrom (ElementType .FLOAT , ElementType .DOUBLE , ElementType .LONG ));
80- }).collect (Collectors .toMap (Map .Entry ::getKey , Map .Entry ::getValue , (e1 , e2 ) -> e1 , LinkedHashMap ::new ));
81-
82- rerankFieldsEvaluatorFactories = randomMap (
83- 1 ,
84- 20 ,
85- () -> new Tuple <>(randomIdentifier (), context -> new EvalOperator .ExpressionEvaluator () {
86- private int channel = randomIntBetween (0 , channelCount - 1 );
87-
88- @ Override
89- public Block eval (Page page ) {
90- Block b = page .getBlock (channel );
91- b .incRef ();
92- ;
93- return b ;
94- }
95-
96- @ Override
97- public void close () {
98-
99- }
100- })
101- );
76+ inputChannelElementTypes = IntStream .range (0 , channelCount ).sorted ().mapToObj (this ::randomElementType ).collect (Collectors .toList ());
77+ rerankFieldsEvaluatorFactories = randomFieldEvaluators ().collect (Collectors .toMap ((e ) -> randomIdentifier (), Function .identity ()));
10278 }
10379
10480 @ Before
@@ -184,16 +160,24 @@ protected int remaining() {
184160
185161 @ Override
186162 protected Page createPage (int positionOffset , int length ) {
163+ Block [] blocks = new Block [inputChannelElementTypes .size ()];
187164 try {
188165 currentPosition += length ;
189- ElementType [] elementTypes = inputChannelElementTypes .values ().toArray (ElementType []::new );
190- Block [] blocks = new Block [inputChannelElementTypes .size ()];
191- for (int b = 0 ; b < elementTypes .length ; b ++) {
192- blocks [b ] = RandomBlock .randomBlock (blockFactory , elementTypes [b ], length , randomBoolean (), 0 , 10 , 0 , 10 ).block ();
166+ for (int b = 0 ; b < inputChannelElementTypes .size (); b ++) {
167+ blocks [b ] = RandomBlock .randomBlock (
168+ blockFactory ,
169+ inputChannelElementTypes .get (b ),
170+ length ,
171+ randomBoolean (),
172+ 0 ,
173+ 10 ,
174+ 0 ,
175+ 10
176+ ).block ();
193177 }
194178 return new Page (blocks );
195179 } catch (Exception e ) {
196- Releasables .closeExpectNoException ();
180+ Releasables .closeExpectNoException (blocks );
197181 throw (e );
198182 }
199183 }
@@ -255,7 +239,40 @@ protected void assertSimpleOutput(List<Page> inputPages, List<Page> resultPages)
255239 }
256240 }
257241
258- void assertExpectedScore (DoubleBlock scoreBlockResult ) {
242+ private int inputChannelCount () {
243+ return inputChannelElementTypes .size ();
244+ }
245+
246+ private int randomInputChannel () {
247+ return randomIntBetween (0 , inputChannelCount () - 1 );
248+ }
249+
250+ private ElementType randomElementType (int channel ) {
251+ return channel == scoreChannel ? ElementType .DOUBLE : randomFrom (ElementType .FLOAT , ElementType .DOUBLE , ElementType .LONG );
252+ }
253+
254+ private Stream <ExpressionEvaluator .Factory > randomFieldEvaluators () {
255+ return Stream .generate (() -> randomFieldEvaluator (randomInputChannel ())).limit (randomIntBetween (0 , 20 ));
256+ }
257+
258+ private static ExpressionEvaluator .Factory randomFieldEvaluator (int channel ) {
259+ return context -> new ExpressionEvaluator () {
260+ @ Override
261+ public Block eval (Page page ) {
262+ Block b = page .getBlock (channel );
263+ b .incRef ();
264+ ;
265+ return b ;
266+ }
267+
268+ @ Override
269+ public void close () {
270+
271+ }
272+ };
273+ }
274+
275+ private void assertExpectedScore (DoubleBlock scoreBlockResult ) {
259276 assertRandomPositions (scoreBlockResult , (pos ) -> {
260277 if (pos % 10 == 0 ) {
261278 assertThat (scoreBlockResult .isNull (pos ), equalTo (true ));
@@ -291,13 +308,13 @@ <V extends Block, U> void assertBlockContentEquals(
291308 });
292309 }
293310
294- void assertRandomPositions (Block block , Consumer <Integer > consumer ) {
311+ private void assertRandomPositions (Block block , Consumer <Integer > consumer ) {
295312 for (Integer pos : randomList (0 , 100 , () -> randomIntBetween (0 , block .getPositionCount () - 1 ))) {
296313 consumer .accept (pos );
297314 }
298315 }
299316
300- <V extends Block , U > void assertByteRefsBlockContentEquals (Block input , Block result , BytesRef readBuffer ) {
317+ private <V extends Block , U > void assertByteRefsBlockContentEquals (Block input , Block result , BytesRef readBuffer ) {
301318 assertBlockContentEquals (input , result , (BytesRefBlock b , Integer pos ) -> b .getBytesRef (pos , readBuffer ), BytesRefBlock .class );
302319 }
303320}
0 commit comments