Skip to content

Commit 2068e91

Browse files
committed
Text embedding inference operator.
1 parent 678438b commit 2068e91

File tree

2 files changed

+233
-0
lines changed

2 files changed

+233
-0
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference.textembedding;
9+
10+
import org.elasticsearch.compute.data.BytesRefBlock;
11+
import org.elasticsearch.compute.data.FloatBlock;
12+
import org.elasticsearch.compute.data.Page;
13+
import org.elasticsearch.compute.operator.DriverContext;
14+
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
15+
import org.elasticsearch.compute.operator.Operator;
16+
import org.elasticsearch.core.Releasables;
17+
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
18+
import org.elasticsearch.xpack.esql.inference.InferenceService;
19+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
20+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner;
21+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig;
22+
23+
/**
24+
* {@link TextEmbeddingOperator} is an {@link InferenceOperator} that performs text embedding inference.
25+
* It evaluates a text expression for each input row, constructs text embedding inference requests,
26+
* and emits the dense vector embeddings as output.
27+
*/
28+
public class TextEmbeddingOperator extends InferenceOperator {
29+
30+
private final ExpressionEvaluator textEvaluator;
31+
32+
public TextEmbeddingOperator(
33+
DriverContext driverContext,
34+
BulkInferenceRunner bulkInferenceRunner,
35+
String inferenceId,
36+
ExpressionEvaluator textEvaluator,
37+
int maxOutstandingPages
38+
) {
39+
super(driverContext, bulkInferenceRunner, inferenceId, maxOutstandingPages);
40+
this.textEvaluator = textEvaluator;
41+
}
42+
43+
@Override
44+
protected void doClose() {
45+
Releasables.close(textEvaluator);
46+
}
47+
48+
@Override
49+
public String toString() {
50+
return "TextEmbeddingOperator[inference_id=[" + inferenceId() + "]]";
51+
}
52+
53+
/**
54+
* Constructs the text embedding inference requests iterator for the given input page by evaluating the text expression.
55+
*
56+
* @param inputPage The input data page.
57+
*/
58+
@Override
59+
protected BulkInferenceRequestIterator requests(Page inputPage) {
60+
return new TextEmbeddingOperatorRequestIterator((BytesRefBlock) textEvaluator.eval(inputPage), inferenceId());
61+
}
62+
63+
/**
64+
* Creates a new {@link TextEmbeddingOperatorOutputBuilder} to collect and emit the text embedding results.
65+
*
66+
* @param input The input page for which results will be constructed.
67+
*/
68+
@Override
69+
protected TextEmbeddingOperatorOutputBuilder outputBuilder(Page input) {
70+
FloatBlock.Builder outputBlockBuilder = blockFactory().newFloatBlockBuilder(input.getPositionCount());
71+
return new TextEmbeddingOperatorOutputBuilder(outputBlockBuilder, input);
72+
}
73+
74+
/**
75+
* Factory for creating {@link TextEmbeddingOperator} instances.
76+
*/
77+
public record Factory(InferenceService inferenceService, String inferenceId, ExpressionEvaluator.Factory textEvaluatorFactory)
78+
implements
79+
OperatorFactory {
80+
@Override
81+
public String describe() {
82+
return "TextEmbeddingOperator[inference_id=[" + inferenceId + "]]";
83+
}
84+
85+
@Override
86+
public Operator get(DriverContext driverContext) {
87+
return new TextEmbeddingOperator(
88+
driverContext,
89+
inferenceService.bulkInferenceRunner(),
90+
inferenceId,
91+
textEvaluatorFactory.get(driverContext),
92+
BulkInferenceRunnerConfig.DEFAULT.maxOutstandingBulkRequests()
93+
);
94+
}
95+
}
96+
}
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference.textembedding;
9+
10+
import org.elasticsearch.compute.data.Block;
11+
import org.elasticsearch.compute.data.BytesRefBlock;
12+
import org.elasticsearch.compute.data.FloatBlock;
13+
import org.elasticsearch.compute.data.Page;
14+
import org.elasticsearch.compute.operator.Operator;
15+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
16+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
17+
import org.elasticsearch.xpack.esql.inference.InferenceOperatorTestCase;
18+
import org.hamcrest.Matcher;
19+
import org.junit.Before;
20+
21+
import java.util.List;
22+
23+
import static org.hamcrest.Matchers.equalTo;
24+
import static org.hamcrest.Matchers.hasSize;
25+
26+
public class TextEmbeddingOperatorTests extends InferenceOperatorTestCase<TextEmbeddingFloatResults> {
27+
private static final String SIMPLE_INFERENCE_ID = "test_text_embedding";
28+
private static final int EMBEDDING_DIMENSION = 384; // Common embedding dimension
29+
30+
private int inputChannel;
31+
32+
@Before
33+
public void initTextEmbeddingChannels() {
34+
inputChannel = between(0, inputsCount - 1);
35+
}
36+
37+
@Override
38+
protected Operator.OperatorFactory simple(SimpleOptions options) {
39+
return new TextEmbeddingOperator.Factory(mockedInferenceService(), SIMPLE_INFERENCE_ID, evaluatorFactory(inputChannel));
40+
}
41+
42+
@Override
43+
protected void assertSimpleOutput(List<Page> input, List<Page> results) {
44+
assertThat(results, hasSize(input.size()));
45+
46+
for (int curPage = 0; curPage < input.size(); curPage++) {
47+
Page inputPage = input.get(curPage);
48+
Page resultPage = results.get(curPage);
49+
50+
assertEquals(inputPage.getPositionCount(), resultPage.getPositionCount());
51+
assertEquals(inputPage.getBlockCount() + 1, resultPage.getBlockCount());
52+
53+
for (int channel = 0; channel < inputPage.getBlockCount(); channel++) {
54+
Block inputBlock = inputPage.getBlock(channel);
55+
Block resultBlock = resultPage.getBlock(channel);
56+
assertBlockContentEquals(inputBlock, resultBlock);
57+
}
58+
59+
assertTextEmbeddingResults(inputPage, resultPage);
60+
}
61+
}
62+
63+
private void assertTextEmbeddingResults(Page inputPage, Page resultPage) {
64+
BytesRefBlock inputBlock = resultPage.getBlock(inputChannel);
65+
FloatBlock resultBlock = (FloatBlock) resultPage.getBlock(inputPage.getBlockCount());
66+
67+
BlockStringReader blockReader = new InferenceOperatorTestCase.BlockStringReader();
68+
69+
for (int curPos = 0; curPos < inputPage.getPositionCount(); curPos++) {
70+
if (inputBlock.isNull(curPos)) {
71+
assertThat(resultBlock.isNull(curPos), equalTo(true));
72+
} else {
73+
// Verify that we have an embedding vector at this position
74+
assertThat(resultBlock.isNull(curPos), equalTo(false));
75+
assertThat(resultBlock.getValueCount(curPos), equalTo(EMBEDDING_DIMENSION));
76+
77+
// Get the input text to verify our mock embedding generation
78+
String inputText = blockReader.readString(inputBlock, curPos);
79+
80+
// Verify the embedding values match our mock generation pattern
81+
int firstValueIndex = resultBlock.getFirstValueIndex(curPos);
82+
for (int i = 0; i < EMBEDDING_DIMENSION; i++) {
83+
float expectedValue = generateMockEmbeddingValue(inputText, i);
84+
float actualValue = resultBlock.getFloat(firstValueIndex + i);
85+
assertThat(actualValue, equalTo(expectedValue));
86+
}
87+
}
88+
}
89+
}
90+
91+
@Override
92+
protected TextEmbeddingFloatResults mockInferenceResult(InferenceAction.Request request) {
93+
// For text embedding, we expect one input text per request
94+
String inputText = request.getInput().get(0);
95+
96+
// Generate a deterministic mock embedding based on the input text
97+
float[] mockEmbedding = generateMockEmbedding(inputText, EMBEDDING_DIMENSION);
98+
99+
var embeddingResult = new TextEmbeddingFloatResults.Embedding(mockEmbedding);
100+
return new TextEmbeddingFloatResults(List.of(embeddingResult));
101+
}
102+
103+
@Override
104+
protected Matcher<String> expectedDescriptionOfSimple() {
105+
return expectedToStringOfSimple();
106+
}
107+
108+
@Override
109+
protected Matcher<String> expectedToStringOfSimple() {
110+
return equalTo("TextEmbeddingOperator[inference_id=[" + SIMPLE_INFERENCE_ID + "]]");
111+
}
112+
113+
/**
114+
* Generates a deterministic mock embedding vector based on the input text.
115+
* This ensures our tests are repeatable and verifiable.
116+
*/
117+
private float[] generateMockEmbedding(String inputText, int dimension) {
118+
float[] embedding = new float[dimension];
119+
int textHash = inputText.hashCode();
120+
121+
for (int i = 0; i < dimension; i++) {
122+
embedding[i] = generateMockEmbeddingValue(inputText, i);
123+
}
124+
125+
return embedding;
126+
}
127+
128+
/**
129+
* Generates a single embedding value for a specific dimension based on input text.
130+
* Uses a deterministic function so tests are repeatable.
131+
*/
132+
private float generateMockEmbeddingValue(String inputText, int dimension) {
133+
// Create a deterministic value based on input text and dimension
134+
int hash = (inputText.hashCode() + dimension * 31) % 10000;
135+
return hash / 10000.0f; // Normalize to [0, 1) range
136+
}
137+
}

0 commit comments

Comments
 (0)