Skip to content

Commit 05d46b6

Browse files
committed
Create inference operator for dense vector embedding.
1 parent 3153195 commit 05d46b6

File tree

3 files changed

+325
-0
lines changed

3 files changed

+325
-0
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference.embedding;
9+
10+
import org.elasticsearch.compute.data.Block;
11+
import org.elasticsearch.compute.data.FloatBlock;
12+
import org.elasticsearch.compute.data.Page;
13+
import org.elasticsearch.compute.operator.DriverContext;
14+
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
15+
import org.elasticsearch.compute.operator.Operator;
16+
import org.elasticsearch.core.Releasables;
17+
import org.elasticsearch.threadpool.ThreadPool;
18+
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
19+
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
20+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig;
21+
22+
import java.util.stream.IntStream;
23+
24+
/**
25+
* {@link DenseEmbeddingOperator} is an inference operator that compute vector embeddings from textual data .
26+
*/
27+
public class DenseEmbeddingOperator extends InferenceOperator {
28+
29+
// Default number of rows to include per inference request
30+
private static final int DEFAULT_BATCH_SIZE = 20;
31+
32+
// Encodes each input row into a string representation for the model
33+
private final ExpressionEvaluator inputEvaluator;
34+
35+
// Numbers of dimensions for the vector
36+
private final int dimensions;
37+
38+
// Batch size used to group rows into a single inference request (currently fixed)
39+
// TODO: make it configurable either in the command or as query pragmas
40+
private final int batchSize = DEFAULT_BATCH_SIZE;
41+
42+
public DenseEmbeddingOperator(
43+
DriverContext driverContext,
44+
InferenceRunner inferenceRunner,
45+
ThreadPool threadPool,
46+
int dimensions,
47+
String inferenceId,
48+
ExpressionEvaluator inputEvaluator
49+
) {
50+
super(driverContext, inferenceRunner, BulkInferenceExecutionConfig.DEFAULT, threadPool, inferenceId);
51+
this.dimensions = dimensions;
52+
this.inputEvaluator = inputEvaluator;
53+
}
54+
55+
@Override
56+
public void addInput(Page input) {
57+
try {
58+
Block inputBlock = inputEvaluator.eval(input);
59+
super.addInput(input.appendBlock(inputBlock));
60+
} catch (Exception e) {
61+
releasePageOnAnyThread(input);
62+
throw e;
63+
}
64+
}
65+
66+
@Override
67+
protected void doClose() {
68+
Releasables.close(inputEvaluator);
69+
}
70+
71+
@Override
72+
public String toString() {
73+
return "DenseEmbeddingOperator[inference_id=[" + inferenceId() + "]]";
74+
}
75+
76+
/**
77+
* Returns the request iterator responsible for batching and converting input rows into inference requests.
78+
*/
79+
@Override
80+
protected DenseEmbeddingRequestIterator requests(Page inputPage) {
81+
int inputBlockChannel = inputPage.getBlockCount() - 1;
82+
return new DenseEmbeddingRequestIterator(inputPage.getBlock(inputBlockChannel), inferenceId(), batchSize);
83+
}
84+
85+
/**
86+
* Returns the output builder responsible for collecting inference responses and building the output page.
87+
*/
88+
@Override
89+
protected DenseEmbeddingOperatorOutputBuilder outputBuilder(Page input) {
90+
FloatBlock.Builder outputBlockBuilder = blockFactory().newFloatBlockBuilder(input.getPositionCount() * dimensions);
91+
return new DenseEmbeddingOperatorOutputBuilder(
92+
outputBlockBuilder,
93+
input.projectBlocks(IntStream.range(0, input.getBlockCount() - 1).toArray()),
94+
dimensions
95+
);
96+
}
97+
98+
/**
99+
* Factory for creating {@link DenseEmbeddingOperator} instances
100+
*/
101+
public record Factory(InferenceRunner inferenceRunner, int dimensions, String inferenceId,
102+
ExpressionEvaluator.Factory inputEvaluatorFactory) implements OperatorFactory {
103+
104+
@Override
105+
public String describe() {
106+
return "DenseEmbeddingOperator[inference_id=[" + inferenceId + "]]";
107+
}
108+
109+
@Override
110+
public Operator get(DriverContext driverContext) {
111+
return new DenseEmbeddingOperator(
112+
driverContext,
113+
inferenceRunner,
114+
inferenceRunner.threadPool(),
115+
dimensions,
116+
inferenceId,
117+
inputEvaluatorFactory().get(driverContext)
118+
);
119+
}
120+
}
121+
}
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference.embedding;
9+
10+
import org.elasticsearch.compute.data.Block;
11+
import org.elasticsearch.compute.data.FloatBlock;
12+
import org.elasticsearch.compute.data.Page;
13+
import org.elasticsearch.core.Releasables;
14+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
15+
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
16+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
17+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
18+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
19+
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
20+
21+
import java.util.Iterator;
22+
import java.util.stream.IntStream;
23+
24+
/**
25+
* Builds the output page for the {@link DenseEmbeddingOperator}.
26+
*/
27+
public class DenseEmbeddingOperatorOutputBuilder implements InferenceOperator.OutputBuilder {
28+
29+
30+
private final FloatBlock.Builder outputBlockBuilder;
31+
private final Page inputPage;
32+
private final int dimensions;
33+
34+
public DenseEmbeddingOperatorOutputBuilder(FloatBlock.Builder outputBlockBuilder, Page inputPage, int dimensions) {
35+
this.outputBlockBuilder = outputBlockBuilder;
36+
this.inputPage = inputPage;
37+
this.dimensions = dimensions;
38+
}
39+
40+
@Override
41+
public void close() {
42+
Releasables.close(outputBlockBuilder);
43+
releasePageOnAnyThread(inputPage);
44+
}
45+
46+
/**
47+
* Constructs a new output {@link Page} with dense embedding in the last column.
48+
*/
49+
@Override
50+
public Page buildOutput() {
51+
Block outputBlock = outputBlockBuilder.build();
52+
assert outputBlock.getPositionCount() == inputPage.getPositionCount();
53+
return inputPage.shallowCopy().appendBlock(outputBlock);
54+
}
55+
56+
/**
57+
* Extracts the embedding results from the inference response and append them to the output block builder.
58+
* <p>
59+
* If the response is not of type {@link TextEmbeddingResults} an {@link IllegalStateException} is thrown.
60+
* </p>
61+
* <p>
62+
* The responses must be added in the same order as the corresponding inference requests were generated.
63+
* Failing to preserve order may lead to incorrect or misaligned output rows.
64+
* </p>
65+
*/
66+
@Override
67+
public void addInferenceResponse(InferenceAction.Response inferenceResponse) {
68+
EmbeddingValueReader embeddingValueReader = EmbeddingValueReader.of(inferenceResponse, dimensions);
69+
while (embeddingValueReader.hasNext()) {
70+
writeEmbeddings(embeddingValueReader.next());
71+
}
72+
}
73+
74+
private void writeEmbeddings(float[] values) {
75+
outputBlockBuilder.beginPositionEntry();
76+
for (float value : values) {
77+
outputBlockBuilder.appendFloat(value);
78+
}
79+
outputBlockBuilder.endPositionEntry();
80+
}
81+
82+
private static class EmbeddingValueReader implements Iterator<float[]> {
83+
private final int dimensions;
84+
85+
private final Iterator<? extends EmbeddingResults.Embedding<?>> embeddingsIterator;
86+
87+
private EmbeddingValueReader(Iterator<? extends EmbeddingResults.Embedding<?>> embeddingsIterator, int dimensions) {
88+
this.dimensions = dimensions;
89+
this.embeddingsIterator = embeddingsIterator;
90+
}
91+
92+
public boolean hasNext() {
93+
return embeddingsIterator.hasNext();
94+
}
95+
96+
public float[] next() {
97+
EmbeddingResults.Embedding<?> embedding = embeddingsIterator.next();
98+
float[] values = switch(embedding) {
99+
case TextEmbeddingFloatResults.Embedding textEmbeddingFloat -> textEmbeddingFloat.values();
100+
case TextEmbeddingByteResults.Embedding textEmbeddingBytes -> toFloatArray(textEmbeddingBytes.values());
101+
default -> throw new IllegalStateException("Unsupported embedding type [" + embedding.getClass() + "]");
102+
};
103+
104+
assert values.length == dimensions : "Unexpected vector size: " + values.length ;
105+
106+
return values;
107+
}
108+
109+
private static float[] toFloatArray(byte[] bytes) {
110+
float[] floatValues = new float[bytes.length];
111+
IntStream.range(0, floatValues.length).forEach(i -> floatValues[i] = ((Byte) bytes[i]).floatValue());
112+
return floatValues;
113+
}
114+
115+
116+
public static EmbeddingValueReader of(InferenceAction.Response inferenceResponse, int dimensions) {
117+
TextEmbeddingResults<?> inferenceResults = InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, TextEmbeddingResults.class);
118+
return new EmbeddingValueReader(inferenceResults.embeddings().iterator(), dimensions);
119+
}
120+
}
121+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference.embedding;
9+
10+
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.common.lucene.BytesRefs;
12+
import org.elasticsearch.compute.data.BytesRefBlock;
13+
import org.elasticsearch.inference.TaskType;
14+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
15+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
16+
17+
import java.util.ArrayList;
18+
import java.util.List;
19+
import java.util.NoSuchElementException;
20+
21+
/**
22+
* Iterator over input data blocks to create batched inference requests for the dense vector text embedding task.
23+
*
24+
* <p>This iterator reads from a {@link BytesRefBlock} containing text to be embedded. It slices the input into batches
25+
* of configurable size and converts each batch into an {@link InferenceAction.Request} with the task type {@link TaskType#TEXT_EMBEDDING}.
26+
*/
27+
public class DenseEmbeddingRequestIterator implements BulkInferenceRequestIterator {
28+
private final BytesRefBlock inputBlock;
29+
private final String inferenceId;
30+
private final int batchSize;
31+
private int remainingPositions;
32+
33+
public DenseEmbeddingRequestIterator(BytesRefBlock inputBlock, String inferenceId, int batchSize) {
34+
this.inputBlock = inputBlock;
35+
this.inferenceId = inferenceId;
36+
this.batchSize = batchSize;
37+
this.remainingPositions = inputBlock.getPositionCount();
38+
}
39+
40+
@Override
41+
public boolean hasNext() {
42+
return remainingPositions > 0;
43+
}
44+
45+
@Override
46+
public InferenceAction.Request next() {
47+
if (hasNext() == false) {
48+
throw new NoSuchElementException();
49+
}
50+
51+
final int inputSize = Math.min(remainingPositions, batchSize);
52+
final List<String> inputs = new ArrayList<>(inputSize);
53+
BytesRef scratch = new BytesRef();
54+
55+
int startIndex = inputBlock.getPositionCount() - remainingPositions;
56+
for (int i = 0; i < inputSize; i++) {
57+
int pos = startIndex + i;
58+
if (inputBlock.isNull(pos)) {
59+
inputs.add("");
60+
} else {
61+
scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(pos), scratch);
62+
inputs.add(BytesRefs.toString(scratch));
63+
}
64+
}
65+
66+
remainingPositions -= inputSize;
67+
return inferenceRequest(inputs);
68+
}
69+
70+
@Override
71+
public int estimatedSize() {
72+
return inputBlock.getPositionCount();
73+
}
74+
75+
private InferenceAction.Request inferenceRequest(List<String> inputs) {
76+
return InferenceAction.Request.builder(inferenceId, TaskType.TEXT_EMBEDDING).setInput(inputs).build();
77+
}
78+
79+
@Override
80+
public void close() {
81+
82+
}
83+
}

0 commit comments

Comments
 (0)