Skip to content

Commit 54ce1eb

Browse files
committed
Small test refactoring
1 parent b0129a3 commit 54ce1eb

File tree

1 file changed

+56
-39
lines changed

1 file changed

+56
-39
lines changed

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/RerankOperatorTests.java

Lines changed: 56 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,13 @@
2424
import org.elasticsearch.compute.data.Page;
2525
import org.elasticsearch.compute.operator.AsyncOperator;
2626
import org.elasticsearch.compute.operator.DriverContext;
27-
import org.elasticsearch.compute.operator.EvalOperator;
27+
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
2828
import org.elasticsearch.compute.operator.Operator;
2929
import org.elasticsearch.compute.operator.SourceOperator;
3030
import org.elasticsearch.compute.test.AbstractBlockSourceOperator;
3131
import org.elasticsearch.compute.test.OperatorTestCase;
3232
import org.elasticsearch.compute.test.RandomBlock;
3333
import org.elasticsearch.core.Releasables;
34-
import org.elasticsearch.core.Tuple;
3534
import org.elasticsearch.threadpool.FixedExecutorBuilder;
3635
import org.elasticsearch.threadpool.TestThreadPool;
3736
import org.elasticsearch.threadpool.ThreadPool;
@@ -43,13 +42,14 @@
4342

4443
import java.io.IOException;
4544
import java.util.ArrayList;
46-
import java.util.LinkedHashMap;
4745
import java.util.List;
4846
import java.util.Map;
4947
import java.util.function.BiFunction;
5048
import java.util.function.Consumer;
49+
import java.util.function.Function;
5150
import java.util.stream.Collectors;
5251
import java.util.stream.IntStream;
52+
import java.util.stream.Stream;
5353

5454
import static org.hamcrest.Matchers.equalTo;
5555
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -65,40 +65,16 @@ public class RerankOperatorTests extends OperatorTestCase {
6565
private static final String SIMPLE_INFERENCE_ID = "test_reranker";
6666
private static final String SIMPLE_QUERY = "query text";
6767
private ThreadPool threadPool;
68-
private Map<String, ElementType> inputChannelElementTypes;
69-
private Map<String, EvalOperator.ExpressionEvaluator.Factory> rerankFieldsEvaluatorFactories;
68+
private List<ElementType> inputChannelElementTypes;
69+
private Map<String, ExpressionEvaluator.Factory> rerankFieldsEvaluatorFactories;
7070
private int scoreChannel;
7171

7272
@Before
7373
private void initChannels() {
7474
int channelCount = randomIntBetween(2, 10);
7575
scoreChannel = randomIntBetween(0, channelCount - 1);
76-
inputChannelElementTypes = IntStream.range(0, channelCount).sorted().mapToObj(i -> {
77-
return i == scoreChannel
78-
? Map.entry("_score", ElementType.DOUBLE)
79-
: Map.entry(randomIdentifier(), randomFrom(ElementType.FLOAT, ElementType.DOUBLE, ElementType.LONG));
80-
}).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (e1, e2) -> e1, LinkedHashMap::new));
81-
82-
rerankFieldsEvaluatorFactories = randomMap(
83-
1,
84-
20,
85-
() -> new Tuple<>(randomIdentifier(), context -> new EvalOperator.ExpressionEvaluator() {
86-
private int channel = randomIntBetween(0, channelCount - 1);
87-
88-
@Override
89-
public Block eval(Page page) {
90-
Block b = page.getBlock(channel);
91-
b.incRef();
92-
;
93-
return b;
94-
}
95-
96-
@Override
97-
public void close() {
98-
99-
}
100-
})
101-
);
76+
inputChannelElementTypes = IntStream.range(0, channelCount).sorted().mapToObj(this::randomElementType).collect(Collectors.toList());
77+
rerankFieldsEvaluatorFactories = randomFieldEvaluators().collect(Collectors.toMap((e) -> randomIdentifier(), Function.identity()));
10278
}
10379

