Skip to content

Commit 678438b

Browse files
committed
Add text embedding output builder.
1 parent b8c5f11 commit 678438b

File tree

3 files changed

+351
-1
lines changed

3 files changed

+351
-1
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ public String toString() {
155155
return Strings.toString(this);
156156
}
157157

158-
float[] toFloatArray() {
158+
public float[] toFloatArray() {
159159
float[] floatArray = new float[values.length];
160160
for (int i = 0; i < values.length; i++) {
161161
floatArray[i] = ((Byte) values[i]).floatValue();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference.textembedding;
9+
10+
import org.elasticsearch.compute.data.Block;
11+
import org.elasticsearch.compute.data.FloatBlock;
12+
import org.elasticsearch.compute.data.Page;
13+
import org.elasticsearch.core.Releasables;
14+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
15+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
16+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
17+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
18+
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
19+
20+
/**
21+
* {@link TextEmbeddingOperatorOutputBuilder} builds the output page for text embedding by converting
22+
* {@link TextEmbeddingResults} into a {@link FloatBlock} containing dense vector embeddings.
23+
*/
24+
public class TextEmbeddingOperatorOutputBuilder implements InferenceOperator.OutputBuilder {
25+
private final Page inputPage;
26+
private final FloatBlock.Builder outputBlockBuilder;
27+
28+
public TextEmbeddingOperatorOutputBuilder(FloatBlock.Builder outputBlockBuilder, Page inputPage) {
29+
this.inputPage = inputPage;
30+
this.outputBlockBuilder = outputBlockBuilder;
31+
}
32+
33+
@Override
34+
public void close() {
35+
Releasables.close(outputBlockBuilder);
36+
}
37+
38+
/**
39+
* Adds an inference response to the output builder.
40+
*
41+
* <p>
42+
* If the response is null or not of type {@link TextEmbeddingResults} an {@link IllegalStateException} is thrown.
43+
* Else, the embedding vector is added to the output block as a multi-value position.
44+
* </p>
45+
*
46+
* <p>
47+
* The responses must be added in the same order as the corresponding inference requests were generated.
48+
* Failing to preserve order may lead to incorrect or misaligned output rows.
49+
* </p>
50+
*/
51+
@Override
52+
public void addInferenceResponse(InferenceAction.Response inferenceResponse) {
53+
if (inferenceResponse == null) {
54+
outputBlockBuilder.appendNull();
55+
return;
56+
}
57+
58+
TextEmbeddingResults<?> embeddingResults = inferenceResults(inferenceResponse);
59+
60+
var embeddings = embeddingResults.embeddings();
61+
if (embeddings.isEmpty()) {
62+
outputBlockBuilder.appendNull();
63+
return;
64+
}
65+
66+
float[] embeddingArray = getEmbeddingAsFloatArray(embeddingResults);
67+
68+
outputBlockBuilder.beginPositionEntry();
69+
for (float component : embeddingArray) {
70+
outputBlockBuilder.appendFloat(component);
71+
}
72+
outputBlockBuilder.endPositionEntry();
73+
}
74+
75+
/**
76+
* Builds the final output page by appending the embedding output block to the input page.
77+
*/
78+
@Override
79+
public Page buildOutput() {
80+
Block outputBlock = outputBlockBuilder.build();
81+
assert outputBlock.getPositionCount() == inputPage.getPositionCount();
82+
return inputPage.appendBlock(outputBlock);
83+
}
84+
85+
private TextEmbeddingResults<?> inferenceResults(InferenceAction.Response inferenceResponse) {
86+
return InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, TextEmbeddingResults.class);
87+
}
88+
89+
/**
90+
* Extracts the embedding as a float array from the embedding result.
91+
*/
92+
private float[] getEmbeddingAsFloatArray(TextEmbeddingResults<?> embedding) {
93+
return switch (embedding.embeddings().get(0)) {
94+
case TextEmbeddingFloatResults.Embedding floatEmbedding -> floatEmbedding.values();
95+
case TextEmbeddingByteResults.Embedding byteEmbedding -> byteEmbedding.toFloatArray();
96+
default -> throw new IllegalArgumentException(
97+
"Unsupported embedding type: "
98+
+ embedding.embeddings().get(0).getClass().getName()
99+
+ ". Expected TextEmbeddingFloatResults.Embedding or TextEmbeddingByteResults.Embedding."
100+
);
101+
};
102+
}
103+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference.textembedding;
9+
10+
import org.elasticsearch.compute.data.Block;
11+
import org.elasticsearch.compute.data.ElementType;
12+
import org.elasticsearch.compute.data.FloatBlock;
13+
import org.elasticsearch.compute.data.Page;
14+
import org.elasticsearch.compute.test.ComputeTestCase;
15+
import org.elasticsearch.compute.test.RandomBlock;
16+
import org.elasticsearch.core.Releasables;
17+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
18+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
19+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
20+
21+
import java.util.List;
22+
23+
import static org.hamcrest.Matchers.equalTo;
24+
25+
public class TextEmbeddingOperatorOutputBuilderTests extends ComputeTestCase {
26+
27+
public void testBuildSmallOutputWithFloatEmbeddings() throws Exception {
28+
assertBuildOutputWithFloatEmbeddings(between(1, 100));
29+
}
30+
31+
public void testBuildLargeOutputWithFloatEmbeddings() throws Exception {
32+
assertBuildOutputWithFloatEmbeddings(between(1_000, 10_000));
33+
}
34+
35+
public void testBuildSmallOutputWithByteEmbeddings() throws Exception {
36+
assertBuildOutputWithByteEmbeddings(between(1, 100));
37+
}
38+
39+
public void testBuildLargeOutputWithByteEmbeddings() throws Exception {
40+
assertBuildOutputWithByteEmbeddings(between(1_000, 10_000));
41+
}
42+
43+
public void testHandleNullResponses() throws Exception {
44+
final int size = between(10, 100);
45+
final Page inputPage = randomInputPage(size, between(1, 5));
46+
47+
try (
48+
TextEmbeddingOperatorOutputBuilder outputBuilder = new TextEmbeddingOperatorOutputBuilder(
49+
blockFactory().newFloatBlockBuilder(size),
50+
inputPage
51+
)
52+
) {
53+
// Add some null responses
54+
for (int currentPos = 0; currentPos < inputPage.getPositionCount(); currentPos++) {
55+
if (randomBoolean()) {
56+
outputBuilder.addInferenceResponse(null);
57+
} else {
58+
float[] embedding = randomFloatEmbedding(randomIntBetween(50, 200));
59+
outputBuilder.addInferenceResponse(createFloatEmbeddingResponse(embedding));
60+
}
61+
}
62+
63+
final Page outputPage = outputBuilder.buildOutput();
64+
assertThat(outputPage.getPositionCount(), equalTo(inputPage.getPositionCount()));
65+
assertThat(outputPage.getBlockCount(), equalTo(inputPage.getBlockCount() + 1));
66+
67+
FloatBlock outputBlock = (FloatBlock) outputPage.getBlock(outputPage.getBlockCount() - 1);
68+
assertThat(outputBlock.getPositionCount(), equalTo(size));
69+
70+
outputPage.releaseBlocks();
71+
}
72+
73+
allBreakersEmpty();
74+
}
75+
76+
public void testHandleEmptyEmbeddings() throws Exception {
77+
final int size = between(5, 50);
78+
final Page inputPage = randomInputPage(size, between(1, 3));
79+
80+
try (
81+
TextEmbeddingOperatorOutputBuilder outputBuilder = new TextEmbeddingOperatorOutputBuilder(
82+
blockFactory().newFloatBlockBuilder(size),
83+
inputPage
84+
)
85+
) {
86+
// Add responses with empty embeddings
87+
for (int currentPos = 0; currentPos < inputPage.getPositionCount(); currentPos++) {
88+
outputBuilder.addInferenceResponse(createEmptyFloatEmbeddingResponse());
89+
}
90+
91+
final Page outputPage = outputBuilder.buildOutput();
92+
FloatBlock outputBlock = (FloatBlock) outputPage.getBlock(outputPage.getBlockCount() - 1);
93+
94+
// All positions should be null due to empty embeddings
95+
for (int pos = 0; pos < outputBlock.getPositionCount(); pos++) {
96+
assertThat(outputBlock.isNull(pos), equalTo(true));
97+
}
98+
99+
outputPage.releaseBlocks();
100+
}
101+
102+
allBreakersEmpty();
103+
}
104+
105+
private void assertBuildOutputWithFloatEmbeddings(int size) throws Exception {
106+
final Page inputPage = randomInputPage(size, between(1, 10));
107+
final int embeddingDim = randomIntBetween(50, 1536); // Common embedding dimensions
108+
final float[][] expectedEmbeddings = new float[size][];
109+
110+
try (
111+
TextEmbeddingOperatorOutputBuilder outputBuilder = new TextEmbeddingOperatorOutputBuilder(
112+
blockFactory().newFloatBlockBuilder(size),
113+
inputPage
114+
)
115+
) {
116+
for (int currentPos = 0; currentPos < inputPage.getPositionCount(); currentPos++) {
117+
float[] embedding = randomFloatEmbedding(embeddingDim);
118+
expectedEmbeddings[currentPos] = embedding;
119+
outputBuilder.addInferenceResponse(createFloatEmbeddingResponse(embedding));
120+
}
121+
122+
final Page outputPage = outputBuilder.buildOutput();
123+
assertThat(outputPage.getPositionCount(), equalTo(inputPage.getPositionCount()));
124+
assertThat(outputPage.getBlockCount(), equalTo(inputPage.getBlockCount() + 1));
125+
126+
assertFloatEmbeddingContent((FloatBlock) outputPage.getBlock(outputPage.getBlockCount() - 1), expectedEmbeddings);
127+
128+
outputPage.releaseBlocks();
129+
}
130+
131+
allBreakersEmpty();
132+
}
133+
134+
private void assertBuildOutputWithByteEmbeddings(int size) throws Exception {
135+
final Page inputPage = randomInputPage(size, between(1, 10));
136+
final int embeddingDim = randomIntBetween(50, 1536);
137+
final byte[][] expectedByteEmbeddings = new byte[size][];
138+
139+
try (
140+
TextEmbeddingOperatorOutputBuilder outputBuilder = new TextEmbeddingOperatorOutputBuilder(
141+
blockFactory().newFloatBlockBuilder(size),
142+
inputPage
143+
)
144+
) {
145+
for (int currentPos = 0; currentPos < inputPage.getPositionCount(); currentPos++) {
146+
byte[] embedding = randomByteEmbedding(embeddingDim);
147+
expectedByteEmbeddings[currentPos] = embedding;
148+
outputBuilder.addInferenceResponse(createByteEmbeddingResponse(embedding));
149+
}
150+
151+
final Page outputPage = outputBuilder.buildOutput();
152+
assertThat(outputPage.getPositionCount(), equalTo(inputPage.getPositionCount()));
153+
assertThat(outputPage.getBlockCount(), equalTo(inputPage.getBlockCount() + 1));
154+
155+
assertByteEmbeddingContent((FloatBlock) outputPage.getBlock(outputPage.getBlockCount() - 1), expectedByteEmbeddings);
156+
157+
outputPage.releaseBlocks();
158+
}
159+
160+
allBreakersEmpty();
161+
}
162+
163+
private void assertFloatEmbeddingContent(FloatBlock block, float[][] expectedEmbeddings) {
164+
for (int currentPos = 0; currentPos < block.getPositionCount(); currentPos++) {
165+
assertThat(block.isNull(currentPos), equalTo(false));
166+
assertThat(block.getValueCount(currentPos), equalTo(expectedEmbeddings[currentPos].length));
167+
168+
int firstValueIndex = block.getFirstValueIndex(currentPos);
169+
for (int i = 0; i < expectedEmbeddings[currentPos].length; i++) {
170+
float actualValue = block.getFloat(firstValueIndex + i);
171+
float expectedValue = expectedEmbeddings[currentPos][i];
172+
assertThat(actualValue, equalTo(expectedValue));
173+
}
174+
}
175+
}
176+
177+
private void assertByteEmbeddingContent(FloatBlock block, byte[][] expectedByteEmbeddings) {
178+
for (int currentPos = 0; currentPos < block.getPositionCount(); currentPos++) {
179+
assertThat(block.isNull(currentPos), equalTo(false));
180+
assertThat(block.getValueCount(currentPos), equalTo(expectedByteEmbeddings[currentPos].length));
181+
182+
int firstValueIndex = block.getFirstValueIndex(currentPos);
183+
for (int i = 0; i < expectedByteEmbeddings[currentPos].length; i++) {
184+
float actualValue = block.getFloat(firstValueIndex + i);
185+
// Convert byte to float the same way as TextEmbeddingByteResults.Embedding.toFloatArray()
186+
float expectedValue = expectedByteEmbeddings[currentPos][i];
187+
assertThat(actualValue, equalTo(expectedValue));
188+
}
189+
}
190+
}
191+
192+
private float[] randomFloatEmbedding(int dimension) {
193+
float[] embedding = new float[dimension];
194+
for (int i = 0; i < dimension; i++) {
195+
embedding[i] = randomFloat();
196+
}
197+
return embedding;
198+
}
199+
200+
private byte[] randomByteEmbedding(int dimension) {
201+
byte[] embedding = new byte[dimension];
202+
for (int i = 0; i < dimension; i++) {
203+
embedding[i] = randomByte();
204+
}
205+
return embedding;
206+
}
207+
208+
private static InferenceAction.Response createFloatEmbeddingResponse(float[] embedding) {
209+
var embeddingResult = new TextEmbeddingFloatResults.Embedding(embedding);
210+
var textEmbeddingResults = new TextEmbeddingFloatResults(List.of(embeddingResult));
211+
return new InferenceAction.Response(textEmbeddingResults);
212+
}
213+
214+
private static InferenceAction.Response createByteEmbeddingResponse(byte[] embedding) {
215+
var embeddingResult = new TextEmbeddingByteResults.Embedding(embedding);
216+
var textEmbeddingResults = new TextEmbeddingByteResults(List.of(embeddingResult));
217+
return new InferenceAction.Response(textEmbeddingResults);
218+
}
219+
220+
private static InferenceAction.Response createEmptyFloatEmbeddingResponse() {
221+
var textEmbeddingResults = new TextEmbeddingFloatResults(List.of());
222+
return new InferenceAction.Response(textEmbeddingResults);
223+
}
224+
225+
private Page randomInputPage(int positionCount, int columnCount) {
226+
final Block[] blocks = new Block[columnCount];
227+
try {
228+
for (int i = 0; i < columnCount; i++) {
229+
blocks[i] = RandomBlock.randomBlock(
230+
blockFactory(),
231+
RandomBlock.randomElementExcluding(List.of(ElementType.AGGREGATE_METRIC_DOUBLE)),
232+
positionCount,
233+
randomBoolean(),
234+
0,
235+
0,
236+
randomInt(10),
237+
randomInt(10)
238+
).block();
239+
}
240+
241+
return new Page(blocks);
242+
} catch (Exception e) {
243+
Releasables.close(blocks);
244+
throw (e);
245+
}
246+
}
247+
}

0 commit comments

Comments
 (0)