Skip to content

Commit 433ab3a

Browse files
authored
[ES|QL] Text embedding inference operator (#135062)
* Move CompletionOperatorRequestIterator.PromptReader to InputTextReader so it can be reused. * Implementing the TextEmbeddingInferenceOperator. * Inference operator request iterator and output builder are package private now. * Fix comment
1 parent 0363570 commit 433ab3a

File tree

12 files changed

+1160
-64
lines changed

12 files changed

+1160
-64
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference;
9+
10+
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.compute.data.BytesRefBlock;
12+
import org.elasticsearch.core.Releasable;
13+
import org.elasticsearch.core.Releasables;
14+
15+
/**
16+
* Helper class that reads text strings from a {@link BytesRefBlock}.
17+
* This class is used by inference operators to extract text content from block data.
18+
*/
19+
public class InputTextReader implements Releasable {
20+
private final BytesRefBlock textBlock;
21+
private final StringBuilder strBuilder = new StringBuilder();
22+
private BytesRef readBuffer = new BytesRef();
23+
24+
public InputTextReader(BytesRefBlock textBlock) {
25+
this.textBlock = textBlock;
26+
}
27+
28+
/**
29+
* Reads the text string at the given position.
30+
* Multiple values at the position are concatenated with newlines.
31+
*
32+
* @param pos the position index in the block
33+
* @return the text string at the position, or null if the position contains a null value
34+
*/
35+
public String readText(int pos) {
36+
return readText(pos, Integer.MAX_VALUE);
37+
}
38+
39+
/**
40+
* Reads the text string at the given position.
41+
*
42+
* @param pos the position index in the block
43+
* @param limit the maximum number of value to read from the position
44+
* @return the text string at the position, or null if the position contains a null value
45+
*/
46+
public String readText(int pos, int limit) {
47+
if (textBlock.isNull(pos)) {
48+
return null;
49+
}
50+
51+
strBuilder.setLength(0);
52+
int maxPos = Math.min(limit, textBlock.getValueCount(pos));
53+
for (int valueIndex = 0; valueIndex < maxPos; valueIndex++) {
54+
readBuffer = textBlock.getBytesRef(textBlock.getFirstValueIndex(pos) + valueIndex, readBuffer);
55+
strBuilder.append(readBuffer.utf8ToString());
56+
if (valueIndex != maxPos - 1) {
57+
strBuilder.append("\n");
58+
}
59+
}
60+
61+
return strBuilder.toString();
62+
}
63+
64+
/**
65+
* Returns the total number of positions (text entries) in the block.
66+
*/
67+
public int estimatedSize() {
68+
return textBlock.getPositionCount();
69+
}
70+
71+
@Override
72+
public void close() {
73+
textBlock.allowPassingToDifferentDriver();
74+
Releasables.close(textBlock);
75+
}
76+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
* {@link CompletionOperatorOutputBuilder} builds the output page for {@link CompletionOperator} by converting {@link ChatCompletionResults}
2121
* into a {@link BytesRefBlock}.
2222
*/
23-
public class CompletionOperatorOutputBuilder implements InferenceOperator.OutputBuilder {
23+
class CompletionOperatorOutputBuilder implements InferenceOperator.OutputBuilder {
2424
private final Page inputPage;
2525
private final BytesRefBlock.Builder outputBlockBuilder;
2626
private final BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();
2727

28-
public CompletionOperatorOutputBuilder(BytesRefBlock.Builder outputBlockBuilder, Page inputPage) {
28+
CompletionOperatorOutputBuilder(BytesRefBlock.Builder outputBlockBuilder, Page inputPage) {
2929
this.inputPage = inputPage;
3030
this.outputBlockBuilder = outputBlockBuilder;
3131
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java

Lines changed: 8 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77

88
package org.elasticsearch.xpack.esql.inference.completion;
99

10-
import org.apache.lucene.util.BytesRef;
1110
import org.elasticsearch.compute.data.BytesRefBlock;
12-
import org.elasticsearch.core.Releasable;
1311
import org.elasticsearch.core.Releasables;
1412
import org.elasticsearch.inference.TaskType;
1513
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
14+
import org.elasticsearch.xpack.esql.inference.InputTextReader;
1615
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
1716

1817
import java.util.List;
@@ -22,9 +21,9 @@
2221
* This iterator reads prompts from a {@link BytesRefBlock} and converts them into individual {@link InferenceAction.Request} instances
2322
* of type {@link TaskType#COMPLETION}.
2423
*/
25-
public class CompletionOperatorRequestIterator implements BulkInferenceRequestIterator {
24+
class CompletionOperatorRequestIterator implements BulkInferenceRequestIterator {
2625

27-
private final PromptReader promptReader;
26+
private final InputTextReader textReader;
2827
private final String inferenceId;
2928
private final int size;
3029
private int currentPos = 0;
@@ -35,8 +34,8 @@ public class CompletionOperatorRequestIterator implements BulkInferenceRequestIt
3534
* @param promptBlock The input block containing prompts.
3635
* @param inferenceId The ID of the inference model to invoke.
3736
*/
38-
public CompletionOperatorRequestIterator(BytesRefBlock promptBlock, String inferenceId) {
39-
this.promptReader = new PromptReader(promptBlock);
37+
CompletionOperatorRequestIterator(BytesRefBlock promptBlock, String inferenceId) {
38+
this.textReader = new InputTextReader(promptBlock);
4039
this.size = promptBlock.getPositionCount();
4140
this.inferenceId = inferenceId;
4241
}
@@ -52,7 +51,7 @@ public InferenceAction.Request next() {
5251
throw new NoSuchElementException();
5352
}
5453

55-
return inferenceRequest(promptReader.readPrompt(currentPos++));
54+
return inferenceRequest(textReader.readText(currentPos++));
5655
}
5756

5857
/**
@@ -68,60 +67,11 @@ private InferenceAction.Request inferenceRequest(String prompt) {
6867

6968
@Override
7069
public int estimatedSize() {
71-
return promptReader.estimatedSize();
70+
return textReader.estimatedSize();
7271
}
7372

7473
@Override
7574
public void close() {
76-
Releasables.close(promptReader);
77-
}
78-
79-
/**
80-
* Helper class that reads prompts from a {@link BytesRefBlock}.
81-
*/
82-
private static class PromptReader implements Releasable {
83-
private final BytesRefBlock promptBlock;
84-
private final StringBuilder strBuilder = new StringBuilder();
85-
private BytesRef readBuffer = new BytesRef();
86-
87-
private PromptReader(BytesRefBlock promptBlock) {
88-
this.promptBlock = promptBlock;
89-
}
90-
91-
/**
92-
* Reads the prompt string at the given position..
93-
*
94-
* @param pos the position index in the block
95-
*/
96-
public String readPrompt(int pos) {
97-
if (promptBlock.isNull(pos)) {
98-
return null;
99-
}
100-
101-
strBuilder.setLength(0);
102-
103-
for (int valueIndex = 0; valueIndex < promptBlock.getValueCount(pos); valueIndex++) {
104-
readBuffer = promptBlock.getBytesRef(promptBlock.getFirstValueIndex(pos) + valueIndex, readBuffer);
105-
strBuilder.append(readBuffer.utf8ToString());
106-
if (valueIndex != promptBlock.getValueCount(pos) - 1) {
107-
strBuilder.append("\n");
108-
}
109-
}
110-
111-
return strBuilder.toString();
112-
}
113-
114-
/**
115-
* Returns the total number of positions (prompts) in the block.
116-
*/
117-
public int estimatedSize() {
118-
return promptBlock.getPositionCount();
119-
}
120-
121-
@Override
122-
public void close() {
123-
promptBlock.allowPassingToDifferentDriver();
124-
Releasables.close(promptBlock);
125-
}
75+
Releasables.close(textReader);
12676
}
12777
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
* * reranked relevance scores into the specified score channel of the input page.
2525
*/
2626

27-
public class RerankOperatorOutputBuilder implements InferenceOperator.OutputBuilder {
27+
class RerankOperatorOutputBuilder implements InferenceOperator.OutputBuilder {
2828

2929
private final Page inputPage;
3030
private final DoubleBlock.Builder scoreBlockBuilder;
3131
private final int scoreChannel;
3232

33-
public RerankOperatorOutputBuilder(DoubleBlock.Builder scoreBlockBuilder, Page inputPage, int scoreChannel) {
33+
RerankOperatorOutputBuilder(DoubleBlock.Builder scoreBlockBuilder, Page inputPage, int scoreChannel) {
3434
this.inputPage = inputPage;
3535
this.scoreBlockBuilder = scoreBlockBuilder;
3636
this.scoreChannel = scoreChannel;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@
2525
* <p>This iterator reads from a {@link BytesRefBlock} containing input documents or items to be reranked. It slices the input into batches
2626
* of configurable size and converts each batch into an {@link InferenceAction.Request} with the task type {@link TaskType#RERANK}.
2727
*/
28-
public class RerankOperatorRequestIterator implements BulkInferenceRequestIterator {
28+
class RerankOperatorRequestIterator implements BulkInferenceRequestIterator {
2929
private final BytesRefBlock inputBlock;
3030
private final String inferenceId;
3131
private final String queryText;
3232
private final int batchSize;
3333
private int remainingPositions;
3434

35-
public RerankOperatorRequestIterator(BytesRefBlock inputBlock, String inferenceId, String queryText, int batchSize) {
35+
RerankOperatorRequestIterator(BytesRefBlock inputBlock, String inferenceId, String queryText, int batchSize) {
3636
this.inputBlock = inputBlock;
3737
this.inferenceId = inferenceId;
3838
this.queryText = queryText;
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference.textembedding;
9+
10+
import org.elasticsearch.compute.data.BytesRefBlock;
11+
import org.elasticsearch.compute.data.FloatBlock;
12+
import org.elasticsearch.compute.data.Page;
13+
import org.elasticsearch.compute.operator.DriverContext;
14+
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
15+
import org.elasticsearch.compute.operator.Operator;
16+
import org.elasticsearch.core.Releasables;
17+
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
18+
import org.elasticsearch.xpack.esql.inference.InferenceService;
19+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
20+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner;
21+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig;
22+
23+
/**
24+
* {@link TextEmbeddingOperator} is an {@link InferenceOperator} that performs text embedding inference.
25+
* It evaluates a text expression for each input row, constructs text embedding inference requests,
26+
* and emits the dense vector embeddings as output.
27+
*/
28+
public class TextEmbeddingOperator extends InferenceOperator {
29+
30+
private final ExpressionEvaluator textEvaluator;
31+
32+
public TextEmbeddingOperator(
33+
DriverContext driverContext,
34+
BulkInferenceRunner bulkInferenceRunner,
35+
String inferenceId,
36+
ExpressionEvaluator textEvaluator,
37+
int maxOutstandingPages
38+
) {
39+
super(driverContext, bulkInferenceRunner, inferenceId, maxOutstandingPages);
40+
this.textEvaluator = textEvaluator;
41+
}
42+
43+
@Override
44+
protected void doClose() {
45+
Releasables.close(textEvaluator);
46+
}
47+
48+
@Override
49+
public String toString() {
50+
return "TextEmbeddingOperator[inference_id=[" + inferenceId() + "]]";
51+
}
52+
53+
/**
54+
* Constructs the text embedding inference requests iterator for the given input page by evaluating the text expression.
55+
*
56+
* @param inputPage The input data page.
57+
*/
58+
@Override
59+
protected BulkInferenceRequestIterator requests(Page inputPage) {
60+
return new TextEmbeddingOperatorRequestIterator((BytesRefBlock) textEvaluator.eval(inputPage), inferenceId());
61+
}
62+
63+
/**
64+
* Creates a new {@link TextEmbeddingOperatorOutputBuilder} to collect and emit the text embedding results.
65+
*
66+
* @param input The input page for which results will be constructed.
67+
*/
68+
@Override
69+
protected TextEmbeddingOperatorOutputBuilder outputBuilder(Page input) {
70+
FloatBlock.Builder outputBlockBuilder = blockFactory().newFloatBlockBuilder(input.getPositionCount());
71+
return new TextEmbeddingOperatorOutputBuilder(outputBlockBuilder, input);
72+
}
73+
74+
/**
75+
* Factory for creating {@link TextEmbeddingOperator} instances.
76+
*/
77+
public record Factory(InferenceService inferenceService, String inferenceId, ExpressionEvaluator.Factory textEvaluatorFactory)
78+
implements
79+
OperatorFactory {
80+
@Override
81+
public String describe() {
82+
return "TextEmbeddingOperator[inference_id=[" + inferenceId + "]]";
83+
}
84+
85+
@Override
86+
public Operator get(DriverContext driverContext) {
87+
return new TextEmbeddingOperator(
88+
driverContext,
89+
inferenceService.bulkInferenceRunner(),
90+
inferenceId,
91+
textEvaluatorFactory.get(driverContext),
92+
BulkInferenceRunnerConfig.DEFAULT.maxOutstandingBulkRequests()
93+
);
94+
}
95+
}
96+
}

0 commit comments

Comments
 (0)