Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.inference;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;

/**
* Helper class that reads text strings from a {@link BytesRefBlock}.
* This class is used by inference operators to extract text content from block data.
*/
public class InputTextReader implements Releasable {
private final BytesRefBlock textBlock;
private final StringBuilder strBuilder = new StringBuilder();
private BytesRef readBuffer = new BytesRef();

public InputTextReader(BytesRefBlock textBlock) {
this.textBlock = textBlock;
}

/**
* Reads the text string at the given position.
* Multiple values at the position are concatenated with newlines.
*
* @param pos the position index in the block
* @return the text string at the position, or null if the position contains a null value
*/
public String readText(int pos) {
return readText(pos, Integer.MAX_VALUE);
}

/**
* Reads the text string at the given position.
*
* @param pos the position index in the block
* @param limit the maximum number of value to read from the position
* @return the text string at the position, or null if the position contains a null value
*/
public String readText(int pos, int limit) {
if (textBlock.isNull(pos)) {
return null;
}

strBuilder.setLength(0);
int maxPos = Math.min(limit, textBlock.getValueCount(pos));
for (int valueIndex = 0; valueIndex < maxPos; valueIndex++) {
readBuffer = textBlock.getBytesRef(textBlock.getFirstValueIndex(pos) + valueIndex, readBuffer);
strBuilder.append(readBuffer.utf8ToString());
if (valueIndex != maxPos - 1) {
strBuilder.append("\n");
}
}

return strBuilder.toString();
}

/**
* Returns the total number of positions (text entries) in the block.
*/
public int estimatedSize() {
return textBlock.getPositionCount();
}

@Override
public void close() {
textBlock.allowPassingToDifferentDriver();
Releasables.close(textBlock);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@

package org.elasticsearch.xpack.esql.inference.completion;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.esql.inference.InputTextReader;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;

import java.util.List;
Expand All @@ -24,7 +23,7 @@
*/
public class CompletionOperatorRequestIterator implements BulkInferenceRequestIterator {

private final PromptReader promptReader;
private final InputTextReader textReader;
private final String inferenceId;
private final int size;
private int currentPos = 0;
Expand All @@ -36,7 +35,7 @@ public class CompletionOperatorRequestIterator implements BulkInferenceRequestIt
* @param inferenceId The ID of the inference model to invoke.
*/
public CompletionOperatorRequestIterator(BytesRefBlock promptBlock, String inferenceId) {
this.promptReader = new PromptReader(promptBlock);
this.textReader = new InputTextReader(promptBlock);
this.size = promptBlock.getPositionCount();
this.inferenceId = inferenceId;
}
Expand All @@ -52,7 +51,7 @@ public InferenceAction.Request next() {
throw new NoSuchElementException();
}

return inferenceRequest(promptReader.readPrompt(currentPos++));
return inferenceRequest(textReader.readText(currentPos++));
}

/**
Expand All @@ -68,60 +67,11 @@ private InferenceAction.Request inferenceRequest(String prompt) {

@Override
public int estimatedSize() {
return promptReader.estimatedSize();
return textReader.estimatedSize();
}

@Override
public void close() {
Releasables.close(promptReader);
}

/**
* Helper class that reads prompts from a {@link BytesRefBlock}.
*/
private static class PromptReader implements Releasable {
private final BytesRefBlock promptBlock;
private final StringBuilder strBuilder = new StringBuilder();
private BytesRef readBuffer = new BytesRef();

private PromptReader(BytesRefBlock promptBlock) {
this.promptBlock = promptBlock;
}

/**
* Reads the prompt string at the given position..
*
* @param pos the position index in the block
*/
public String readPrompt(int pos) {
if (promptBlock.isNull(pos)) {
return null;
}

strBuilder.setLength(0);

for (int valueIndex = 0; valueIndex < promptBlock.getValueCount(pos); valueIndex++) {
readBuffer = promptBlock.getBytesRef(promptBlock.getFirstValueIndex(pos) + valueIndex, readBuffer);
strBuilder.append(readBuffer.utf8ToString());
if (valueIndex != promptBlock.getValueCount(pos) - 1) {
strBuilder.append("\n");
}
}

return strBuilder.toString();
}

/**
* Returns the total number of positions (prompts) in the block.
*/
public int estimatedSize() {
return promptBlock.getPositionCount();
}

@Override
public void close() {
promptBlock.allowPassingToDifferentDriver();
Releasables.close(promptBlock);
}
Releasables.close(textReader);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.inference.textembedding;

import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.FloatBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
import org.elasticsearch.xpack.esql.inference.InferenceService;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig;

/**
* {@link TextEmbeddingOperator} is an {@link InferenceOperator} that performs text embedding inference.
* It evaluates a text expression for each input row, constructs text embedding inference requests,
* and emits the dense vector embeddings as output.
*/
public class TextEmbeddingOperator extends InferenceOperator {

private final ExpressionEvaluator textEvaluator;

public TextEmbeddingOperator(
DriverContext driverContext,
BulkInferenceRunner bulkInferenceRunner,
String inferenceId,
ExpressionEvaluator textEvaluator,
int maxOutstandingPages
) {
super(driverContext, bulkInferenceRunner, inferenceId, maxOutstandingPages);
this.textEvaluator = textEvaluator;
}

@Override
protected void doClose() {
Releasables.close(textEvaluator);
}

@Override
public String toString() {
return "TextEmbeddingOperator[inference_id=[" + inferenceId() + "]]";
}

/**
* Constructs the text embedding inference requests iterator for the given input page by evaluating the text expression.
*
* @param inputPage The input data page.
*/
@Override
protected BulkInferenceRequestIterator requests(Page inputPage) {
return new TextEmbeddingOperatorRequestIterator((BytesRefBlock) textEvaluator.eval(inputPage), inferenceId());
}

/**
* Creates a new {@link TextEmbeddingOperatorOutputBuilder} to collect and emit the text embedding results.
*
* @param input The input page for which results will be constructed.
*/
@Override
protected TextEmbeddingOperatorOutputBuilder outputBuilder(Page input) {
FloatBlock.Builder outputBlockBuilder = blockFactory().newFloatBlockBuilder(input.getPositionCount());
return new TextEmbeddingOperatorOutputBuilder(outputBlockBuilder, input);
}

/**
* Factory for creating {@link TextEmbeddingOperator} instances.
*/
public record Factory(InferenceService inferenceService, String inferenceId, ExpressionEvaluator.Factory textEvaluatorFactory)
implements
OperatorFactory {
@Override
public String describe() {
return "TextEmbeddingOperator[inference_id=[" + inferenceId + "]]";
}

@Override
public Operator get(DriverContext driverContext) {
return new TextEmbeddingOperator(
driverContext,
inferenceService.bulkInferenceRunner(),
inferenceId,
textEvaluatorFactory.get(driverContext),
BulkInferenceRunnerConfig.DEFAULT.maxOutstandingBulkRequests()
);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.inference.textembedding;

import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.FloatBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.esql.inference.InferenceOperator;

/**
* {@link TextEmbeddingOperatorOutputBuilder} builds the output page for text embedding by converting
* {@link TextEmbeddingResults} into a {@link FloatBlock} containing dense vector embeddings.
*/
public class TextEmbeddingOperatorOutputBuilder implements InferenceOperator.OutputBuilder {
private final Page inputPage;
private final FloatBlock.Builder outputBlockBuilder;

public TextEmbeddingOperatorOutputBuilder(FloatBlock.Builder outputBlockBuilder, Page inputPage) {
this.inputPage = inputPage;
this.outputBlockBuilder = outputBlockBuilder;
}

@Override
public void close() {
Releasables.close(outputBlockBuilder);
}

/**
* Adds an inference response to the output builder.
*
* <p>
* If the response is null or not of type {@link TextEmbeddingResults} an {@link IllegalStateException} is thrown.
* Else, the embedding vector is added to the output block as a multi-value position.
* </p>
*
* <p>
* The responses must be added in the same order as the corresponding inference requests were generated.
* Failing to preserve order may lead to incorrect or misaligned output rows.
* </p>
*/
@Override
public void addInferenceResponse(InferenceAction.Response inferenceResponse) {
if (inferenceResponse == null) {
outputBlockBuilder.appendNull();
return;
}

TextEmbeddingResults<?> embeddingResults = inferenceResults(inferenceResponse);

var embeddings = embeddingResults.embeddings();
if (embeddings.isEmpty()) {
outputBlockBuilder.appendNull();
return;
}

float[] embeddingArray = getEmbeddingAsFloatArray(embeddingResults);

outputBlockBuilder.beginPositionEntry();
for (float component : embeddingArray) {
outputBlockBuilder.appendFloat(component);
}
outputBlockBuilder.endPositionEntry();
}

/**
* Builds the final output page by appending the embedding output block to the input page.
*/
@Override
public Page buildOutput() {
Block outputBlock = outputBlockBuilder.build();
assert outputBlock.getPositionCount() == inputPage.getPositionCount();
return inputPage.appendBlock(outputBlock);
}

private TextEmbeddingResults<?> inferenceResults(InferenceAction.Response inferenceResponse) {
return InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, TextEmbeddingResults.class);
}

/**
* Extracts the embedding as a float array from the embedding result.
*/
private float[] getEmbeddingAsFloatArray(TextEmbeddingResults<?> embedding) {
return switch (embedding.embeddings().get(0)) {
case TextEmbeddingFloatResults.Embedding floatEmbedding -> floatEmbedding.values();
case TextEmbeddingByteResults.Embedding byteEmbedding -> toFloatArray(byteEmbedding.values());
default -> throw new IllegalArgumentException(
"Unsupported embedding type: "
+ embedding.embeddings().get(0).getClass().getName()
+ ". Expected TextEmbeddingFloatResults.Embedding or TextEmbeddingByteResults.Embedding."
);
};
}

private float[] toFloatArray(byte[] values) {
float[] floatArray = new float[values.length];
for (int i = 0; i < values.length; i++) {
floatArray[i] = ((Byte) values[i]).floatValue();
}
return floatArray;
}
}
Loading