10480
@Before
@@ -184,16 +160,24 @@ protected int remaining() {
184160

185161
@Override
186162
protected Page createPage(int positionOffset, int length) {
163+
Block[] blocks = new Block[inputChannelElementTypes.size()];
187164
try {
188165
currentPosition += length;
189-
ElementType[] elementTypes = inputChannelElementTypes.values().toArray(ElementType[]::new);
190-
Block[] blocks = new Block[inputChannelElementTypes.size()];
191-
for (int b = 0; b < elementTypes.length; b++) {
192-
blocks[b] = RandomBlock.randomBlock(blockFactory, elementTypes[b], length, randomBoolean(), 0, 10, 0, 10).block();
166+
for (int b = 0; b < inputChannelElementTypes.size(); b++) {
167+
blocks[b] = RandomBlock.randomBlock(
168+
blockFactory,
169+
inputChannelElementTypes.get(b),
170+
length,
171+
randomBoolean(),
172+
0,
173+
10,
174+
0,
175+
10
176+
).block();
193177
}
194178
return new Page(blocks);
195179
} catch (Exception e) {
196-
Releasables.closeExpectNoException();
180+
Releasables.closeExpectNoException(blocks);
197181
throw (e);
198182
}
199183
}
@@ -255,7 +239,40 @@ protected void assertSimpleOutput(List<Page> inputPages, List<Page> resultPages)
255239
}
256240
}
257241

258-
void assertExpectedScore(DoubleBlock scoreBlockResult) {
242+
private int inputChannelCount() {
243+
return inputChannelElementTypes.size();
244+
}
245+
246+
private int randomInputChannel() {
247+
return randomIntBetween(0, inputChannelCount() - 1);
248+
}
249+
250+
private ElementType randomElementType(int channel) {
251+
return channel == scoreChannel ? ElementType.DOUBLE : randomFrom(ElementType.FLOAT, ElementType.DOUBLE, ElementType.LONG);
252+
}
253+
254+
private Stream<ExpressionEvaluator.Factory> randomFieldEvaluators() {
255+
return Stream.generate(() -> randomFieldEvaluator(randomInputChannel())).limit(randomIntBetween(0, 20));
256+
}
257+
258+
private static ExpressionEvaluator.Factory randomFieldEvaluator(int channel) {
259+
return context -> new ExpressionEvaluator() {
260+
@Override
261+
public Block eval(Page page) {
262+
Block b = page.getBlock(channel);
263+
b.incRef();
264+
;
265+
return b;
266+
}
267+
268+
@Override
269+
public void close() {
270+
271+
}
272+
};
273+
}
274+
275+
private void assertExpectedScore(DoubleBlock scoreBlockResult) {
259276
assertRandomPositions(scoreBlockResult, (pos) -> {
260277
if (pos % 10 == 0) {
261278
assertThat(scoreBlockResult.isNull(pos), equalTo(true));
@@ -291,13 +308,13 @@ <V extends Block, U> void assertBlockContentEquals(
291308
});
292309
}
293310

294-
void assertRandomPositions(Block block, Consumer<Integer> consumer) {
311+
private void assertRandomPositions(Block block, Consumer<Integer> consumer) {
295312
for (Integer pos : randomList(0, 100, () -> randomIntBetween(0, block.getPositionCount() - 1))) {
296313
consumer.accept(pos);
297314
}
298315
}
299316

300-
<V extends Block, U> void assertByteRefsBlockContentEquals(Block input, Block result, BytesRef readBuffer) {
317+
private <V extends Block, U> void assertByteRefsBlockContentEquals(Block input, Block result, BytesRef readBuffer) {
301318
assertBlockContentEquals(input, result, (BytesRefBlock b, Integer pos) -> b.getBytesRef(pos, readBuffer), BytesRefBlock.class);
302319
}
303320
}

0 commit comments

Comments
 (0)