Skip to content

Commit f4be9cf

Browse files
committed
Add tests for the dense vector operator.
1 parent f91f573 commit f4be9cf

File tree

8 files changed

+490
-12
lines changed

8 files changed

+490
-12
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperator.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ public String toString() {
7777
* Returns the request iterator responsible for batching and converting input rows into inference requests.
7878
*/
7979
@Override
80-
protected DenseEmbeddingRequestIterator requests(Page inputPage) {
80+
protected DenseEmbeddingOperatorRequestIterator requests(Page inputPage) {
8181
int inputBlockChannel = inputPage.getBlockCount() - 1;
82-
return new DenseEmbeddingRequestIterator(inputPage.getBlock(inputBlockChannel), inferenceId(), batchSize);
82+
return new DenseEmbeddingOperatorRequestIterator(inputPage.getBlock(inputBlockChannel), inferenceId(), batchSize);
8383
}
8484

8585
/**
@@ -98,8 +98,12 @@ protected DenseEmbeddingOperatorOutputBuilder outputBuilder(Page input) {
9898
/**
9999
* Factory for creating {@link DenseEmbeddingOperator} instances
100100
*/
101-
public record Factory(InferenceRunner inferenceRunner, int dimensions, String inferenceId,
102-
ExpressionEvaluator.Factory inputEvaluatorFactory) implements OperatorFactory {
101+
public record Factory(
102+
InferenceRunner inferenceRunner,
103+
int dimensions,
104+
String inferenceId,
105+
ExpressionEvaluator.Factory inputEvaluatorFactory
106+
) implements OperatorFactory {
103107

104108
@Override
105109
public String describe() {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilder.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
*/
2727
public class DenseEmbeddingOperatorOutputBuilder implements InferenceOperator.OutputBuilder {
2828

29-
3029
private final FloatBlock.Builder outputBlockBuilder;
3130
private final Page inputPage;
3231
private final int dimensions;
@@ -95,13 +94,13 @@ public boolean hasNext() {
9594

9695
public float[] next() {
9796
EmbeddingResults.Embedding<?> embedding = embeddingsIterator.next();
98-
float[] values = switch(embedding) {
97+
float[] values = switch (embedding) {
9998
case TextEmbeddingFloatResults.Embedding textEmbeddingFloat -> textEmbeddingFloat.values();
10099
case TextEmbeddingByteResults.Embedding textEmbeddingBytes -> toFloatArray(textEmbeddingBytes.values());
101100
default -> throw new IllegalStateException("Unsupported embedding type [" + embedding.getClass() + "]");
102101
};
103102

104-
assert values.length == dimensions : "Unexpected vector size: " + values.length ;
103+
assert values.length == dimensions : "Unexpected vector size: " + values.length;
105104

106105
return values;
107106
}
@@ -112,9 +111,11 @@ private static float[] toFloatArray(byte[] bytes) {
112111
return floatValues;
113112
}
114113

115-
116114
public static EmbeddingValueReader of(InferenceAction.Response inferenceResponse, int dimensions) {
117-
TextEmbeddingResults<?> inferenceResults = InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, TextEmbeddingResults.class);
115+
TextEmbeddingResults<?> inferenceResults = InferenceOperator.OutputBuilder.inferenceResults(
116+
inferenceResponse,
117+
TextEmbeddingResults.class
118+
);
118119
return new EmbeddingValueReader(inferenceResults.embeddings().iterator(), dimensions);
119120
}
120121
}
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
* <p>This iterator reads from a {@link BytesRefBlock} containing text to be embedded. It slices the input into batches
2525
* of configurable size and converts each batch into an {@link InferenceAction.Request} with the task type {@link TaskType#TEXT_EMBEDDING}.
2626
*/
27-
public class DenseEmbeddingRequestIterator implements BulkInferenceRequestIterator {
27+
public class DenseEmbeddingOperatorRequestIterator implements BulkInferenceRequestIterator {
2828
private final BytesRefBlock inputBlock;
2929
private final String inferenceId;
3030
private final int batchSize;
3131
private int remainingPositions;
3232

33-
public DenseEmbeddingRequestIterator(BytesRefBlock inputBlock, String inferenceId, int batchSize) {
33+
public DenseEmbeddingOperatorRequestIterator(BytesRefBlock inputBlock, String inferenceId, int batchSize) {
3434
this.inputBlock = inputBlock;
3535
this.inferenceId = inferenceId;
3636
this.batchSize = batchSize;

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ protected Page createPage(int positionOffset, int length) {
9393
if (randomInt() % 100 == 0) {
9494
builder.appendNull();
9595
} else {
96-
builder.appendBytesRef(new BytesRef(randomAlphaOfLength(10)));
96+
builder.appendBytesRef(new BytesRef(randomAlphaOfLength(randomInt(10))));
9797
}
9898

9999
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference.embedding;
9+
10+
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
11+
12+
import org.elasticsearch.compute.data.Block;
13+
import org.elasticsearch.compute.data.FloatBlock;
14+
import org.elasticsearch.compute.data.Page;
15+
import org.elasticsearch.compute.test.ComputeTestCase;
16+
import org.elasticsearch.compute.test.RandomBlock;
17+
import org.elasticsearch.core.Releasables;
18+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
19+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
20+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
21+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
22+
23+
import java.util.ArrayList;
24+
import java.util.List;
25+
26+
import static org.hamcrest.Matchers.equalTo;
27+
28+
public class DenseEmbeddingOperatorOutputBuilderTests extends ComputeTestCase {
29+
30+
private final static List<Integer> DIMENSIONS = List.of(1, 32, 128, 512, 2048, 5096);
31+
private final static List<Integer> INPUT_SIZES = List.of(10, 100, 1_000, 10_000);
32+
private final static List<Integer> BATCH_SIZES = List.of(1, 10, 100, 1000);
33+
private final static List<Class<? extends TextEmbeddingResults<?>>> EMBEDDING_TYPES = List.of(
34+
TextEmbeddingBitResults.class,
35+
TextEmbeddingByteResults.class,
36+
TextEmbeddingFloatResults.class
37+
);
38+
39+
private final static String TEST_PARAMS_FORMATING = "dims=%d, input_size=%d, batch_size=%d, embedding_type=%s";
40+
41+
private final int dimensions;
42+
private final int inputPageSize;
43+
private final int batchSize;
44+
private final Class<? extends TextEmbeddingResults<?>> embeddingType;
45+
46+
@ParametersFactory(argumentFormatting = TEST_PARAMS_FORMATING)
47+
public static Iterable<Object[]> parameters() {
48+
List<Object[]> params = new ArrayList<>();
49+
params.add(new Object[] {});
50+
51+
for (List<?> axis : List.of(DIMENSIONS, INPUT_SIZES, BATCH_SIZES, EMBEDDING_TYPES)) {
52+
53+
List<Object[]> newParams = new ArrayList<>();
54+
for (Object[] combination : params) {
55+
for (Object element : axis) {
56+
Object[] newCombination = new Object[combination.length + 1];
57+
System.arraycopy(combination, 0, newCombination, 0, combination.length);
58+
newCombination[newCombination.length - 1] = element;
59+
newParams.add(newCombination);
60+
}
61+
}
62+
params = newParams;
63+
}
64+
65+
return params;
66+
}
67+
68+
public DenseEmbeddingOperatorOutputBuilderTests(
69+
int dimensions,
70+
int inputPageSize,
71+
int batchSize,
72+
Class<? extends TextEmbeddingBitResults> embeddingType
73+
) {
74+
this.dimensions = dimensions;
75+
this.inputPageSize = inputPageSize;
76+
this.batchSize = batchSize;
77+
this.embeddingType = embeddingType;
78+
}
79+
80+
public void testOutput() {
81+
final Page inputPage = randomInputPage(inputPageSize, between(1, 20));
82+
try (
83+
DenseEmbeddingOperatorOutputBuilder outputBuilder = new DenseEmbeddingOperatorOutputBuilder(
84+
blockFactory().newFloatBlockBuilder(inputPageSize * dimensions),
85+
inputPage,
86+
dimensions
87+
)
88+
) {
89+
for (int currentPos = 0; currentPos < inputPage.getPositionCount(); currentPos += batchSize) {
90+
outputBuilder.addInferenceResponse(
91+
DenseEmbeddingUtils.inferenceResponse(
92+
currentPos,
93+
Math.min(inputPage.getPositionCount() - currentPos, batchSize),
94+
embeddingType,
95+
dimensions
96+
)
97+
);
98+
}
99+
100+
final Page outputPage = outputBuilder.buildOutput();
101+
try {
102+
assertThat(outputPage.getPositionCount(), equalTo(inputPage.getPositionCount()));
103+
assertThat(outputPage.getBlockCount(), equalTo(inputPage.getBlockCount() + 1));
104+
assertOutputContent(outputPage.getBlock(inputPage.getBlockCount()));
105+
} finally {
106+
outputPage.releaseBlocks();
107+
}
108+
109+
} finally {
110+
inputPage.releaseBlocks();
111+
}
112+
}
113+
114+
private void assertOutputContent(FloatBlock block) {
115+
116+
for (int currentPos = 0; currentPos < block.getPositionCount(); currentPos++) {
117+
assertThat(block.isNull(currentPos), equalTo(false));
118+
assertThat(block.getValueCount(currentPos), equalTo(dimensions));
119+
float[] expectedEmbeddingValues = new float[dimensions];
120+
if (embeddingType.equals(TextEmbeddingByteResults.class) || embeddingType.equals(TextEmbeddingBitResults.class)) {
121+
byte[] bytesValues = embeddingType.equals(TextEmbeddingBitResults.class)
122+
? DenseEmbeddingUtils.toBitArray(currentPos, dimensions)
123+
: DenseEmbeddingUtils.toByteArray(currentPos, dimensions);
124+
assertThat(bytesValues.length, equalTo(dimensions));
125+
for (int i = 0; i < bytesValues.length; i++) {
126+
expectedEmbeddingValues[i] = bytesValues[i];
127+
}
128+
} else if (embeddingType.equals(TextEmbeddingFloatResults.class)) {
129+
expectedEmbeddingValues = DenseEmbeddingUtils.toFloatArray(currentPos, dimensions);
130+
}
131+
132+
for (int valueIndex = 0; valueIndex < block.getValueCount(currentPos); valueIndex++) {
133+
assertThat(block.getFloat(block.getFirstValueIndex(currentPos) + valueIndex), equalTo(expectedEmbeddingValues[valueIndex]));
134+
}
135+
}
136+
}
137+
138+
private Page randomInputPage(int positionCount, int columnCount) {
139+
final Block[] blocks = new Block[columnCount];
140+
try {
141+
for (int i = 0; i < columnCount; i++) {
142+
blocks[i] = RandomBlock.randomBlock(
143+
blockFactory(),
144+
RandomBlock.randomElementType(),
145+
positionCount,
146+
randomBoolean(),
147+
0,
148+
0,
149+
randomInt(10),
150+
randomInt(10)
151+
).block();
152+
}
153+
154+
return new Page(blocks);
155+
} catch (Exception e) {
156+
Releasables.close(blocks);
157+
throw (e);
158+
}
159+
}
160+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference.embedding;
9+
10+
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.compute.data.BytesRefBlock;
12+
import org.elasticsearch.compute.test.ComputeTestCase;
13+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
14+
15+
import java.util.List;
16+
17+
import static org.hamcrest.Matchers.equalTo;
18+
19+
public class DenseEmbeddingOperatorRequestIteratorTests extends ComputeTestCase {
20+
21+
public void testIterateSmallInput() {
22+
assertIterate(between(1, 100), randomIntBetween(1, 1_000));
23+
}
24+
25+
public void testIterateLargeInput() {
26+
assertIterate(between(10_000, 100_000), randomIntBetween(1, 1_000));
27+
}
28+
29+
private void assertIterate(int size, int batchSize) {
30+
final String inferenceId = randomIdentifier();
31+
32+
try (
33+
BytesRefBlock inputBlock = randomInputBlock(size);
34+
DenseEmbeddingOperatorRequestIterator requestIterator = new DenseEmbeddingOperatorRequestIterator(
35+
inputBlock,
36+
inferenceId,
37+
batchSize
38+
)
39+
) {
40+
BytesRef scratch = new BytesRef();
41+
42+
for (int currentPos = 0; requestIterator.hasNext();) {
43+
InferenceAction.Request request = requestIterator.next();
44+
45+
assertThat(request.getInferenceEntityId(), equalTo(inferenceId));
46+
List<String> inputs = request.getInput();
47+
for (String input : inputs) {
48+
scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(currentPos), scratch);
49+
assertThat(input, equalTo(scratch.utf8ToString()));
50+
currentPos++;
51+
}
52+
}
53+
}
54+
}
55+
56+
private BytesRefBlock randomInputBlock(int size) {
57+
try (BytesRefBlock.Builder blockBuilder = blockFactory().newBytesRefBlockBuilder(size)) {
58+
for (int i = 0; i < size; i++) {
59+
blockBuilder.appendBytesRef(new BytesRef(randomAlphaOfLength(10)));
60+
}
61+
62+
return blockBuilder.build();
63+
}
64+
}
65+
}

0 commit comments

Comments
 (0)