Skip to content

Commit d3a47a2

Browse files
committed
Adding more tests for RerankOperatorOutputBuilder and RerankOperatorRequestIterator
1 parent d3f7a0e commit d3a47a2

File tree

4 files changed

+181
-4
lines changed

4 files changed

+181
-4
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public void close() {
3636

3737
@Override
3838
public Page buildOutput() {
39-
int blockCount = Integer.max(inputPage.getBlockCount() - 1, scoreChannel + 1);
39+
int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1);
4040
Block[] blocks = new Block[blockCount];
4141

4242
try {
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference.bulk;
9+
10+
import org.elasticsearch.compute.data.Block;
11+
import org.elasticsearch.compute.data.DoubleBlock;
12+
import org.elasticsearch.compute.data.Page;
13+
import org.elasticsearch.compute.test.ComputeTestCase;
14+
import org.elasticsearch.compute.test.RandomBlock;
15+
import org.elasticsearch.core.Releasables;
16+
import org.elasticsearch.logging.LogManager;
17+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
18+
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
19+
import org.elasticsearch.xpack.esql.inference.rerank.RerankOperatorOutputBuilder;
20+
21+
import java.util.ArrayList;
22+
import java.util.List;
23+
24+
import static org.hamcrest.Matchers.equalTo;
25+
26+
public class RerankOperatorOutputBuilderTests extends ComputeTestCase {
27+
28+
public void testBuildSmallOutput() {
29+
assertBuildOutput(between(1, 100));
30+
}
31+
32+
public void testBuildLargeOutput() {
33+
assertBuildOutput(between(10_000, 100_000));
34+
}
35+
36+
private void assertBuildOutput(int size) {
37+
final Page inputPage = randomInputPage(size, between(1, 20));
38+
final int scoreChannel = randomIntBetween(0, inputPage.getBlockCount());
39+
try (
40+
RerankOperatorOutputBuilder outputBuilder = new RerankOperatorOutputBuilder(
41+
blockFactory().newDoubleBlockBuilder(size),
42+
inputPage,
43+
scoreChannel
44+
)
45+
) {
46+
int batchSize = randomIntBetween(1, size);
47+
for (int currentPos = 0; currentPos < inputPage.getPositionCount();) {
48+
List<RankedDocsResults.RankedDoc> rankedDocs = new ArrayList<>();
49+
for (int rankedDocIndex = 0; rankedDocIndex < batchSize && currentPos < inputPage.getPositionCount(); rankedDocIndex++) {
50+
rankedDocs.add(new RankedDocsResults.RankedDoc(rankedDocIndex, relevanceScore(currentPos), randomIdentifier()));
51+
currentPos++;
52+
}
53+
54+
outputBuilder.addInferenceResponse(new InferenceAction.Response(new RankedDocsResults(rankedDocs)));
55+
}
56+
57+
final Page outputPage = outputBuilder.buildOutput();
58+
try {
59+
assertThat(outputPage.getPositionCount(), equalTo(inputPage.getPositionCount()));
60+
LogManager.getLogger(RerankOperatorOutputBuilderTests.class)
61+
.info(
62+
"{} , {}, {}, {}",
63+
scoreChannel,
64+
inputPage.getBlockCount(),
65+
outputPage.getBlockCount(),
66+
Math.max(scoreChannel + 1, inputPage.getBlockCount())
67+
);
68+
assertThat(outputPage.getBlockCount(), equalTo(Integer.max(scoreChannel + 1, inputPage.getBlockCount())));
69+
assertOutputContent(outputPage.getBlock(scoreChannel));
70+
} finally {
71+
outputPage.releaseBlocks();
72+
}
73+
74+
} finally {
75+
inputPage.releaseBlocks();
76+
}
77+
}
78+
79+
private float relevanceScore(int position) {
80+
return (float) 1 / (1 + position);
81+
}
82+
83+
private void assertOutputContent(DoubleBlock block) {
84+
for (int currentPos = 0; currentPos < block.getPositionCount(); currentPos++) {
85+
assertThat(block.isNull(currentPos), equalTo(false));
86+
assertThat(block.getValueCount(currentPos), equalTo(1));
87+
assertThat(block.getDouble(block.getFirstValueIndex(currentPos)), equalTo((double) relevanceScore(currentPos)));
88+
}
89+
}
90+
91+
private Page randomInputPage(int positionCount, int columnCount) {
92+
final Block[] blocks = new Block[columnCount];
93+
try {
94+
for (int i = 0; i < columnCount; i++) {
95+
blocks[i] = RandomBlock.randomBlock(
96+
blockFactory(),
97+
RandomBlock.randomElementType(),
98+
positionCount,
99+
randomBoolean(),
100+
0,
101+
0,
102+
randomInt(10),
103+
randomInt(10)
104+
).block();
105+
}
106+
107+
return new Page(blocks);
108+
} catch (Exception e) {
109+
Releasables.close(blocks);
110+
throw (e);
111+
}
112+
}
113+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference.bulk;
9+
10+
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.compute.data.BytesRefBlock;
12+
import org.elasticsearch.compute.test.ComputeTestCase;
13+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
14+
import org.elasticsearch.xpack.esql.inference.rerank.RerankOperatorRequestIterator;
15+
16+
import java.util.List;
17+
18+
import static org.hamcrest.Matchers.equalTo;
19+
20+
public class RerankOperatorRequestIteratorTests extends ComputeTestCase {
21+
22+
public void testIterateSmallInput() {
23+
assertIterate(between(1, 100), randomIntBetween(1, 1_000));
24+
}
25+
26+
public void testIterateLargeInput() {
27+
assertIterate(between(10_000, 100_000), randomIntBetween(1, 1_000));
28+
}
29+
30+
private void assertIterate(int size, int batchSize) {
31+
final BytesRefBlock inputBlock = randomInputBlock(size);
32+
final String inferenceId = randomIdentifier();
33+
final String queryText = randomIdentifier();
34+
35+
try (
36+
RerankOperatorRequestIterator requestIterator = new RerankOperatorRequestIterator(inputBlock, inferenceId, queryText, batchSize)
37+
) {
38+
BytesRef scratch = new BytesRef();
39+
40+
for (int currentPos = 0; requestIterator.hasNext();) {
41+
InferenceAction.Request request = requestIterator.next();
42+
43+
assertThat(request.getInferenceEntityId(), equalTo(inferenceId));
44+
assertThat(request.getQuery(), equalTo(queryText));
45+
List<String> inputs = request.getInput();
46+
for (String input : inputs) {
47+
scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(currentPos), scratch);
48+
assertThat(input, equalTo(scratch.utf8ToString()));
49+
currentPos++;
50+
}
51+
}
52+
}
53+
}
54+
55+
private BytesRefBlock randomInputBlock(int size) {
56+
try (BytesRefBlock.Builder blockBuilder = blockFactory().newBytesRefBlockBuilder(size)) {
57+
for (int i = 0; i < size; i++) {
58+
blockBuilder.appendBytesRef(new BytesRef(randomAlphaOfLength(10)));
59+
}
60+
61+
return blockBuilder.build();
62+
}
63+
}
64+
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilderTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public void testBuildLargeOutput() {
3232
}
3333

3434
private void assertBuildOutput(int size) {
35-
Page inputPage = randomInputPage(size, between(1, 20));
35+
final Page inputPage = randomInputPage(size, between(1, 20));
3636
try (
3737
CompletionOperatorOutputBuilder outputBuilder = new CompletionOperatorOutputBuilder(
3838
blockFactory().newBytesRefBlockBuilder(size),
@@ -44,7 +44,7 @@ private void assertBuildOutput(int size) {
4444
outputBuilder.addInferenceResponse(new InferenceAction.Response(new ChatCompletionResults(results)));
4545
}
4646

47-
Page outputPage = outputBuilder.buildOutput();
47+
final Page outputPage = outputBuilder.buildOutput();
4848
assertThat(outputPage.getPositionCount(), equalTo(inputPage.getPositionCount()));
4949
assertThat(outputPage.getBlockCount(), equalTo(inputPage.getBlockCount() + 1));
5050
assertOutputContent(outputPage.getBlock(outputPage.getBlockCount() - 1));
@@ -68,7 +68,7 @@ private void assertOutputContent(BytesRefBlock block) {
6868
}
6969

7070
private Page randomInputPage(int positionCount, int columnCount) {
71-
Block[] blocks = new Block[columnCount];
71+
final Block[] blocks = new Block[columnCount];
7272
try {
7373
for (int i = 0; i < columnCount; i++) {
7474
blocks[i] = RandomBlock.randomBlock(

0 commit comments

Comments
 (0)