From 05d46b61c85e0f1bb4b8fda9a6e3b991ab79de73 Mon Sep 17 00:00:00 2001 From: afoucret Date: Mon, 30 Jun 2025 14:57:16 +0200 Subject: [PATCH 01/12] Create inference operator for dense vector embedding. --- .../embedding/DenseEmbeddingOperator.java | 121 ++++++++++++++++++ .../DenseEmbeddingOperatorOutputBuilder.java | 121 ++++++++++++++++++ .../DenseEmbeddingRequestIterator.java | 83 ++++++++++++ 3 files changed, 325 insertions(+) create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperator.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilder.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingRequestIterator.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperator.java new file mode 100644 index 0000000000000..e0575c8cd34fd --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperator.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.embedding; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.esql.inference.InferenceOperator; +import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig; + +import java.util.stream.IntStream; + +/** + * {@link DenseEmbeddingOperator} is an inference operator that compute vector embeddings from textual data . + */ +public class DenseEmbeddingOperator extends InferenceOperator { + + // Default number of rows to include per inference request + private static final int DEFAULT_BATCH_SIZE = 20; + + // Encodes each input row into a string representation for the model + private final ExpressionEvaluator inputEvaluator; + + // Numbers of dimensions for the vector + private final int dimensions; + + // Batch size used to group rows into a single inference request (currently fixed) + // TODO: make it configurable either in the command or as query pragmas + private final int batchSize = DEFAULT_BATCH_SIZE; + + public DenseEmbeddingOperator( + DriverContext driverContext, + InferenceRunner inferenceRunner, + ThreadPool threadPool, + int dimensions, + String inferenceId, + ExpressionEvaluator inputEvaluator + ) { + super(driverContext, inferenceRunner, BulkInferenceExecutionConfig.DEFAULT, threadPool, inferenceId); + this.dimensions = dimensions; + this.inputEvaluator = inputEvaluator; + } + + @Override + public void addInput(Page input) { + try { + Block inputBlock = inputEvaluator.eval(input); + super.addInput(input.appendBlock(inputBlock)); + } catch (Exception e) { + releasePageOnAnyThread(input); + throw e; + } + } + + @Override + protected void doClose() { + Releasables.close(inputEvaluator); + } + + @Override + public String toString() { + return "DenseEmbeddingOperator[inference_id=[" + inferenceId() + "]]"; + } + + /** + * Returns the request iterator responsible for batching and converting input rows into inference requests. + */ + @Override + protected DenseEmbeddingRequestIterator requests(Page inputPage) { + int inputBlockChannel = inputPage.getBlockCount() - 1; + return new DenseEmbeddingRequestIterator(inputPage.getBlock(inputBlockChannel), inferenceId(), batchSize); + } + + /** + * Returns the output builder responsible for collecting inference responses and building the output page. + */ + @Override + protected DenseEmbeddingOperatorOutputBuilder outputBuilder(Page input) { + FloatBlock.Builder outputBlockBuilder = blockFactory().newFloatBlockBuilder(input.getPositionCount() * dimensions); + return new DenseEmbeddingOperatorOutputBuilder( + outputBlockBuilder, + input.projectBlocks(IntStream.range(0, input.getBlockCount() - 1).toArray()), + dimensions + ); + } + + /** + * Factory for creating {@link DenseEmbeddingOperator} instances + */ + public record Factory(InferenceRunner inferenceRunner, int dimensions, String inferenceId, + ExpressionEvaluator.Factory inputEvaluatorFactory) implements OperatorFactory { + + @Override + public String describe() { + return "DenseEmbeddingOperator[inference_id=[" + inferenceId + "]]"; + } + + @Override + public Operator get(DriverContext driverContext) { + return new DenseEmbeddingOperator( + driverContext, + inferenceRunner, + inferenceRunner.threadPool(), + dimensions, + inferenceId, + inputEvaluatorFactory().get(driverContext) + ); + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilder.java new file mode 100644 index 0000000000000..deee6d34c2448 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilder.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.embedding; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.esql.inference.InferenceOperator; + +import java.util.Iterator; +import java.util.stream.IntStream; + +/** + * Builds the output page for the {@link DenseEmbeddingOperator}. + */ +public class DenseEmbeddingOperatorOutputBuilder implements InferenceOperator.OutputBuilder { + + + private final FloatBlock.Builder outputBlockBuilder; + private final Page inputPage; + private final int dimensions; + + public DenseEmbeddingOperatorOutputBuilder(FloatBlock.Builder outputBlockBuilder, Page inputPage, int dimensions) { + this.outputBlockBuilder = outputBlockBuilder; + this.inputPage = inputPage; + this.dimensions = dimensions; + } + + @Override + public void close() { + Releasables.close(outputBlockBuilder); + releasePageOnAnyThread(inputPage); + } + + /** + * Constructs a new output {@link Page} with dense embedding in the last column. + */ + @Override + public Page buildOutput() { + Block outputBlock = outputBlockBuilder.build(); + assert outputBlock.getPositionCount() == inputPage.getPositionCount(); + return inputPage.shallowCopy().appendBlock(outputBlock); + } + + /** + * Extracts the embedding results from the inference response and append them to the output block builder. + *

+ * If the response is not of type {@link TextEmbeddingResults} an {@link IllegalStateException} is thrown. + *

+ *

+ * The responses must be added in the same order as the corresponding inference requests were generated. + * Failing to preserve order may lead to incorrect or misaligned output rows. + *

+ */ + @Override + public void addInferenceResponse(InferenceAction.Response inferenceResponse) { + EmbeddingValueReader embeddingValueReader = EmbeddingValueReader.of(inferenceResponse, dimensions); + while (embeddingValueReader.hasNext()) { + writeEmbeddings(embeddingValueReader.next()); + } + } + + private void writeEmbeddings(float[] values) { + outputBlockBuilder.beginPositionEntry(); + for (float value : values) { + outputBlockBuilder.appendFloat(value); + } + outputBlockBuilder.endPositionEntry(); + } + + private static class EmbeddingValueReader implements Iterator { + private final int dimensions; + + private final Iterator> embeddingsIterator; + + private EmbeddingValueReader(Iterator> embeddingsIterator, int dimensions) { + this.dimensions = dimensions; + this.embeddingsIterator = embeddingsIterator; + } + + public boolean hasNext() { + return embeddingsIterator.hasNext(); + } + + public float[] next() { + EmbeddingResults.Embedding embedding = embeddingsIterator.next(); + float[] values = switch(embedding) { + case TextEmbeddingFloatResults.Embedding textEmbeddingFloat -> textEmbeddingFloat.values(); + case TextEmbeddingByteResults.Embedding textEmbeddingBytes -> toFloatArray(textEmbeddingBytes.values()); + default -> throw new IllegalStateException("Unsupported embedding type [" + embedding.getClass() + "]"); + }; + + assert values.length == dimensions : "Unexpected vector size: " + values.length ; + + return values; + } + + private static float[] toFloatArray(byte[] bytes) { + float[] floatValues = new float[bytes.length]; + IntStream.range(0, floatValues.length).forEach(i -> floatValues[i] = ((Byte) bytes[i]).floatValue()); + return floatValues; + } + + + public static EmbeddingValueReader of(InferenceAction.Response inferenceResponse, int dimensions) { + TextEmbeddingResults inferenceResults = InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, TextEmbeddingResults.class); + return new EmbeddingValueReader(inferenceResults.embeddings().iterator(), dimensions); + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingRequestIterator.java new file mode 100644 index 0000000000000..069a97dbe3612 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingRequestIterator.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.embedding; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.lucene.BytesRefs; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; + +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; + +/** + * Iterator over input data blocks to create batched inference requests for the dense vector text embedding task. + * + *

This iterator reads from a {@link BytesRefBlock} containing text to be embedded. It slices the input into batches + * of configurable size and converts each batch into an {@link InferenceAction.Request} with the task type {@link TaskType#TEXT_EMBEDDING}. + */ +public class DenseEmbeddingRequestIterator implements BulkInferenceRequestIterator { + private final BytesRefBlock inputBlock; + private final String inferenceId; + private final int batchSize; + private int remainingPositions; + + public DenseEmbeddingRequestIterator(BytesRefBlock inputBlock, String inferenceId, int batchSize) { + this.inputBlock = inputBlock; + this.inferenceId = inferenceId; + this.batchSize = batchSize; + this.remainingPositions = inputBlock.getPositionCount(); + } + + @Override + public boolean hasNext() { + return remainingPositions > 0; + } + + @Override + public InferenceAction.Request next() { + if (hasNext() == false) { + throw new NoSuchElementException(); + } + + final int inputSize = Math.min(remainingPositions, batchSize); + final List inputs = new ArrayList<>(inputSize); + BytesRef scratch = new BytesRef(); + + int startIndex = inputBlock.getPositionCount() - remainingPositions; + for (int i = 0; i < inputSize; i++) { + int pos = startIndex + i; + if (inputBlock.isNull(pos)) { + inputs.add(""); + } else { + scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(pos), scratch); + inputs.add(BytesRefs.toString(scratch)); + } + } + + remainingPositions -= inputSize; + return inferenceRequest(inputs); + } + + @Override + public int estimatedSize() { + return inputBlock.getPositionCount(); + } + + private InferenceAction.Request inferenceRequest(List inputs) { + return InferenceAction.Request.builder(inferenceId, TaskType.TEXT_EMBEDDING).setInput(inputs).build(); + } + + @Override + public void close() { + + } +} From 66e06241c9e7a52fd8cad6e06e22f96fba7a5ac8 Mon Sep 17 00:00:00 2001 From: afoucret Date: Tue, 1 Jul 2025 11:50:04 +0200 Subject: [PATCH 02/12] Add tests for the dense vector operator. --- .../embedding/DenseEmbeddingOperator.java | 12 +- .../DenseEmbeddingOperatorOutputBuilder.java | 11 +- ...enseEmbeddingOperatorRequestIterator.java} | 4 +- .../inference/InferenceOperatorTestCase.java | 2 +- ...seEmbeddingOperatorOutputBuilderTests.java | 160 ++++++++++++++++++ ...EmbeddingOperatorRequestIteratorTests.java | 65 +++++++ .../DenseEmbeddingOperatorTests.java | 160 ++++++++++++++++++ .../embedding/DenseEmbeddingUtils.java | 88 ++++++++++ 8 files changed, 490 insertions(+), 12 deletions(-) rename x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/{DenseEmbeddingRequestIterator.java => DenseEmbeddingOperatorRequestIterator.java} (93%) create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilderTests.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorRequestIteratorTests.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorTests.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingUtils.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperator.java index e0575c8cd34fd..5b006a07855fe 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperator.java @@ -77,9 +77,9 @@ public String toString() { * Returns the request iterator responsible for batching and converting input rows into inference requests. */ @Override - protected DenseEmbeddingRequestIterator requests(Page inputPage) { + protected DenseEmbeddingOperatorRequestIterator requests(Page inputPage) { int inputBlockChannel = inputPage.getBlockCount() - 1; - return new DenseEmbeddingRequestIterator(inputPage.getBlock(inputBlockChannel), inferenceId(), batchSize); + return new DenseEmbeddingOperatorRequestIterator(inputPage.getBlock(inputBlockChannel), inferenceId(), batchSize); } /** @@ -98,8 +98,12 @@ protected DenseEmbeddingOperatorOutputBuilder outputBuilder(Page input) { /** * Factory for creating {@link DenseEmbeddingOperator} instances */ - public record Factory(InferenceRunner inferenceRunner, int dimensions, String inferenceId, - ExpressionEvaluator.Factory inputEvaluatorFactory) implements OperatorFactory { + public record Factory( + InferenceRunner inferenceRunner, + int dimensions, + String inferenceId, + ExpressionEvaluator.Factory inputEvaluatorFactory + ) implements OperatorFactory { @Override public String describe() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilder.java index deee6d34c2448..de262cb4155f6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilder.java @@ -26,7 +26,6 @@ */ public class DenseEmbeddingOperatorOutputBuilder implements InferenceOperator.OutputBuilder { - private final FloatBlock.Builder outputBlockBuilder; private final Page inputPage; private final int dimensions; @@ -95,13 +94,13 @@ public boolean hasNext() { public float[] next() { EmbeddingResults.Embedding embedding = embeddingsIterator.next(); - float[] values = switch(embedding) { + float[] values = switch (embedding) { case TextEmbeddingFloatResults.Embedding textEmbeddingFloat -> textEmbeddingFloat.values(); case TextEmbeddingByteResults.Embedding textEmbeddingBytes -> toFloatArray(textEmbeddingBytes.values()); default -> throw new IllegalStateException("Unsupported embedding type [" + embedding.getClass() + "]"); }; - assert values.length == dimensions : "Unexpected vector size: " + values.length ; + assert values.length == dimensions : "Unexpected vector size: " + values.length; return values; } @@ -112,9 +111,11 @@ private static float[] toFloatArray(byte[] bytes) { return floatValues; } - public static EmbeddingValueReader of(InferenceAction.Response inferenceResponse, int dimensions) { - TextEmbeddingResults inferenceResults = InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, TextEmbeddingResults.class); + TextEmbeddingResults inferenceResults = InferenceOperator.OutputBuilder.inferenceResults( + inferenceResponse, + TextEmbeddingResults.class + ); return new EmbeddingValueReader(inferenceResults.embeddings().iterator(), dimensions); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorRequestIterator.java similarity index 93% rename from x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingRequestIterator.java rename to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorRequestIterator.java index 069a97dbe3612..e1f3a3c4ebec1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingRequestIterator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorRequestIterator.java @@ -24,13 +24,13 @@ *

This iterator reads from a {@link BytesRefBlock} containing text to be embedded. It slices the input into batches * of configurable size and converts each batch into an {@link InferenceAction.Request} with the task type {@link TaskType#TEXT_EMBEDDING}. */ -public class DenseEmbeddingRequestIterator implements BulkInferenceRequestIterator { +public class DenseEmbeddingOperatorRequestIterator implements BulkInferenceRequestIterator { private final BytesRefBlock inputBlock; private final String inferenceId; private final int batchSize; private int remainingPositions; - public DenseEmbeddingRequestIterator(BytesRefBlock inputBlock, String inferenceId, int batchSize) { + public DenseEmbeddingOperatorRequestIterator(BytesRefBlock inputBlock, String inferenceId, int batchSize) { this.inputBlock = inputBlock; this.inferenceId = inferenceId; this.batchSize = batchSize; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java index c49e301968aa0..9a000778d4003 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java @@ -93,7 +93,7 @@ protected Page createPage(int positionOffset, int length) { if (randomInt() % 100 == 0) { builder.appendNull(); } else { - builder.appendBytesRef(new BytesRef(randomAlphaOfLength(10))); + builder.appendBytesRef(new BytesRef(randomAlphaOfLength(randomInt(10)))); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilderTests.java new file mode 100644 index 0000000000000..c4e754dacffcc --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilderTests.java @@ -0,0 +1,160 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.embedding; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.test.ComputeTestCase; +import org.elasticsearch.compute.test.RandomBlock; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; + +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class DenseEmbeddingOperatorOutputBuilderTests extends ComputeTestCase { + + private final static List DIMENSIONS = List.of(1, 32, 128, 512, 2048, 5096); + private final static List INPUT_SIZES = List.of(10, 100, 1_000, 10_000); + private final static List BATCH_SIZES = List.of(1, 10, 100, 1000); + private final static List>> EMBEDDING_TYPES = List.of( + TextEmbeddingBitResults.class, + TextEmbeddingByteResults.class, + TextEmbeddingFloatResults.class + ); + + private final static String TEST_PARAMS_FORMATING = "dims=%d, input_size=%d, batch_size=%d, embedding_type=%s"; + + private final int dimensions; + private final int inputPageSize; + private final int batchSize; + private final Class> embeddingType; + + @ParametersFactory(argumentFormatting = TEST_PARAMS_FORMATING) + public static Iterable parameters() { + List params = new ArrayList<>(); + params.add(new Object[] {}); + + for (List axis : List.of(DIMENSIONS, INPUT_SIZES, BATCH_SIZES, EMBEDDING_TYPES)) { + + List newParams = new ArrayList<>(); + for (Object[] combination : params) { + for (Object element : axis) { + Object[] newCombination = new Object[combination.length + 1]; + System.arraycopy(combination, 0, newCombination, 0, combination.length); + newCombination[newCombination.length - 1] = element; + newParams.add(newCombination); + } + } + params = newParams; + } + + return params; + } + + public DenseEmbeddingOperatorOutputBuilderTests( + int dimensions, + int inputPageSize, + int batchSize, + Class embeddingType + ) { + this.dimensions = dimensions; + this.inputPageSize = inputPageSize; + this.batchSize = batchSize; + this.embeddingType = embeddingType; + } + + public void testOutput() { + final Page inputPage = randomInputPage(inputPageSize, between(1, 20)); + try ( + DenseEmbeddingOperatorOutputBuilder outputBuilder = new DenseEmbeddingOperatorOutputBuilder( + blockFactory().newFloatBlockBuilder(inputPageSize * dimensions), + inputPage, + dimensions + ) + ) { + for (int currentPos = 0; currentPos < inputPage.getPositionCount(); currentPos += batchSize) { + outputBuilder.addInferenceResponse( + DenseEmbeddingUtils.inferenceResponse( + currentPos, + Math.min(inputPage.getPositionCount() - currentPos, batchSize), + embeddingType, + dimensions + ) + ); + } + + final Page outputPage = outputBuilder.buildOutput(); + try { + assertThat(outputPage.getPositionCount(), equalTo(inputPage.getPositionCount())); + assertThat(outputPage.getBlockCount(), equalTo(inputPage.getBlockCount() + 1)); + assertOutputContent(outputPage.getBlock(inputPage.getBlockCount())); + } finally { + outputPage.releaseBlocks(); + } + + } finally { + inputPage.releaseBlocks(); + } + } + + private void assertOutputContent(FloatBlock block) { + + for (int currentPos = 0; currentPos < block.getPositionCount(); currentPos++) { + assertThat(block.isNull(currentPos), equalTo(false)); + assertThat(block.getValueCount(currentPos), equalTo(dimensions)); + float[] expectedEmbeddingValues = new float[dimensions]; + if (embeddingType.equals(TextEmbeddingByteResults.class) || embeddingType.equals(TextEmbeddingBitResults.class)) { + byte[] bytesValues = embeddingType.equals(TextEmbeddingBitResults.class) + ? DenseEmbeddingUtils.toBitArray(currentPos, dimensions) + : DenseEmbeddingUtils.toByteArray(currentPos, dimensions); + assertThat(bytesValues.length, equalTo(dimensions)); + for (int i = 0; i < bytesValues.length; i++) { + expectedEmbeddingValues[i] = bytesValues[i]; + } + } else if (embeddingType.equals(TextEmbeddingFloatResults.class)) { + expectedEmbeddingValues = DenseEmbeddingUtils.toFloatArray(currentPos, dimensions); + } + + for (int valueIndex = 0; valueIndex < block.getValueCount(currentPos); valueIndex++) { + assertThat(block.getFloat(block.getFirstValueIndex(currentPos) + valueIndex), equalTo(expectedEmbeddingValues[valueIndex])); + } + } + } + + private Page randomInputPage(int positionCount, int columnCount) { + final Block[] blocks = new Block[columnCount]; + try { + for (int i = 0; i < columnCount; i++) { + blocks[i] = RandomBlock.randomBlock( + blockFactory(), + RandomBlock.randomElementType(), + positionCount, + randomBoolean(), + 0, + 0, + randomInt(10), + randomInt(10) + ).block(); + } + + return new Page(blocks); + } catch (Exception e) { + Releasables.close(blocks); + throw (e); + } + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorRequestIteratorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorRequestIteratorTests.java new file mode 100644 index 0000000000000..7570dabb40cad --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorRequestIteratorTests.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.embedding; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.test.ComputeTestCase; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; + +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class DenseEmbeddingOperatorRequestIteratorTests extends ComputeTestCase { + + public void testIterateSmallInput() { + assertIterate(between(1, 100), randomIntBetween(1, 1_000)); + } + + public void testIterateLargeInput() { + assertIterate(between(10_000, 100_000), randomIntBetween(1, 1_000)); + } + + private void assertIterate(int size, int batchSize) { + final String inferenceId = randomIdentifier(); + + try ( + BytesRefBlock inputBlock = randomInputBlock(size); + DenseEmbeddingOperatorRequestIterator requestIterator = new DenseEmbeddingOperatorRequestIterator( + inputBlock, + inferenceId, + batchSize + ) + ) { + BytesRef scratch = new BytesRef(); + + for (int currentPos = 0; requestIterator.hasNext();) { + InferenceAction.Request request = requestIterator.next(); + + assertThat(request.getInferenceEntityId(), equalTo(inferenceId)); + List inputs = request.getInput(); + for (String input : inputs) { + scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(currentPos), scratch); + assertThat(input, equalTo(scratch.utf8ToString())); + currentPos++; + } + } + } + } + + private BytesRefBlock randomInputBlock(int size) { + try (BytesRefBlock.Builder blockBuilder = blockFactory().newBytesRefBlockBuilder(size)) { + for (int i = 0; i < size; i++) { + blockBuilder.appendBytesRef(new BytesRef(randomAlphaOfLength(10))); + } + + return blockBuilder.build(); + } + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorTests.java new file mode 100644 index 0000000000000..b2b105271f016 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorTests.java @@ -0,0 +1,160 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.embedding; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.lucene.BytesRefs; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.esql.inference.InferenceOperatorTestCase; +import org.hamcrest.Matcher; + +import java.util.ArrayList; +import java.util.List; + +import static org.elasticsearch.xpack.esql.inference.embedding.DenseEmbeddingUtils.toBitArray; +import static org.elasticsearch.xpack.esql.inference.embedding.DenseEmbeddingUtils.toByteArray; +import static org.elasticsearch.xpack.esql.inference.embedding.DenseEmbeddingUtils.toFloatArray; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class DenseEmbeddingOperatorTests extends InferenceOperatorTestCase> { + + private static final String SIMPLE_INFERENCE_ID = "test_dense_embedding"; + + private static final String TEST_PARAMS_FORMATING = "dims=%s, embedding_type=%s"; + + private final static List DIMENSIONS = List.of(1, 32, 128, 512, 2048, 5096); + private final static List>> EMBEDDING_TYPES = List.of( + TextEmbeddingBitResults.class, + TextEmbeddingByteResults.class, + TextEmbeddingFloatResults.class + ); + + @ParametersFactory(argumentFormatting = TEST_PARAMS_FORMATING) + public static Iterable parameters() { + List params = new ArrayList<>(); + params.add(new Object[] {}); + + for (List axis : List.of(DIMENSIONS, EMBEDDING_TYPES)) { + + List newParams = new ArrayList<>(); + for (Object[] combination : params) { + for (Object element : axis) { + Object[] newCombination = new Object[combination.length + 1]; + System.arraycopy(combination, 0, newCombination, 0, combination.length); + newCombination[newCombination.length - 1] = element; + newParams.add(newCombination); + } + } + params = newParams; + } + + return params; + } + + private final int dimensions; + private final Class> embeddingType; + + public DenseEmbeddingOperatorTests(int dimensions, Class> embeddingType) { + this.dimensions = dimensions; + this.embeddingType = embeddingType; + } + + @Override + protected Operator.OperatorFactory simple(SimpleOptions options) { + return new DenseEmbeddingOperator.Factory(mockedSimpleInferenceRunner(), dimensions, SIMPLE_INFERENCE_ID, evaluatorFactory(0)); + } + + @Override + protected void assertSimpleOutput(List inputPages, List resultPages) { + assertThat(inputPages, hasSize(resultPages.size())); + + for (int pageId = 0; pageId < inputPages.size(); pageId++) { + Page inputPage = inputPages.get(pageId); + Page resultPage = resultPages.get(pageId); + + assertThat(resultPage.getPositionCount(), equalTo(inputPage.getPositionCount())); + assertThat(resultPage.getBlockCount(), equalTo(inputPage.getBlockCount() + 1)); + + for (int channel = 0; channel < inputPage.getBlockCount(); channel++) { + Block inputBlock = inputPage.getBlock(channel); + Block resultBlock = resultPage.getBlock(channel); + + assertThat(resultBlock.getPositionCount(), equalTo(resultPage.getPositionCount())); + assertThat(resultBlock.elementType(), equalTo(inputBlock.elementType())); + + if (channel == 0) { + assertEmbeddingContent((BytesRefBlock) inputBlock, resultPage.getBlock(inputPage.getBlockCount())); + } + } + } + } + + private void assertEmbeddingContent(BytesRefBlock inputBlock, FloatBlock embeddingsBlock) { + BytesRef scratch = new BytesRef(); + for (int position = 0; position < inputBlock.getPositionCount(); position++) { + String textContent = ""; + if (inputBlock.isNull(position) == false) { + scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(position), scratch); + textContent = BytesRefs.toString(scratch); + } + float[] expectedEmbedding = new float[dimensions]; + if (embeddingType.equals(TextEmbeddingFloatResults.class)) { + expectedEmbedding = toFloatArray(textContent.length(), dimensions); + } + if (embeddingType.equals(TextEmbeddingByteResults.class) || embeddingType.equals(TextEmbeddingBitResults.class)) { + byte[] bytesValues = embeddingType.equals(TextEmbeddingBitResults.class) + ? toBitArray(textContent.length(), dimensions) + : toByteArray(textContent.length(), dimensions); + assertThat(bytesValues.length, equalTo(dimensions)); + for (int i = 0; i < bytesValues.length; i++) { + expectedEmbedding[i] = bytesValues[i]; + } + } + + assertThat(embeddingsBlock.isNull(position), equalTo(false)); + assertThat(embeddingsBlock.getValueCount(position), equalTo(dimensions)); + for (int valueIndex = 0; valueIndex < dimensions; valueIndex++) { + assertThat( + embeddingsBlock.getFloat(embeddingsBlock.getFirstValueIndex(position) + valueIndex), + equalTo(expectedEmbedding[valueIndex]) + ); + } + } + } + + @Override + protected Matcher expectedDescriptionOfSimple() { + return expectedToStringOfSimple(); + } + + @Override + protected Matcher expectedToStringOfSimple() { + return equalTo("DenseEmbeddingOperator[inference_id=[" + SIMPLE_INFERENCE_ID + "]]"); + } + + @Override + protected TextEmbeddingResults mockInferenceResult(InferenceAction.Request request) { + return DenseEmbeddingUtils.inferenceResults( + request.getInput().stream().mapToInt(String::length).toArray(), + embeddingType, + dimensions + ); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingUtils.java new file mode 100644 index 0000000000000..193d9975e7281 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingUtils.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.embedding; + +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.IntStream; + +public class DenseEmbeddingUtils { + + public static byte[] toBitArray(int position, int length) { + byte[] bits = new byte[length]; + for (int i = 0; i < 32 && i < length; i++) { + bits[length - 1 - i] = (byte) ((position >> i) & 1); + } + return bits; + } + + public static byte[] toByteArray(int position, int length) { + byte[] digits = new byte[length]; + for (int i = length - 1; i >= 0; i--) { + digits[i] = (byte) (position % 10); + position /= 10; + } + + return digits; + } + + public static float[] toFloatArray(int position, int length) { + float[] floats = new float[length]; + floats[0] = position; + return floats; + } + + public static InferenceAction.Response inferenceResponse( + int startPosition, + int size, + Class> embeddingType, + int dimensions + ) { + return new InferenceAction.Response( + inferenceResults(IntStream.range(startPosition, startPosition + size).toArray(), embeddingType, dimensions) + ); + } + + public static TextEmbeddingResults inferenceResults( + int[] values, + Class> embeddingType, + int dimensions + ) { + if (embeddingType.equals(TextEmbeddingBitResults.class)) { + List embeddings = new ArrayList<>(); + for (int i = 0; i < values.length; i++) { + embeddings.add(new TextEmbeddingByteResults.Embedding(toBitArray(values[i], dimensions))); + } + return new TextEmbeddingBitResults(embeddings); + } + + if (embeddingType.equals(TextEmbeddingByteResults.class)) { + List embeddings = new ArrayList<>(); + for (int i = 0; i < values.length; i++) { + embeddings.add(new TextEmbeddingByteResults.Embedding(toByteArray(values[i], dimensions))); + } + return new TextEmbeddingByteResults(embeddings); + } + + if (embeddingType.equals(TextEmbeddingFloatResults.class)) { + List embeddings = new ArrayList<>(); + for (int i = 0; i < values.length; i++) { + embeddings.add(new TextEmbeddingFloatResults.Embedding(toFloatArray(values[i], dimensions))); + } + return new TextEmbeddingFloatResults(embeddings); + } + + throw new AssertionError("Unexpected Embedding type [" + embeddingType + "]"); + } +} From 3d075fd859568b5c5646b7f36819afff2289f7ed Mon Sep 17 00:00:00 2001 From: afoucret Date: Tue, 1 Jul 2025 13:52:33 +0200 Subject: [PATCH 03/12] adding a new DENSE_VECTOR_EMBEDDING_FUNCTION capability --- .../elasticsearch/xpack/esql/action/EsqlCapabilities.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 9a8b71e8e5eea..3d4e081f0ec70 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1197,6 +1197,11 @@ public enum Cap { */ KNN_FUNCTION(Build.current().isSnapshot()), + /** + * Support for dense vector embedding function + */ + DENSE_VECTOR_EMBEDDING_FUNCTION(Build.current().isSnapshot()), + LIKE_WITH_LIST_OF_PATTERNS, /** From 3ae547ec76c75b5b1859af66530529cd37cc98c2 Mon Sep 17 00:00:00 2001 From: afoucret Date: Tue, 1 Jul 2025 13:55:43 +0200 Subject: [PATCH 04/12] lint. --- .../DenseEmbeddingOperatorOutputBuilderTests.java | 8 ++++---- .../inference/embedding/DenseEmbeddingOperatorTests.java | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilderTests.java index c4e754dacffcc..ac6c6f9f99216 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilderTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilderTests.java @@ -27,10 +27,10 @@ public class DenseEmbeddingOperatorOutputBuilderTests extends ComputeTestCase { - private final static List DIMENSIONS = List.of(1, 32, 128, 512, 2048, 5096); - private final static List INPUT_SIZES = List.of(10, 100, 1_000, 10_000); - private final static List BATCH_SIZES = List.of(1, 10, 100, 1000); - private final static List>> EMBEDDING_TYPES = List.of( + private static final List DIMENSIONS = List.of(1, 32, 128, 512, 2048, 5096); + private static final List INPUT_SIZES = List.of(10, 100, 1_000, 10_000); + private static final List BATCH_SIZES = List.of(1, 10, 100, 1000); + private static final List>> EMBEDDING_TYPES = List.of( TextEmbeddingBitResults.class, TextEmbeddingByteResults.class, TextEmbeddingFloatResults.class diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorTests.java index b2b105271f016..dc704b22e198f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorTests.java @@ -39,8 +39,8 @@ public class DenseEmbeddingOperatorTests extends InferenceOperatorTestCase DIMENSIONS = List.of(1, 32, 128, 512, 2048, 5096); - private final static List>> EMBEDDING_TYPES = List.of( + private static final List DIMENSIONS = List.of(1, 32, 128, 512, 2048, 5096); + private static final List>> EMBEDDING_TYPES = List.of( TextEmbeddingBitResults.class, TextEmbeddingByteResults.class, TextEmbeddingFloatResults.class From a36e3ae0fc9e654de110192f5cad8af76715a2d4 Mon Sep 17 00:00:00 2001 From: afoucret Date: Tue, 1 Jul 2025 17:54:20 +0200 Subject: [PATCH 05/12] Improve the way test are named. --- .../DenseEmbeddingOperatorOutputBuilderTests.java | 13 ++++++------- .../embedding/DenseEmbeddingOperatorTests.java | 10 ++++++---- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilderTests.java index ac6c6f9f99216..13a31b3f1a74d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilderTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilderTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.inference.embedding; +import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import org.elasticsearch.compute.data.Block; @@ -36,14 +37,12 @@ public class DenseEmbeddingOperatorOutputBuilderTests extends ComputeTestCase { TextEmbeddingFloatResults.class ); - private final static String TEST_PARAMS_FORMATING = "dims=%d, input_size=%d, batch_size=%d, embedding_type=%s"; - private final int dimensions; private final int inputPageSize; private final int batchSize; private final Class> embeddingType; - @ParametersFactory(argumentFormatting = TEST_PARAMS_FORMATING) + @ParametersFactory public static Iterable parameters() { List params = new ArrayList<>(); params.add(new Object[] {}); @@ -66,10 +65,10 @@ public static Iterable parameters() { } public DenseEmbeddingOperatorOutputBuilderTests( - int dimensions, - int inputPageSize, - int batchSize, - Class embeddingType + @Name("dimensions") int dimensions, + @Name("inputPageSize") int inputPageSize, + @Name("batchSize") int batchSize, + @Name("embeddingType") Class embeddingType ) { this.dimensions = dimensions; this.inputPageSize = inputPageSize; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorTests.java index dc704b22e198f..2326be1d42c8b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.inference.embedding; +import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import org.apache.lucene.util.BytesRef; @@ -37,8 +38,6 @@ public class DenseEmbeddingOperatorTests extends InferenceOperatorTestCase DIMENSIONS = List.of(1, 32, 128, 512, 2048, 5096); private static final List>> EMBEDDING_TYPES = List.of( TextEmbeddingBitResults.class, @@ -46,7 +45,7 @@ public class DenseEmbeddingOperatorTests extends InferenceOperatorTestCase parameters() { List params = new ArrayList<>(); params.add(new Object[] {}); @@ -71,7 +70,10 @@ public static Iterable parameters() { private final int dimensions; private final Class> embeddingType; - public DenseEmbeddingOperatorTests(int dimensions, Class> embeddingType) { + public DenseEmbeddingOperatorTests( + @Name("dimensions") int dimensions, + @Name("embeddingType") Class> embeddingType + ) { this.dimensions = dimensions; this.embeddingType = embeddingType; } From 94fb91885e0b2ad53f53b6bd6493689f9a1c0506 Mon Sep 17 00:00:00 2001 From: afoucret Date: Tue, 1 Jul 2025 17:55:28 +0200 Subject: [PATCH 06/12] First implementation of inference function. --- .../functions/text_dense_vector_embedding.svg | 1 + .../text_dense_vector_embedding.json | 9 + .../functions/text_dense_vector_embedding.md | 5 + .../esql/core/expression/MapExpression.java | 14 +- .../esql/expression/ExpressionWritables.java | 9 + .../function/EsqlFunctionRegistry.java | 2 + .../DenseVectorEmbeddingFunction.java | 152 ++++++++++++++ .../function/inference/InferenceFunction.java | 192 ++++++++++++++++++ .../function/AbstractFunctionTestCase.java | 1 + .../DenseVectorEmbeddingErrorTests.java | 79 +++++++ .../DenseVectorEmbeddingFunctionTests.java | 103 ++++++++++ .../DenseVectorSerializationTests.java | 39 ++++ 12 files changed, 605 insertions(+), 1 deletion(-) create mode 100644 docs/reference/query-languages/esql/images/functions/text_dense_vector_embedding.svg create mode 100644 docs/reference/query-languages/esql/kibana/definition/functions/text_dense_vector_embedding.json create mode 100644 docs/reference/query-languages/esql/kibana/docs/functions/text_dense_vector_embedding.md create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingFunction.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingErrorTests.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingFunctionTests.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorSerializationTests.java diff --git a/docs/reference/query-languages/esql/images/functions/text_dense_vector_embedding.svg b/docs/reference/query-languages/esql/images/functions/text_dense_vector_embedding.svg new file mode 100644 index 0000000000000..c628f0137ff91 --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/text_dense_vector_embedding.svg @@ -0,0 +1 @@ +TEXT_DENSE_VECTOR_EMBEDDING(inputText,options) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/text_dense_vector_embedding.json b/docs/reference/query-languages/esql/kibana/definition/functions/text_dense_vector_embedding.json new file mode 100644 index 0000000000000..c64df7894f836 --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/definition/functions/text_dense_vector_embedding.json @@ -0,0 +1,9 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.", + "type" : "scalar", + "name" : "text_dense_vector_embedding", + "description" : "Embed input text into a dense vector representation using an inference model.", + "signatures" : [ ], + "preview" : true, + "snapshot_only" : true +} diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/text_dense_vector_embedding.md b/docs/reference/query-languages/esql/kibana/docs/functions/text_dense_vector_embedding.md new file mode 100644 index 0000000000000..7ee06e487fb0f --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/docs/functions/text_dense_vector_embedding.md @@ -0,0 +1,5 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +### TEXT DENSE VECTOR EMBEDDING +Embed input text into a dense vector representation using an inference model. + diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MapExpression.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MapExpression.java index 24736ac3a2514..6682515e062c4 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MapExpression.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MapExpression.java @@ -120,10 +120,18 @@ public Expression get(Object key) { return map.get(key); } else { // the key(literal) could be converted to BytesRef by ConvertStringToByteRef - return keyFoldedMap.containsKey(key) ? keyFoldedMap.get(key) : keyFoldedMap.get(new BytesRef(key.toString())); + return keyFoldedMap.containsKey(key) ? keyFoldedMap.get(key) : keyFoldedMap.get(getKeyAsBytesRef(key)); } } + public Expression getOrDefault(Object key, Expression defaultValue) { + return containsKey(key) ? get(key) : defaultValue; + } + + public boolean containsKey(Object key) { + return keyFoldedMap.containsKey(key) || keyFoldedMap.containsKey(getKeyAsBytesRef(key)); + } + @Override public boolean equals(Object obj) { if (this == obj) { @@ -142,4 +150,8 @@ public String toString() { String str = entryExpressions.stream().map(String::valueOf).collect(Collectors.joining(", ")); return "{ " + str + " }"; } + + private BytesRef getKeyAsBytesRef(Object key) { + return new BytesRef(key.toString()); + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java index 901f364a60041..e9c2eda5d0026 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java @@ -13,6 +13,7 @@ import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables; import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextWritables; +import org.elasticsearch.xpack.esql.expression.function.inference.DenseVectorEmbeddingFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.ScalarFunctionWritables; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromBase64; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble; @@ -119,6 +120,7 @@ public static List getNamedWriteables() { entries.addAll(fullText()); entries.addAll(unaryScalars()); entries.addAll(vector()); + entries.addAll(inference()); return entries; } @@ -264,4 +266,11 @@ private static List vector() { } return List.of(); } + + private static List inference() { + if (EsqlCapabilities.Cap.DENSE_VECTOR_EMBEDDING_FUNCTION.isEnabled()) { + return List.of(DenseVectorEmbeddingFunction.ENTRY); + } + return List.of(); + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index ede0537d5d3d4..8feb6120b4145 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -52,6 +52,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.Term; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; +import org.elasticsearch.xpack.esql.expression.function.inference.DenseVectorEmbeddingFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least; @@ -479,6 +480,7 @@ private static FunctionDefinition[][] snapshotFunctions() { def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"), def(Term.class, bi(Term::new), "term"), def(Knn.class, tri(Knn::new), "knn"), + def(DenseVectorEmbeddingFunction.class, bi(DenseVectorEmbeddingFunction::new), "text_dense_vector_embedding"), def(StGeohash.class, StGeohash::new, "st_geohash"), def(StGeohashToLong.class, StGeohashToLong::new, "st_geohash_to_long"), def(StGeohashToString.class, StGeohashToString::new, "st_geohash_to_string"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingFunction.java new file mode 100644 index 0000000000000..ac1e1aeb9ebc8 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingFunction.java @@ -0,0 +1,152 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.MapParam; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; +import java.util.UUID; + +/** + * * A function that embeds input text into a dense vector representation using an inference model. + */ +public class DenseVectorEmbeddingFunction extends InferenceFunction { + + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "TextDenseVectorEmbedding", + DenseVectorEmbeddingFunction::new + ); + + private final Expression inputText; + private final Attribute tmpAttribute; + + @FunctionInfo( + returnType = "dense_vector", + preview = true, + description = "Embed input text into a dense vector representation using an inference model.", + appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) } + ) + public DenseVectorEmbeddingFunction( + Source source, + @Param(name = "inputText", type = { "keyword", "text" }, description = "Input text") Expression inputText, + @MapParam( + name = "options", + params = { @MapParam.MapParamEntry(name = "inference_id", type = "keyword", description = "Inference endpoint to use.") }, + optional = true + ) Expression options + ) { + this(source, inputText, options, new ReferenceAttribute(Source.EMPTY, ENTRY.name + "_" + UUID.randomUUID(), DataType.DOUBLE)); + } + + private DenseVectorEmbeddingFunction(Source source, Expression inputText, Expression options, Attribute tmpAttribute) { + super(source, List.of(inputText, tmpAttribute), options); + this.inputText = inputText; + this.tmpAttribute = tmpAttribute; + } + + public DenseVectorEmbeddingFunction(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Attribute.class) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + out.writeNamedWriteable(inputText); + out.writeNamedWriteable(options()); + out.writeNamedWriteable(tmpAttribute); + } + + @Override + public String functionName() { + super.functionName(); + return getWriteableName(); + } + + @Override + public DataType dataType() { + return DataType.DENSE_VECTOR; + } + + @Override + public DenseVectorEmbeddingFunction replaceChildren(List newChildren) { + return new DenseVectorEmbeddingFunction( + source(), + newChildren.get(0), + newChildren.size() > 1 ? newChildren.get(1) : null, + tmpAttribute + ); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, DenseVectorEmbeddingFunction::new, inputText, options(), tmpAttribute); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + protected Literal defaultInferenceId() { + return Literal.NULL; + } + + @Override + public List temporaryAttributes() { + return List.of(tmpAttribute); + } + + @Override + protected TypeResolution resolveParams() { + return TypeResolutions.isString(inputText, sourceText(), TypeResolutions.ParamOrdinal.FIRST); + } + + @Override + protected TypeResolutions.ParamOrdinal optionsParamsOrdinal() { + return TypeResolutions.ParamOrdinal.SECOND; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (!super.equals(o)) return false; + DenseVectorEmbeddingFunction that = (DenseVectorEmbeddingFunction) o; + return Objects.equals(inputText, that.inputText) && Objects.equals(tmpAttribute, that.tmpAttribute); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), inputText, tmpAttribute); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java new file mode 100644 index 0000000000000..507f1e82f6bf1 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java @@ -0,0 +1,192 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.expression.function.Function; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Supplier; +import java.util.stream.Stream; + +/** + * Base class for ESQL functions that perform inference using an `inference_id` and optional parameters. + */ +public abstract class InferenceFunction extends Function implements OptionalArgument { + public static final String INFERENCE_ID_OPTION_NAME = "inference_id"; + + public static final List DEFAULT_OPTIONAL_ARGUMENTS_VALIDATORS = List.of( + new InferenceIdOptionalArgumentsValidator() + ); + + private final Expression inferenceId; + private final Expression options; + + @SuppressWarnings("this-escape") + protected InferenceFunction(Source source, List children, Expression options) { + super(source, Stream.concat(children.stream(), Stream.of(options)).toList()); + this.inferenceId = parseInferenceId(options, this::defaultInferenceId); + this.options = options; + } + + /** + * Returns the expression representing the {@code inference_id} used by the function. + * + * @return the inference ID expression + */ + public Expression inferenceId() { + return inferenceId; + } + + /** + * Returns the expression representing the options passed to the function. + * + * @return the options expression + */ + public Expression options() { + return options; + } + + @Override + protected TypeResolution resolveType() { + if (childrenResolved() == false) { + return new TypeResolution("Unresolved children"); + } + + return resolveParams().and(resolveOptions()); + } + + /** + * Returns the default inference ID expression to use when no {@code inference_id} + * is specified in the options. + * + * @return the default inference ID expression + */ + protected abstract Expression defaultInferenceId(); + + /** + * When an inference function is resolved it is replaced with a temporary attributes that in an ad-hoc inference command. + * These attributes need to be cleansed once they are not used anymore. + * + * @return the list of temporary attributes + */ + public abstract List temporaryAttributes(); + + /** + * Resolves the types of the core parameters passed to this function. + * + * @return the result of parameter type resolution + */ + protected abstract TypeResolution resolveParams(); + + /** + * Return the param ordinal of the optional arguments parameters. + * + * @return the result of option type resolution + */ + protected abstract TypeResolutions.ParamOrdinal optionsParamsOrdinal(); + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + if (!super.equals(o)) return false; + InferenceFunction that = (InferenceFunction) o; + return Objects.equals(options, that.options); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), options); + } + + protected TypeResolution resolveOptions() { + TypeResolution resolution = TypeResolutions.isMapExpression(options(), sourceText(), optionsParamsOrdinal()); + if (resolution.unresolved()) { + return resolution; + } + + MapExpression options = (MapExpression) options(); + for (Map.Entry optionEntry : options.keyFoldedMap().entrySet()) { + for (OptionalArgumentsValidator validator : optionalArgumentsValidators()) { + if (validator.applyTo(optionEntry.getKey(), optionEntry.getValue())) { + TypeResolution optionResolution = validator.resolveOptionValue( + optionEntry.getKey(), + optionEntry.getValue(), + optionsParamsOrdinal() + ); + if (optionResolution.unresolved()) { + return optionResolution; + } + break; + } + } + } + + return TypeResolution.TYPE_RESOLVED; + } + + protected List optionalArgumentsValidators() { + return DEFAULT_OPTIONAL_ARGUMENTS_VALIDATORS; + } + + /** + * Extracts the {@code inference_id} expression from the options. + * Fallback to the provided inference id if the option is missing. + * + * @param options the options map expression + * @param defaultInferenceIdSupplier the supplier for the default inference ID + * @return the resolved inference ID expression + */ + private static Expression parseInferenceId(Expression options, Supplier defaultInferenceIdSupplier) { + return readOption("inference_id", options, defaultInferenceIdSupplier); + } + + /** + * Reads an option value from a map expression with a fallback to a default value. + * + * @param optionName the name of the option to retrieve + * @param options the map expression containing options + * @param defaultValueSupplier the supplier of the default value + * @return the option value as an expression or the default if not present + */ + private static Expression readOption(String optionName, Expression options, Supplier defaultValueSupplier) { + if (options != null && options.dataType() != DataType.NULL && options instanceof MapExpression mapOptions) { + return mapOptions.getOrDefault(optionName, defaultValueSupplier.get()); + } + + return defaultValueSupplier.get(); + } + + public interface OptionalArgumentsValidator { + boolean applyTo(String optionName, Expression optionValue); + + TypeResolution resolveOptionValue(String optionName, Expression optionValue, TypeResolutions.ParamOrdinal paramOrdinal); + } + + public static class InferenceIdOptionalArgumentsValidator implements OptionalArgumentsValidator { + private InferenceIdOptionalArgumentsValidator() {} + + public boolean applyTo(String optionName, Expression optionValue) { + return optionName.equals(INFERENCE_ID_OPTION_NAME); + } + + public TypeResolution resolveOptionValue(String optionName, Expression optionValue, TypeResolutions.ParamOrdinal paramOrdinal) { + return TypeResolutions.isString(optionValue, optionName, paramOrdinal) + .and(TypeResolutions.isNotNull(optionValue, optionName, paramOrdinal)) + .and(TypeResolutions.isFoldable(optionValue, optionName, paramOrdinal)); + } + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java index 00f20b9376a6f..4e256a1563af2 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java @@ -798,6 +798,7 @@ public static void testFunctionInfo() { Set returnTypes = Arrays.stream(description.returnType()) .filter(t -> DataType.UNDER_CONSTRUCTION.containsKey(DataType.fromNameOrAlias(t)) == false) .collect(Collectors.toCollection(TreeSet::new)); + assertEquals(returnFromSignature, returnTypes); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingErrorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingErrorTests.java new file mode 100644 index 0000000000000..4158b2cc9380e --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingErrorTests.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.ErrorsForCasesWithoutExamplesTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matcher; +import org.junit.Before; + +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.stream.Stream; + +import static org.elasticsearch.common.logging.LoggerMessageFormat.format; +import static org.hamcrest.Matchers.equalTo; + +public class DenseVectorEmbeddingErrorTests extends ErrorsForCasesWithoutExamplesTestCase { + + @Before + public void checkCapability() { + assumeTrue("DENSE_VECTOR_EMBEDDING_FUNCTION is not enabled", EsqlCapabilities.Cap.DENSE_VECTOR_EMBEDDING_FUNCTION.isEnabled()); + } + + @Override + protected List cases() { + return paramsToSuppliers(DenseVectorEmbeddingFunctionTests.parameters()); + } + + @Override + protected Stream> testCandidates(List cases, Set> valid) { + // Don't test null, as it is not allowed but the expected message is not a type error - so we check it separately in VerifierTests + return super.testCandidates(cases, valid).filter(sig -> false == sig.contains(DataType.NULL)); + } + + @Override + protected Expression build(Source source, List args) { + return new DenseVectorEmbeddingFunction(source, args.get(0), args.get(1)); + } + + @Override + protected Matcher expectedTypeErrorMatcher(List> validPerPosition, List signature) { + return equalTo(errorMessageString(validPerPosition, signature, (v, p) -> "string")); + } + + private static String errorMessageString( + List> validPerPosition, + List signature, + AbstractFunctionTestCase.PositionalErrorMessageSupplier positionalErrorMessageSupplier + ) { + for (int i = 0; i < signature.size(); i++) { + if (validPerPosition.get(i).contains(signature.get(i)) == false) { + // Map expressions have different error messages + if (i == signature.size() - 1) { + return format( + null, + "{} argument of [{}] must be a map expression, received []", + TypeResolutions.ParamOrdinal.fromIndex(i).name().toLowerCase(Locale.ROOT), + sourceForSignature(signature) + ); + } + break; + } + } + + return typeErrorMessage(true, validPerPosition, signature, positionalErrorMessageSupplier); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingFunctionTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingFunctionTests.java new file mode 100644 index 0000000000000..c493cf7177043 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingFunctionTests.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.xpack.core.security.authc.support.mapper.expressiondsl.FieldExpression; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.FunctionName; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matchers; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Stream; + +import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.elasticsearch.xpack.esql.core.type.DataType.UNSUPPORTED; +import static org.hamcrest.Matchers.equalTo; + +@FunctionName("text_dense_vector_embedding") +public class DenseVectorEmbeddingFunctionTests extends AbstractFunctionTestCase { + @Before + public void checkCapability() { + assumeTrue("DENSE_VECTOR_EMBEDDING_FUNCTION is not enabled", EsqlCapabilities.Cap.DENSE_VECTOR_EMBEDDING_FUNCTION.isEnabled()); + } + + public DenseVectorEmbeddingFunctionTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + List> paramList = List.of(List.of()); + List>> paramSuppliers = List.of( + () -> Stream.concat( + textTestCases("inputText", new BytesRef("input text value")).stream(), + textTestCases("inputText", new FieldExpression("foo", List.of(new FieldExpression.FieldValue("foo")))).stream() + ).toList(), + DenseVectorEmbeddingFunctionTests::optionsTestCases + ); + + for (Supplier> paramSupplier : paramSuppliers) { + List> newParams = new ArrayList<>(); + for (List params : paramList) { + for (TestCaseSupplier.TypedData value : paramSupplier.get()) { + List combination = new ArrayList<>(params); + combination.add(value); + newParams.add(combination); + } + } + paramList = newParams; + } + + return parameterSuppliersFromTypedData(paramList.stream().map(args -> { + List dataTypes = args.stream().map(TestCaseSupplier.TypedData::type).toList(); + return new TestCaseSupplier( + "TextDenseVectorEmbedding[" + dataTypes + "]", + dataTypes, + () -> new TestCaseSupplier.TestCase(args, Matchers.blankOrNullString(), DENSE_VECTOR, equalTo(true)) + ); + }).toList()); + } + + @Override + protected Expression build(Source source, List args) { + return new DenseVectorEmbeddingFunction(source, args.get(0), args.get(1)); + } + + private static List textTestCases(String name, Object value) { + return DataType.stringTypes().stream().map(dataType -> new TestCaseSupplier.TypedData(value, dataType, name)).toList(); + } + + private static List optionsTestCases() { + Literal inferenceIdOptionName = Literal.keyword(EMPTY, "inference_id"); + return DataType.stringTypes() + .stream() + .map( + dataType -> new TestCaseSupplier.TypedData( + new MapExpression(EMPTY, List.of(inferenceIdOptionName, new Literal(EMPTY, new BytesRef("inferenceId"), dataType))), + UNSUPPORTED, + "options" + ) + ) + .toList(); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorSerializationTests.java new file mode 100644 index 0000000000000..57661cf9a3ae0 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorSerializationTests.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.List; + +public class DenseVectorSerializationTests extends AbstractExpressionSerializationTests { + + @Before + public void checkCapability() { + assumeTrue("DENSE_VECTOR_EMBEDDING_FUNCTION is not enabled", EsqlCapabilities.Cap.DENSE_VECTOR_EMBEDDING_FUNCTION.isEnabled()); + } + + protected DenseVectorEmbeddingFunction createTestInstance() { + Source source = randomSource(); + Expression inputText = randomChild(); + Expression options = randomChild(); + return new DenseVectorEmbeddingFunction(source, inputText, options); + } + + @Override + protected DenseVectorEmbeddingFunction mutateInstance(DenseVectorEmbeddingFunction instance) { + List newChildren = new ArrayList<>(instance.children()); + newChildren.set(randomInt(newChildren.size() - 1), randomChild()); + return instance.replaceChildren(newChildren); + } +} From ae208a64a370c555373bf5a23738f4b41bac3499 Mon Sep 17 00:00:00 2001 From: afoucret Date: Wed, 2 Jul 2025 10:43:04 +0200 Subject: [PATCH 07/12] Logical & physical plan implementation. --- .../xpack/esql/plan/PlanWritables.java | 4 + .../embedding/DenseVectorEmbedding.java | 145 ++++++++++++++++++ .../embedding/DenseVectorEmbeddingExec.java | 113 ++++++++++++++ ...enseVectorEmbeddingSerializationTests.java | 59 +++++++ .../DenseVectorEmbeddingExecTests.java | 59 +++++++ 5 files changed, 380 insertions(+) create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbedding.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExec.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbeddingSerializationTests.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExecTests.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java index 2fe9f5182ae00..940d8adda546c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.esql.plan.logical.TopN; import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank; +import org.elasticsearch.xpack.esql.plan.logical.inference.embedding.DenseVectorEmbedding; import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin; import org.elasticsearch.xpack.esql.plan.logical.join.Join; import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject; @@ -55,6 +56,7 @@ import org.elasticsearch.xpack.esql.plan.physical.TopNExec; import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec; import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec; +import org.elasticsearch.xpack.esql.plan.physical.inference.embedding.DenseVectorEmbeddingExec; import java.util.ArrayList; import java.util.List; @@ -72,6 +74,7 @@ public static List logical() { return List.of( Aggregate.ENTRY, Completion.ENTRY, + DenseVectorEmbedding.ENTRY, Dissect.ENTRY, Enrich.ENTRY, EsRelation.ENTRY, @@ -99,6 +102,7 @@ public static List physical() { return List.of( AggregateExec.ENTRY, CompletionExec.ENTRY, + DenseVectorEmbeddingExec.ENTRY, DissectExec.ENTRY, EnrichExec.ENTRY, EsQueryExec.ENTRY, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbedding.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbedding.java new file mode 100644 index 0000000000000..11711be6b4784 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbedding.java @@ -0,0 +1,145 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plan.logical.inference.embedding; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.esql.capabilities.TelemetryAware; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.NameId; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes; + +public class DenseVectorEmbedding extends InferencePlan implements TelemetryAware { + + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + LogicalPlan.class, + "DenseVectorEmbedding", + DenseVectorEmbedding::new + ); + + private final Expression input; + private final Attribute targetField; + private List lazyOutput; + + public DenseVectorEmbedding(Source source, LogicalPlan child, Expression inferenceId, Expression input, Attribute targetField) { + super(source, child, inferenceId); + this.input = input; + this.targetField = targetField; + } + + public DenseVectorEmbedding(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(LogicalPlan.class), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Attribute.class) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + out.writeNamedWriteable(child()); + out.writeNamedWriteable(inferenceId()); + out.writeNamedWriteable(input); + out.writeNamedWriteable(targetField); + } + + public Expression input() { + return input; + } + + public Attribute embeddingField() { + return targetField; + } + + @Override + public TaskType taskType() { + return TaskType.TEXT_EMBEDDING; + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + public List output() { + if (lazyOutput == null) { + lazyOutput = mergeOutputAttributes(List.of(targetField), child().output()); + } + return lazyOutput; + } + + @Override + public List generatedAttributes() { + return List.of(targetField); + } + + @Override + public DenseVectorEmbedding withGeneratedNames(List newNames) { + checkNumberOfNewNames(newNames); + return new DenseVectorEmbedding(source(), child(), inferenceId(), input, this.renameTargetField(newNames.get(0))); + } + + private Attribute renameTargetField(String newName) { + if (newName.equals(targetField.name())) { + return targetField; + } + + return targetField.withName(newName).withId(new NameId()); + } + + @Override + public boolean expressionsResolved() { + return super.expressionsResolved() && input.resolved() && targetField.resolved(); + } + + @Override + public DenseVectorEmbedding withInferenceId(Expression newInferenceId) { + return new DenseVectorEmbedding(source(), child(), newInferenceId, input, targetField); + } + + @Override + public DenseVectorEmbedding replaceChild(LogicalPlan newChild) { + return new DenseVectorEmbedding(source(), newChild, inferenceId(), input, targetField); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, DenseVectorEmbedding::new, child(), inferenceId(), input, targetField); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + DenseVectorEmbedding that = (DenseVectorEmbedding) o; + return Objects.equals(input, that.input) && Objects.equals(targetField, that.targetField); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), input, targetField); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExec.java new file mode 100644 index 0000000000000..9933fc0eae8ce --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExec.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plan.physical.inference.embedding; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.AttributeSet; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; +import org.elasticsearch.xpack.esql.plan.physical.UnaryExec; +import org.elasticsearch.xpack.esql.plan.physical.inference.InferenceExec; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes; + +public class DenseVectorEmbeddingExec extends InferenceExec { + + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + PhysicalPlan.class, + "DenseVectorEmbeddingExec", + DenseVectorEmbeddingExec::new + ); + + private final Expression input; + private final Attribute targetField; + private List lazyOutput; + + public DenseVectorEmbeddingExec(Source source, PhysicalPlan child, Expression inferenceId, Expression input, Attribute targetField) { + super(source, child, inferenceId); + this.input = input; + this.targetField = targetField; + } + + public DenseVectorEmbeddingExec(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(PhysicalPlan.class), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Attribute.class) + ); + } + + public Expression input() { + return input; + } + + public Attribute targetField() { + return targetField; + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeNamedWriteable(input); + out.writeNamedWriteable(targetField); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, DenseVectorEmbeddingExec::new, child(), inferenceId(), input, targetField); + } + + @Override + public UnaryExec replaceChild(PhysicalPlan newChild) { + return new DenseVectorEmbeddingExec(source(), newChild, inferenceId(), input, targetField); + } + + @Override + public List output() { + if (lazyOutput == null) { + lazyOutput = mergeOutputAttributes(List.of(targetField), child().output()); + } + return lazyOutput; + } + + @Override + protected AttributeSet computeReferences() { + return input.references(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + DenseVectorEmbeddingExec that = (DenseVectorEmbeddingExec) o; + return Objects.equals(input, that.input) && Objects.equals(targetField, that.targetField); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), input, targetField); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbeddingSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbeddingSerializationTests.java new file mode 100644 index 0000000000000..8378a851a61fa --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbeddingSerializationTests.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plan.logical.inference.embedding; + +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests; +import org.elasticsearch.xpack.esql.plan.logical.AbstractLogicalPlanSerializationTests; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; + +import java.io.IOException; + +import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; + +public class DenseVectorEmbeddingSerializationTests extends AbstractLogicalPlanSerializationTests { + + @Override + protected DenseVectorEmbedding createTestInstance() { + return new DenseVectorEmbedding(randomSource(), randomChild(0), randomInferenceId(), randomInput(), randomTargetField()); + } + + @Override + protected DenseVectorEmbedding mutateInstance(DenseVectorEmbedding instance) throws IOException { + LogicalPlan child = instance.child(); + Expression inferenceId = instance.inferenceId(); + Expression input = instance.input(); + Attribute targetField = instance.embeddingField(); + + switch (between(0, 3)) { + case 0 -> child = randomValueOtherThan(child, () -> randomChild(0)); + case 1 -> inferenceId = randomValueOtherThan(inferenceId, this::randomInferenceId); + case 2 -> input = randomValueOtherThan(input, this::randomInput); + case 3 -> targetField = randomValueOtherThan(targetField, this::randomTargetField); + } + return new DenseVectorEmbedding(instance.source(), child, inferenceId, input, targetField); + } + + private Literal randomInferenceId() { + return Literal.keyword(EMPTY, randomIdentifier()); + } + + private Expression randomInput() { + return randomBoolean() ? Literal.keyword(EMPTY, randomIdentifier()) : randomAttribute(); + } + + private Attribute randomTargetField() { + return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean()); + } + + private Attribute randomAttribute() { + return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean()); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExecTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExecTests.java new file mode 100644 index 0000000000000..00947d495d43d --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExecTests.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plan.physical.inference.embedding; + +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests; +import org.elasticsearch.xpack.esql.plan.physical.AbstractPhysicalPlanSerializationTests; +import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; + +import java.io.IOException; + +import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; + +public class DenseVectorEmbeddingExecTests extends AbstractPhysicalPlanSerializationTests { + + @Override + protected DenseVectorEmbeddingExec createTestInstance() { + return new DenseVectorEmbeddingExec(randomSource(), randomChild(0), randomInferenceId(), randomInput(), randomTargetField()); + } + + @Override + protected DenseVectorEmbeddingExec mutateInstance(DenseVectorEmbeddingExec instance) throws IOException { + PhysicalPlan child = instance.child(); + Expression inferenceId = instance.inferenceId(); + Expression input = instance.input(); + Attribute targetField = instance.targetField(); + + switch (between(0, 3)) { + case 0 -> child = randomValueOtherThan(child, () -> randomChild(0)); + case 1 -> inferenceId = randomValueOtherThan(inferenceId, this::randomInferenceId); + case 2 -> input = randomValueOtherThan(input, this::randomInput); + case 3 -> targetField = randomValueOtherThan(targetField, this::randomTargetField); + } + return new DenseVectorEmbeddingExec(instance.source(), child, inferenceId, input, targetField); + } + + private Literal randomInferenceId() { + return Literal.keyword(EMPTY, randomIdentifier()); + } + + private Expression randomInput() { + return randomBoolean() ? Literal.keyword(EMPTY, randomIdentifier()) : randomAttribute(); + } + + private Attribute randomTargetField() { + return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean()); + } + + private Attribute randomAttribute() { + return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean()); + } +} From 9f672d3856d939c1485770d54017c2b92aef2f5e Mon Sep 17 00:00:00 2001 From: afoucret Date: Wed, 2 Jul 2025 13:40:42 +0200 Subject: [PATCH 08/12] Checkstyle --- .../function/inference/DenseVectorEmbeddingFunction.java | 2 +- .../esql/expression/function/inference/InferenceFunction.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingFunction.java index ac1e1aeb9ebc8..0f5ac162ac283 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingFunction.java @@ -140,7 +140,7 @@ protected TypeResolutions.ParamOrdinal optionsParamsOrdinal() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - if (!super.equals(o)) return false; + if (super.equals(o) == false) return false; DenseVectorEmbeddingFunction that = (DenseVectorEmbeddingFunction) o; return Objects.equals(inputText, that.inputText) && Objects.equals(tmpAttribute, that.tmpAttribute); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java index 507f1e82f6bf1..4bbf32744bd7e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java @@ -102,7 +102,7 @@ protected TypeResolution resolveType() { @Override public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; - if (!super.equals(o)) return false; + if (super.equals(o) == false) return false; InferenceFunction that = (InferenceFunction) o; return Objects.equals(options, that.options); } From 4408f90333e4f2d82224c18977cedbc8424b9191 Mon Sep 17 00:00:00 2001 From: afoucret Date: Wed, 2 Jul 2025 13:57:15 +0200 Subject: [PATCH 09/12] Implemented DenseVectorEmbedding plan analysis. --- .../xpack/esql/analysis/Analyzer.java | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index f48c95397dcab..233da09879a9a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -97,6 +97,7 @@ import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank; +import org.elasticsearch.xpack.esql.plan.logical.inference.embedding.DenseVectorEmbedding; import org.elasticsearch.xpack.esql.plan.logical.join.Join; import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig; import org.elasticsearch.xpack.esql.plan.logical.join.JoinType; @@ -138,6 +139,7 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME; import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS; import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_PERIOD; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT; import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_POINT; @@ -516,6 +518,10 @@ protected LogicalPlan rule(LogicalPlan plan, AnalyzerContext context) { return resolveEval(p, childrenOutput); } + if (plan instanceof DenseVectorEmbedding dve) { + return resolveDenseVectorEmbedding(dve, childrenOutput); + } + if (plan instanceof Enrich p) { return resolveEnrich(p, childrenOutput); } @@ -820,6 +826,34 @@ private LogicalPlan resolveFork(Fork fork, AnalyzerContext context) { return changed ? new Fork(fork.source(), newSubPlans, newOutput) : fork; } + private LogicalPlan resolveDenseVectorEmbedding(DenseVectorEmbedding p, List childrenOutput) { + // Resolve the input expression + Expression input = p.input(); + if (input.resolved() == false) { + input = input.transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput)); + } + + // Resolve the target field (similar to Completion) + Attribute targetField = p.embeddingField(); + if (targetField instanceof UnresolvedAttribute ua) { + targetField = new ReferenceAttribute(ua.source(), ua.name(), DENSE_VECTOR); + } + + // Create a new DenseVectorEmbedding with resolved expressions + // Only create a new instance if something changed to avoid unnecessary object creation + if (input != p.input() || targetField != p.embeddingField()) { + return new DenseVectorEmbedding( + p.source(), + p.child(), + p.inferenceId(), + input, + targetField + ); + } + + return p; + } + private LogicalPlan resolveRerank(Rerank rerank, List childrenOutput) { List newFields = new ArrayList<>(); boolean changed = false; From 7322c0c006f123ceb176a85583cba52abfea90df Mon Sep 17 00:00:00 2001 From: afoucret Date: Wed, 2 Jul 2025 18:09:22 +0200 Subject: [PATCH 10/12] Resolve dimensions of the embedding from the model config. --- .../xpack/esql/analysis/Analyzer.java | 10 +-- .../xpack/esql/inference/InferenceRunner.java | 2 +- .../esql/inference/ResolvedInference.java | 18 +---- .../plan/logical/inference/InferencePlan.java | 6 ++ .../embedding/DenseVectorEmbedding.java | 79 +++++++++++++++++-- .../embedding/DenseVectorEmbeddingExec.java | 27 +++++-- .../esql/analysis/AnalyzerTestUtils.java | 13 ++- .../esql/inference/InferenceRunnerTests.java | 30 ++++++- .../inference/ResolvedInferenceTests.java | 41 ---------- ...enseVectorEmbeddingSerializationTests.java | 24 ++++-- .../DenseVectorEmbeddingExecTests.java | 24 ++++-- 11 files changed, 183 insertions(+), 91 deletions(-) delete mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/ResolvedInferenceTests.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index 233da09879a9a..31c7efe719c2a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -408,7 +408,7 @@ protected LogicalPlan rule(InferencePlan plan, AnalyzerContext context) { ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId); if (resolvedInference != null && resolvedInference.taskType() == plan.taskType()) { - return plan; + return plan.withModelConfigurations(resolvedInference.modelConfigurations()); } else if (resolvedInference != null) { String error = "cannot use inference endpoint [" + inferenceId @@ -842,13 +842,7 @@ private LogicalPlan resolveDenseVectorEmbedding(DenseVectorEmbedding p, List inferenceIds, ActionListener { - ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst().getTaskType()); + ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst()); inferenceResolutionBuilder.withResolvedInference(resolvedInference); countdownListener.onResponse(null); }, e -> { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/ResolvedInference.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/ResolvedInference.java index 455ed6488379a..0bf876d6ff8ed 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/ResolvedInference.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/ResolvedInference.java @@ -7,22 +7,12 @@ package org.elasticsearch.xpack.esql.inference; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; -import java.io.IOException; +public record ResolvedInference(String inferenceId, ModelConfigurations modelConfigurations) { -public record ResolvedInference(String inferenceId, TaskType taskType) implements Writeable { - - public ResolvedInference(StreamInput in) throws IOException { - this(in.readString(), TaskType.valueOf(in.readString())); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(inferenceId); - out.writeString(taskType.name()); + public TaskType taskType() { + return modelConfigurations.getTaskType(); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java index 620e8726865d6..1ad945dffa4a8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.plan.logical.inference; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; @@ -69,4 +70,9 @@ public int hashCode() { public PlanType withInferenceResolutionError(String inferenceId, String error) { return withInferenceId(new UnresolvedAttribute(inferenceId().source(), inferenceId, error)); } + + @SuppressWarnings("unchecked") + public PlanType withModelConfigurations(ModelConfigurations modelConfig) { + return (PlanType) this; + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbedding.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbedding.java index 11711be6b4784..c654bcff590be 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbedding.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbedding.java @@ -10,13 +10,19 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.lucene.BytesRefs; +import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.esql.capabilities.TelemetryAware; +import org.elasticsearch.xpack.esql.core.capabilities.Unresolvable; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; @@ -36,13 +42,26 @@ public class DenseVectorEmbedding extends InferencePlan im ); private final Expression input; + private final Expression dimensions; private final Attribute targetField; private List lazyOutput; public DenseVectorEmbedding(Source source, LogicalPlan child, Expression inferenceId, Expression input, Attribute targetField) { + this(source, child, inferenceId, new UnresolvedDimensions(inferenceId), input, targetField); + } + + DenseVectorEmbedding( + Source source, + LogicalPlan child, + Expression inferenceId, + Expression dimensions, + Expression input, + Attribute targetField + ) { super(source, child, inferenceId); this.input = input; this.targetField = targetField; + this.dimensions = dimensions; } public DenseVectorEmbedding(StreamInput in) throws IOException { @@ -51,6 +70,7 @@ public DenseVectorEmbedding(StreamInput in) throws IOException { in.readNamedWriteable(LogicalPlan.class), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class), in.readNamedWriteable(Attribute.class) ); } @@ -60,6 +80,7 @@ public void writeTo(StreamOutput out) throws IOException { source().writeTo(out); out.writeNamedWriteable(child()); out.writeNamedWriteable(inferenceId()); + out.writeNamedWriteable(dimensions); out.writeNamedWriteable(input); out.writeNamedWriteable(targetField); } @@ -77,6 +98,10 @@ public TaskType taskType() { return TaskType.TEXT_EMBEDDING; } + public Expression dimensions() { + return dimensions; + } + @Override public String getWriteableName() { return ENTRY.name; @@ -98,7 +123,7 @@ public List generatedAttributes() { @Override public DenseVectorEmbedding withGeneratedNames(List newNames) { checkNumberOfNewNames(newNames); - return new DenseVectorEmbedding(source(), child(), inferenceId(), input, this.renameTargetField(newNames.get(0))); + return new DenseVectorEmbedding(source(), child(), inferenceId(), dimensions, input, this.renameTargetField(newNames.get(0))); } private Attribute renameTargetField(String newName) { @@ -111,22 +136,45 @@ private Attribute renameTargetField(String newName) { @Override public boolean expressionsResolved() { - return super.expressionsResolved() && input.resolved() && targetField.resolved(); + return super.expressionsResolved() && input.resolved() && targetField.resolved() && dimensions.resolved(); } @Override public DenseVectorEmbedding withInferenceId(Expression newInferenceId) { - return new DenseVectorEmbedding(source(), child(), newInferenceId, input, targetField); + return new DenseVectorEmbedding(source(), child(), newInferenceId, dimensions, input, targetField); + } + + public DenseVectorEmbedding withDimensions(Expression newDimensions) { + return new DenseVectorEmbedding(source(), child(), inferenceId(), newDimensions, input, targetField); + } + + public DenseVectorEmbedding withTargetField(Attribute targetField) { + return new DenseVectorEmbedding(source(), child(), inferenceId(), dimensions, input, targetField); + } + + @Override + public DenseVectorEmbedding withModelConfigurations(ModelConfigurations modelConfig) { + boolean hasChanged = false; + Expression newDimensions = dimensions; + + if (dimensions.resolved() == false + && modelConfig.getServiceSettings() != null + && modelConfig.getServiceSettings().dimensions() > 0) { + hasChanged = true; + newDimensions = new Literal(Source.EMPTY, modelConfig.getServiceSettings().dimensions(), DataType.INTEGER); + } + + return hasChanged ? withDimensions(newDimensions) : this; } @Override public DenseVectorEmbedding replaceChild(LogicalPlan newChild) { - return new DenseVectorEmbedding(source(), newChild, inferenceId(), input, targetField); + return new DenseVectorEmbedding(source(), newChild, inferenceId(), dimensions, input, targetField); } @Override protected NodeInfo info() { - return NodeInfo.create(this, DenseVectorEmbedding::new, child(), inferenceId(), input, targetField); + return NodeInfo.create(this, DenseVectorEmbedding::new, child(), inferenceId(), dimensions, input, targetField); } @Override @@ -135,11 +183,28 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; if (super.equals(o) == false) return false; DenseVectorEmbedding that = (DenseVectorEmbedding) o; - return Objects.equals(input, that.input) && Objects.equals(targetField, that.targetField); + return Objects.equals(input, that.input) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(targetField, that.targetField); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), input, targetField); + return Objects.hash(super.hashCode(), input, targetField, dimensions); + } + + private static class UnresolvedDimensions extends Literal implements Unresolvable { + + private final String inferenceId; + + private UnresolvedDimensions(Expression inferenceId) { + super(Source.EMPTY, null, DataType.NULL); + this.inferenceId = BytesRefs.toString(inferenceId.fold(FoldContext.small())); + } + + @Override + public String unresolvedMessage() { + return "Dimensions cannot be resolved for inference endpoint[" + inferenceId + "]"; + } } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExec.java index 9933fc0eae8ce..4eae30daeef8e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExec.java @@ -35,12 +35,21 @@ public class DenseVectorEmbeddingExec extends InferenceExec { ); private final Expression input; + private final Expression dimensions; private final Attribute targetField; private List lazyOutput; - public DenseVectorEmbeddingExec(Source source, PhysicalPlan child, Expression inferenceId, Expression input, Attribute targetField) { + public DenseVectorEmbeddingExec( + Source source, + PhysicalPlan child, + Expression inferenceId, + Expression dimensions, + Expression input, + Attribute targetField + ) { super(source, child, inferenceId); this.input = input; + this.dimensions = dimensions; this.targetField = targetField; } @@ -50,6 +59,7 @@ public DenseVectorEmbeddingExec(StreamInput in) throws IOException { in.readNamedWriteable(PhysicalPlan.class), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class), in.readNamedWriteable(Attribute.class) ); } @@ -70,18 +80,23 @@ public String getWriteableName() { @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); + out.writeNamedWriteable(dimensions); out.writeNamedWriteable(input); out.writeNamedWriteable(targetField); } + public Expression dimensions() { + return dimensions; + } + @Override protected NodeInfo info() { - return NodeInfo.create(this, DenseVectorEmbeddingExec::new, child(), inferenceId(), input, targetField); + return NodeInfo.create(this, DenseVectorEmbeddingExec::new, child(), inferenceId(), input, dimensions, targetField); } @Override public UnaryExec replaceChild(PhysicalPlan newChild) { - return new DenseVectorEmbeddingExec(source(), newChild, inferenceId(), input, targetField); + return new DenseVectorEmbeddingExec(source(), newChild, inferenceId(), input, dimensions, targetField); } @Override @@ -103,11 +118,13 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; if (super.equals(o) == false) return false; DenseVectorEmbeddingExec that = (DenseVectorEmbeddingExec) o; - return Objects.equals(input, that.input) && Objects.equals(targetField, that.targetField); + return Objects.equals(input, that.input) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(targetField, that.targetField); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), input, targetField); + return Objects.hash(super.hashCode(), input, dimensions, targetField); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java index 5e6c37545a396..fd5f607f53fb7 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java @@ -39,6 +39,8 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER; import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration; import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public final class AnalyzerTestUtils { @@ -189,12 +191,19 @@ public static EnrichResolution defaultEnrichResolution() { public static InferenceResolution defaultInferenceResolution() { return InferenceResolution.builder() - .withResolvedInference(new ResolvedInference("reranking-inference-id", TaskType.RERANK)) - .withResolvedInference(new ResolvedInference("completion-inference-id", TaskType.COMPLETION)) + .withResolvedInference(mockedResolvedInference("reranking-inference-id", TaskType.RERANK)) + .withResolvedInference(mockedResolvedInference("completion-inference-id", TaskType.COMPLETION)) .withError("error-inference-id", "error with inference resolution") .build(); } + private static ResolvedInference mockedResolvedInference(String id, TaskType taskType) { + ResolvedInference resolvedInference = mock(ResolvedInference.class); + when(resolvedInference.inferenceId()).thenReturn(id); + when(resolvedInference.taskType()).thenReturn(taskType); + return resolvedInference; + } + public static void loadEnrichPolicyResolution( EnrichResolution enrich, String policyType, diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java index ef7b3984bd532..7a5a510596878 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java @@ -26,11 +26,14 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; +import org.hamcrest.FeatureMatcher; +import org.hamcrest.Matcher; import org.junit.After; import org.junit.Before; import java.util.List; +import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @@ -75,7 +78,10 @@ public void testResolveInferenceIds() throws Exception { assertBusy(() -> { InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get(); assertNotNull(inferenceResolution); - assertThat(inferenceResolution.resolvedInferences(), contains(new ResolvedInference("rerank-plan", TaskType.RERANK))); + assertThat( + inferenceResolution.resolvedInferences(), + contains(allOf(inferenceId(equalTo("rerank-plan")), taskType(equalTo(TaskType.RERANK)))) + ); assertThat(inferenceResolution.hasError(), equalTo(false)); }); } @@ -100,8 +106,8 @@ public void testResolveMultipleInferenceIds() throws Exception { assertThat( inferenceResolution.resolvedInferences(), contains( - new ResolvedInference("rerank-plan", TaskType.RERANK), - new ResolvedInference("completion-plan", TaskType.COMPLETION) + allOf(inferenceId(equalTo("rerank-plan")), taskType(equalTo(TaskType.RERANK))), + allOf(inferenceId(equalTo("completion-plan")), taskType(equalTo(TaskType.COMPLETION))) ) ); assertThat(inferenceResolution.hasError(), equalTo(false)); @@ -184,4 +190,22 @@ private static InferencePlan mockInferencePlan(String inferenceId) { when(plan.inferenceId()).thenReturn(Literal.keyword(Source.EMPTY, inferenceId)); return plan; } + + private FeatureMatcher inferenceId(Matcher matcher) { + return new FeatureMatcher<>(matcher, "inference id", "inferenceId") { + @Override + protected String featureValueOf(ResolvedInference resolution) { + return resolution.inferenceId(); + } + }; + } + + private FeatureMatcher taskType(Matcher matcher) { + return new FeatureMatcher<>(matcher, "task type", "taskType") { + @Override + protected TaskType featureValueOf(ResolvedInference resolution) { + return resolution.taskType(); + } + }; + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/ResolvedInferenceTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/ResolvedInferenceTests.java deleted file mode 100644 index b4dfd87224a3a..0000000000000 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/ResolvedInferenceTests.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.inference; - -import org.elasticsearch.TransportVersion; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.AbstractWireTestCase; -import org.elasticsearch.test.ESTestCase; - -import java.io.IOException; - -public class ResolvedInferenceTests extends AbstractWireTestCase { - - @Override - protected ResolvedInference createTestInstance() { - return new ResolvedInference(randomIdentifier(), randomTaskType()); - } - - @Override - protected ResolvedInference mutateInstance(ResolvedInference instance) throws IOException { - if (randomBoolean()) { - return new ResolvedInference(randomValueOtherThan(instance.inferenceId(), ESTestCase::randomIdentifier), instance.taskType()); - } - - return new ResolvedInference(instance.inferenceId(), randomValueOtherThan(instance.taskType(), this::randomTaskType)); - } - - @Override - protected ResolvedInference copyInstance(ResolvedInference instance, TransportVersion version) throws IOException { - return copyInstance(instance, getNamedWriteableRegistry(), (out, v) -> v.writeTo(out), in -> new ResolvedInference(in), version); - } - - private TaskType randomTaskType() { - return randomFrom(TaskType.values()); - } -} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbeddingSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbeddingSerializationTests.java index 8378a851a61fa..f82a090d0801a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbeddingSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbeddingSerializationTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests; import org.elasticsearch.xpack.esql.plan.logical.AbstractLogicalPlanSerializationTests; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; @@ -22,23 +23,32 @@ public class DenseVectorEmbeddingSerializationTests extends AbstractLogicalPlanS @Override protected DenseVectorEmbedding createTestInstance() { - return new DenseVectorEmbedding(randomSource(), randomChild(0), randomInferenceId(), randomInput(), randomTargetField()); + return new DenseVectorEmbedding( + randomSource(), + randomChild(0), + randomInferenceId(), + randomDimensions(), + randomInput(), + randomTargetField() + ); } @Override protected DenseVectorEmbedding mutateInstance(DenseVectorEmbedding instance) throws IOException { LogicalPlan child = instance.child(); Expression inferenceId = instance.inferenceId(); + Expression dimensions = instance.dimensions(); Expression input = instance.input(); Attribute targetField = instance.embeddingField(); - switch (between(0, 3)) { + switch (between(0, 4)) { case 0 -> child = randomValueOtherThan(child, () -> randomChild(0)); case 1 -> inferenceId = randomValueOtherThan(inferenceId, this::randomInferenceId); - case 2 -> input = randomValueOtherThan(input, this::randomInput); - case 3 -> targetField = randomValueOtherThan(targetField, this::randomTargetField); + case 2 -> dimensions = randomValueOtherThan(instance.dimensions(), this::randomDimensions); + case 3 -> input = randomValueOtherThan(input, this::randomInput); + case 4 -> targetField = randomValueOtherThan(targetField, this::randomTargetField); } - return new DenseVectorEmbedding(instance.source(), child, inferenceId, input, targetField); + return new DenseVectorEmbedding(instance.source(), child, inferenceId, dimensions, input, targetField); } private Literal randomInferenceId() { @@ -56,4 +66,8 @@ private Attribute randomTargetField() { private Attribute randomAttribute() { return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean()); } + + private Expression randomDimensions() { + return new Literal(EMPTY, randomInt(), DataType.INTEGER); + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExecTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExecTests.java index 00947d495d43d..c079860869920 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExecTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExecTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests; import org.elasticsearch.xpack.esql.plan.physical.AbstractPhysicalPlanSerializationTests; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; @@ -22,23 +23,32 @@ public class DenseVectorEmbeddingExecTests extends AbstractPhysicalPlanSerializa @Override protected DenseVectorEmbeddingExec createTestInstance() { - return new DenseVectorEmbeddingExec(randomSource(), randomChild(0), randomInferenceId(), randomInput(), randomTargetField()); + return new DenseVectorEmbeddingExec( + randomSource(), + randomChild(0), + randomInferenceId(), + randomDimensions(), + randomInput(), + randomTargetField() + ); } @Override protected DenseVectorEmbeddingExec mutateInstance(DenseVectorEmbeddingExec instance) throws IOException { PhysicalPlan child = instance.child(); Expression inferenceId = instance.inferenceId(); + Expression dimensions = instance.dimensions(); Expression input = instance.input(); Attribute targetField = instance.targetField(); - switch (between(0, 3)) { + switch (between(0, 4)) { case 0 -> child = randomValueOtherThan(child, () -> randomChild(0)); case 1 -> inferenceId = randomValueOtherThan(inferenceId, this::randomInferenceId); - case 2 -> input = randomValueOtherThan(input, this::randomInput); - case 3 -> targetField = randomValueOtherThan(targetField, this::randomTargetField); + case 2 -> dimensions = randomValueOtherThan(input, this::randomAttribute); + case 3 -> input = randomValueOtherThan(input, this::randomInput); + case 4 -> targetField = randomValueOtherThan(targetField, this::randomTargetField); } - return new DenseVectorEmbeddingExec(instance.source(), child, inferenceId, input, targetField); + return new DenseVectorEmbeddingExec(instance.source(), child, inferenceId, dimensions, input, targetField); } private Literal randomInferenceId() { @@ -56,4 +66,8 @@ private Attribute randomTargetField() { private Attribute randomAttribute() { return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean()); } + + private Expression randomDimensions() { + return new Literal(EMPTY, randomInt(), DataType.INTEGER); + } } From 03c6110db4eed52452c665d67930d065231dc3f6 Mon Sep 17 00:00:00 2001 From: afoucret Date: Wed, 2 Jul 2025 18:22:27 +0200 Subject: [PATCH 11/12] Physical plan planning. --- .../embedding/DenseEmbeddingOperator.java | 6 ++-- .../esql/planner/LocalExecutionPlanner.java | 29 +++++++++++++++++++ .../esql/planner/mapper/MapperUtils.java | 13 +++++++++ .../DenseEmbeddingOperatorTests.java | 2 +- 4 files changed, 46 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperator.java index 5b006a07855fe..d3c1e471891a1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperator.java @@ -43,8 +43,8 @@ public DenseEmbeddingOperator( DriverContext driverContext, InferenceRunner inferenceRunner, ThreadPool threadPool, - int dimensions, String inferenceId, + int dimensions, ExpressionEvaluator inputEvaluator ) { super(driverContext, inferenceRunner, BulkInferenceExecutionConfig.DEFAULT, threadPool, inferenceId); @@ -100,8 +100,8 @@ protected DenseEmbeddingOperatorOutputBuilder outputBuilder(Page input) { */ public record Factory( InferenceRunner inferenceRunner, - int dimensions, String inferenceId, + int dimensions, ExpressionEvaluator.Factory inputEvaluatorFactory ) implements OperatorFactory { @@ -116,8 +116,8 @@ public Operator get(DriverContext driverContext) { driverContext, inferenceRunner, inferenceRunner.threadPool(), - dimensions, inferenceId, + dimensions, inputEvaluatorFactory().get(driverContext) ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java index a92d2f439a0ea..333eab10b217a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java @@ -88,6 +88,7 @@ import org.elasticsearch.xpack.esql.inference.InferenceRunner; import org.elasticsearch.xpack.esql.inference.XContentRowEncoder; import org.elasticsearch.xpack.esql.inference.completion.CompletionOperator; +import org.elasticsearch.xpack.esql.inference.embedding.DenseEmbeddingOperator; import org.elasticsearch.xpack.esql.inference.rerank.RerankOperator; import org.elasticsearch.xpack.esql.plan.logical.Fork; import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; @@ -119,6 +120,7 @@ import org.elasticsearch.xpack.esql.plan.physical.TopNExec; import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec; import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec; +import org.elasticsearch.xpack.esql.plan.physical.inference.embedding.DenseVectorEmbeddingExec; import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders.ShardContext; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import org.elasticsearch.xpack.esql.score.ScoreMapper; @@ -266,6 +268,8 @@ private PhysicalOperation plan(PhysicalPlan node, LocalExecutionPlannerContext c return planChangePoint(changePoint, context); } else if (node instanceof CompletionExec completion) { return planCompletion(completion, context); + } else if (node instanceof DenseVectorEmbeddingExec embedding) { + return planDenseVectorEmbedding(embedding, context); } else if (node instanceof SampleExec Sample) { return planSample(Sample, context); } @@ -319,6 +323,31 @@ private PhysicalOperation planCompletion(CompletionExec completion, LocalExecuti return source.with(new CompletionOperator.Factory(inferenceRunner, inferenceId, promptEvaluatorFactory), outputLayout); } + private PhysicalOperation planDenseVectorEmbedding(DenseVectorEmbeddingExec embedding, LocalExecutionPlannerContext context) { + PhysicalOperation source = plan(embedding.child(), context); + String inferenceId = BytesRefs.toString(embedding.inferenceId().fold(context.foldCtx())); + + int dimensions; + if (embedding.dimensions() instanceof Literal literal) { + Object val = literal.value() instanceof BytesRef br ? BytesRefs.toString(br) : literal.value(); + dimensions = stringToInt(val.toString()); + } else { + throw new EsqlIllegalArgumentException("dimensions only supported with literal values"); + } + + Layout outputLayout = source.layout.builder().append(embedding.targetField()).build(); + EvalOperator.ExpressionEvaluator.Factory inputEvaluatorFactory = EvalMapper.toEvaluator( + context.foldCtx(), + embedding.input(), + source.layout + ); + + return source.with( + new DenseEmbeddingOperator.Factory(inferenceRunner, inferenceId, dimensions, inputEvaluatorFactory), + outputLayout + ); + } + private PhysicalOperation planRrfScoreEvalExec(RrfScoreEvalExec rrf, LocalExecutionPlannerContext context) { PhysicalOperation source = plan(rrf.child(), context); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java index 4851de1616844..884f50b71a549 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java @@ -29,6 +29,7 @@ import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank; +import org.elasticsearch.xpack.esql.plan.logical.inference.embedding.DenseVectorEmbedding; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; import org.elasticsearch.xpack.esql.plan.logical.show.ShowInfo; import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; @@ -47,6 +48,7 @@ import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec; import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec; import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec; +import org.elasticsearch.xpack.esql.plan.physical.inference.embedding.DenseVectorEmbeddingExec; import org.elasticsearch.xpack.esql.planner.AbstractPhysicalOperationProviders; import java.util.List; @@ -106,6 +108,17 @@ static PhysicalPlan mapUnary(UnaryPlan p, PhysicalPlan child) { return new CompletionExec(completion.source(), child, completion.inferenceId(), completion.prompt(), completion.targetField()); } + if (p instanceof DenseVectorEmbedding embedding) { + return new DenseVectorEmbeddingExec( + embedding.source(), + child, + embedding.inferenceId(), + embedding.dimensions(), + embedding.input(), + embedding.embeddingField() + ); + } + if (p instanceof Enrich enrich) { return new EnrichExec( enrich.source(), diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorTests.java index 2326be1d42c8b..d0eafc4154835 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorTests.java @@ -80,7 +80,7 @@ public DenseEmbeddingOperatorTests( @Override protected Operator.OperatorFactory simple(SimpleOptions options) { - return new DenseEmbeddingOperator.Factory(mockedSimpleInferenceRunner(), dimensions, SIMPLE_INFERENCE_ID, evaluatorFactory(0)); + return new DenseEmbeddingOperator.Factory(mockedSimpleInferenceRunner(), SIMPLE_INFERENCE_ID, dimensions, evaluatorFactory(0)); } @Override From b80dc31a6fe1f36f89149f5fae66759978776c29 Mon Sep 17 00:00:00 2001 From: afoucret Date: Wed, 2 Jul 2025 19:13:29 +0200 Subject: [PATCH 12/12] Modify PreAnalyzer so it will be easier to implement inference function pre-analysis. --- .../xpack/esql/analysis/PreAnalyzer.java | 34 +++++++++++++++---- .../xpack/esql/inference/InferenceRunner.java | 9 +---- .../xpack/esql/session/EsqlSession.java | 11 ++---- .../esql/analysis/AnalyzerTestUtils.java | 10 +++--- .../esql/inference/InferenceRunnerTests.java | 17 ++++------ 5 files changed, 43 insertions(+), 38 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java index 5b9f41876d6e1..f50175be653a5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java @@ -7,7 +7,9 @@ package org.elasticsearch.xpack.esql.analysis; +import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.index.IndexMode; +import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.plan.IndexPattern; import org.elasticsearch.xpack.esql.plan.logical.Enrich; @@ -21,6 +23,7 @@ import java.util.Set; import static java.util.Collections.emptyList; +import static java.util.Collections.emptySet; /** * This class is part of the planner. Acts somewhat like a linker, to find the indices and enrich policies referenced by the query. @@ -28,25 +31,25 @@ public class PreAnalyzer { public static class PreAnalysis { - public static final PreAnalysis EMPTY = new PreAnalysis(null, emptyList(), emptyList(), emptyList(), emptyList()); + public static final PreAnalysis EMPTY = new PreAnalysis(null, emptyList(), emptyList(), emptySet(), emptyList()); public final IndexMode indexMode; public final List indices; public final List enriches; - public final List> inferencePlans; + public final Set inferenceIds; public final List lookupIndices; public PreAnalysis( IndexMode indexMode, List indices, List enriches, - List> inferencePlans, + Set inferenceIds, List lookupIndices ) { this.indexMode = indexMode; this.indices = indices; this.enriches = enriches; - this.inferencePlans = inferencePlans; + this.inferenceIds = inferenceIds; this.lookupIndices = lookupIndices; } } @@ -64,7 +67,7 @@ protected PreAnalysis doPreAnalyze(LogicalPlan plan) { List unresolvedEnriches = new ArrayList<>(); List lookupIndices = new ArrayList<>(); - List> unresolvedInferencePlans = new ArrayList<>(); + Set unresolvedInferenceIds = new HashSet<>(); Holder indexMode = new Holder<>(); plan.forEachUp(UnresolvedRelation.class, p -> { if (p.indexMode() == IndexMode.LOOKUP) { @@ -78,11 +81,28 @@ protected PreAnalysis doPreAnalyze(LogicalPlan plan) { }); plan.forEachUp(Enrich.class, unresolvedEnriches::add); - plan.forEachUp(InferencePlan.class, unresolvedInferencePlans::add); // mark plan as preAnalyzed (if it were marked, there would be no analysis) plan.forEachUp(LogicalPlan::setPreAnalyzed); - return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, unresolvedInferencePlans, lookupIndices); + return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, inferenceIds(plan), lookupIndices); + } + + protected Set inferenceIds(LogicalPlan plan) { + Set inferenceIds = new HashSet<>(); + + List> inferencePlans = new ArrayList<>(); + plan.forEachUp(InferencePlan.class, inferencePlans::add); + inferencePlans.stream().map(this::inferenceId).forEach(inferenceIds::add); + + return inferenceIds; + } + + private String inferenceId(InferencePlan inferencePlan) { + if (inferencePlan.inferenceId() instanceof Literal literal) { + return BytesRefs.toString(literal.value()); + } + + throw new IllegalStateException("inferenceId is not a literal"); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java index 3b7f08dbea761..c1696257edd00 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java @@ -18,9 +18,7 @@ import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; -import java.util.List; import java.util.Set; -import java.util.stream.Collectors; import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; @@ -39,12 +37,7 @@ public ThreadPool threadPool() { return threadPool; } - public void resolveInferenceIds(List> plans, ActionListener listener) { - resolveInferenceIds(plans.stream().map(InferenceRunner::planInferenceId).collect(Collectors.toSet()), listener); - - } - - private void resolveInferenceIds(Set inferenceIds, ActionListener listener) { + public void resolveInferenceIds(Set inferenceIds, ActionListener listener) { if (inferenceIds.isEmpty()) { listener.onResponse(InferenceResolution.EMPTY); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java index 4ff65f59bbd72..920201897199b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java @@ -85,7 +85,6 @@ import org.elasticsearch.xpack.esql.plan.logical.TopN; import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation; import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; -import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin; import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes; import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin; @@ -372,7 +371,7 @@ public void analyzedPlan( l -> enrichPolicyResolver.resolvePolicies(unresolvedPolicies, executionInfo, l) ) .andThen((l, enrichResolution) -> resolveFieldNames(parsed, enrichResolution, l)) - .andThen((l, preAnalysisResult) -> resolveInferences(preAnalysis.inferencePlans, preAnalysisResult, l)); + .andThen((l, preAnalysisResult) -> resolveInferences(preAnalysis.inferenceIds, preAnalysisResult, l)); // first resolve the lookup indices, then the main indices for (var index : preAnalysis.lookupIndices) { listener = listener.andThen((l, preAnalysisResult) -> { preAnalyzeLookupIndex(index, preAnalysisResult, l); }); @@ -588,12 +587,8 @@ private static void resolveFieldNames(LogicalPlan parsed, EnrichResolution enric } } - private void resolveInferences( - List> inferencePlans, - PreAnalysisResult preAnalysisResult, - ActionListener l - ) { - inferenceRunner.resolveInferenceIds(inferencePlans, l.map(preAnalysisResult::withInferenceResolution)); + private void resolveInferences(Set inferenceIds, PreAnalysisResult preAnalysisResult, ActionListener l) { + inferenceRunner.resolveInferenceIds(inferenceIds, l.map(preAnalysisResult::withInferenceResolution)); } static PreAnalysisResult fieldNames(LogicalPlan parsed, Set enrichPolicyMatchFields, PreAnalysisResult result) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java index fd5f607f53fb7..a016efa137956 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.analysis; import org.elasticsearch.index.IndexMode; +import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.enrich.EnrichPolicy; import org.elasticsearch.xpack.esql.EsqlTestUtils; @@ -197,11 +198,10 @@ public static InferenceResolution defaultInferenceResolution() { .build(); } - private static ResolvedInference mockedResolvedInference(String id, TaskType taskType) { - ResolvedInference resolvedInference = mock(ResolvedInference.class); - when(resolvedInference.inferenceId()).thenReturn(id); - when(resolvedInference.taskType()).thenReturn(taskType); - return resolvedInference; + private static ResolvedInference mockedResolvedInference(String inferenceId, TaskType taskType) { + ModelConfigurations modelConfigurations = mock(ModelConfigurations.class); + when(modelConfigurations.getTaskType()).thenReturn(taskType); + return new ResolvedInference(inferenceId, modelConfigurations); } public static void loadEnrichPolicyResolution( diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java index 7a5a510596878..b63a4e11e60c0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java @@ -32,6 +32,7 @@ import org.junit.Before; import java.util.List; +import java.util.Set; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.contains; @@ -68,10 +69,10 @@ public void shutdownThreadPool() { public void testResolveInferenceIds() throws Exception { InferenceRunner inferenceRunner = new InferenceRunner(mockClient(), threadPool); - List> inferencePlans = List.of(mockInferencePlan("rerank-plan")); + Set inferenceIds = Set.of("rerank-plan"); SetOnce inferenceResolutionSetOnce = new SetOnce<>(); - inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { + inferenceRunner.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { throw new RuntimeException(e); })); @@ -88,14 +89,10 @@ public void testResolveInferenceIds() throws Exception { public void testResolveMultipleInferenceIds() throws Exception { InferenceRunner inferenceRunner = new InferenceRunner(mockClient(), threadPool); - List> inferencePlans = List.of( - mockInferencePlan("rerank-plan"), - mockInferencePlan("rerank-plan"), - mockInferencePlan("completion-plan") - ); + Set inferenceIds = Set.of("rerank-plan", "completion-plan"); SetOnce inferenceResolutionSetOnce = new SetOnce<>(); - inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { + inferenceRunner.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { throw new RuntimeException(e); })); @@ -116,11 +113,11 @@ public void testResolveMultipleInferenceIds() throws Exception { public void testResolveMissingInferenceIds() throws Exception { InferenceRunner inferenceRunner = new InferenceRunner(mockClient(), threadPool); - List> inferencePlans = List.of(mockInferencePlan("missing-plan")); + Set inferenceIds = Set.of("missing-plan"); SetOnce inferenceResolutionSetOnce = new SetOnce<>(); - inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { + inferenceRunner.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { throw new RuntimeException(e); }));