diff --git a/muted-tests.yml b/muted-tests.yml index 7cb1b19316e89..975098af76fc1 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -300,9 +300,6 @@ tests: - class: org.elasticsearch.search.basic.SearchWithRandomDisconnectsIT method: testSearchWithRandomDisconnects issue: https://github.com/elastic/elasticsearch/issues/122707 -- class: org.elasticsearch.xpack.esql.inference.RerankOperatorTests - method: testSimpleCircuitBreaking - issue: https://github.com/elastic/elasticsearch/issues/124337 - class: org.elasticsearch.index.engine.ThreadPoolMergeSchedulerTests method: testSchedulerCloseWaitsForRunningMerge issue: https://github.com/elastic/elasticsearch/issues/125236 @@ -384,9 +381,6 @@ tests: - class: org.elasticsearch.packaging.test.DockerTests method: test024InstallPluginFromArchiveUsingConfigFile issue: https://github.com/elastic/elasticsearch/issues/126936 -- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT - method: test {rerank.Reranker before a limit ASYNC} - issue: https://github.com/elastic/elasticsearch/issues/127051 - class: org.elasticsearch.packaging.test.DockerTests method: test026InstallBundledRepositoryPlugins issue: https://github.com/elastic/elasticsearch/issues/127081 @@ -399,9 +393,6 @@ tests: - class: org.elasticsearch.xpack.test.rest.XPackRestIT method: test {p0=ml/data_frame_analytics_cat_apis/Test cat data frame analytics all jobs with header} issue: https://github.com/elastic/elasticsearch/issues/127625 -- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT - method: test {rerank.Reranker using another sort order ASYNC} - issue: https://github.com/elastic/elasticsearch/issues/127638 - class: org.elasticsearch.xpack.search.CrossClusterAsyncSearchIT method: testCancellationViaTimeoutWithAllowPartialResultsSetToFalse issue: https://github.com/elastic/elasticsearch/issues/127096 diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java index f2fe394b2deac..69df40899a0a8 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java @@ -67,14 +67,11 @@ import static org.elasticsearch.xpack.esql.CsvTestUtils.isEnabled; import static org.elasticsearch.xpack.esql.CsvTestUtils.loadCsvSpecValues; import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.availableDatasetsForEs; -import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.clusterHasInferenceEndpoint; -import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.clusterHasRerankInferenceEndpoint; -import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createInferenceEndpoint; -import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createRerankInferenceEndpoint; -import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteInferenceEndpoint; -import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteRerankInferenceEndpoint; +import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createInferenceEndpoints; +import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteInferenceEndpoints; import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.loadDataSetIntoEs; import static org.elasticsearch.xpack.esql.EsqlTestUtils.classpathResources; +import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.COMPLETION; import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.METRICS_COMMAND; import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.RERANK; import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SEMANTIC_TEXT_FIELD_CAPS; @@ -138,12 +135,8 @@ protected EsqlSpecTestCase( @Before public void setup() throws IOException { - if (supportsInferenceTestService() && clusterHasInferenceEndpoint(client()) == false) { - createInferenceEndpoint(client()); - } - - if (supportsInferenceTestService() && clusterHasRerankInferenceEndpoint(client()) == false) { - createRerankInferenceEndpoint(client()); + if (supportsInferenceTestService()) { + createInferenceEndpoints(adminClient()); } boolean supportsLookup = supportsIndexModeLookup(); @@ -164,8 +157,8 @@ public static void wipeTestData() throws IOException { } } - deleteInferenceEndpoint(client()); - deleteRerankInferenceEndpoint(client()); + deleteInferenceEndpoints(adminClient()); + } public boolean logResults() { @@ -254,7 +247,7 @@ protected boolean supportsInferenceTestService() { } protected boolean requiresInferenceEndpoint() { - return Stream.of(SEMANTIC_TEXT_FIELD_CAPS.capabilityName(), RERANK.capabilityName()) + return Stream.of(SEMANTIC_TEXT_FIELD_CAPS.capabilityName(), RERANK.capabilityName(), COMPLETION.capabilityName()) .anyMatch(testCase.requiredCapabilities::contains); } @@ -372,6 +365,11 @@ private Object valueMapper(CsvTestUtils.Type type, Object value) { return new BigDecimal(s).round(new MathContext(7, RoundingMode.DOWN)).doubleValue(); } } + if (type == CsvTestUtils.Type.TEXT || type == CsvTestUtils.Type.KEYWORD || type == CsvTestUtils.Type.SEMANTIC_TEXT) { + if (value instanceof String s) { + value = s.replaceAll("\\\\n", "\n"); + } + } return value.toString(); } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java index 92fe597362bb0..15b33ff4d1b73 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java @@ -27,6 +27,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; import org.elasticsearch.test.rest.ESRestTestCase; @@ -317,7 +318,7 @@ public static Set availableDatasetsForEs( boolean supportsIndexModeLookup, boolean supportsSourceFieldMapping ) throws IOException { - boolean inferenceEnabled = clusterHasInferenceEndpoint(client); + boolean inferenceEnabled = clusterHasSparseEmbeddingInferenceEndpoint(client); Set testDataSets = new HashSet<>(); @@ -379,77 +380,90 @@ private static void loadDataSetIntoEs( } } + public static void createInferenceEndpoints(RestClient client) throws IOException { + if (clusterHasSparseEmbeddingInferenceEndpoint(client) == false) { + createSparseEmbeddingInferenceEndpoint(client); + } + + if (clusterHasRerankInferenceEndpoint(client) == false) { + createRerankInferenceEndpoint(client); + } + + if (clusterHasCompletionInferenceEndpoint(client) == false) { + createCompletionInferenceEndpoint(client); + } + } + + public static void deleteInferenceEndpoints(RestClient client) throws IOException { + deleteSparseEmbeddingInferenceEndpoint(client); + deleteRerankInferenceEndpoint(client); + deleteCompletionInferenceEndpoint(client); + } + /** The semantic_text mapping type require an inference endpoint that needs to be setup before creating the index. */ - public static void createInferenceEndpoint(RestClient client) throws IOException { - Request request = new Request("PUT", "_inference/sparse_embedding/test_sparse_inference"); - request.setJsonEntity(""" + public static void createSparseEmbeddingInferenceEndpoint(RestClient client) throws IOException { + createInferenceEndpoint(client, TaskType.SPARSE_EMBEDDING, "test_sparse_inference", """ { "service": "test_service", - "service_settings": { - "model": "my_model", - "api_key": "abc64" - }, - "task_settings": { - } + "service_settings": { "model": "my_model", "api_key": "abc64" }, + "task_settings": { } } """); - client.performRequest(request); } - public static void deleteInferenceEndpoint(RestClient client) throws IOException { - try { - client.performRequest(new Request("DELETE", "_inference/test_sparse_inference")); - } catch (ResponseException e) { - // 404 here means the endpoint was not created - if (e.getResponse().getStatusLine().getStatusCode() != 404) { - throw e; - } - } + public static void deleteSparseEmbeddingInferenceEndpoint(RestClient client) throws IOException { + deleteInferenceEndpoint(client, "test_sparse_inference"); } - public static boolean clusterHasInferenceEndpoint(RestClient client) throws IOException { - Request request = new Request("GET", "_inference/sparse_embedding/test_sparse_inference"); - try { - client.performRequest(request); - } catch (ResponseException e) { - if (e.getResponse().getStatusLine().getStatusCode() == 404) { - return false; - } - throw e; - } - return true; + public static boolean clusterHasSparseEmbeddingInferenceEndpoint(RestClient client) throws IOException { + return clusterHasInferenceEndpoint(client, TaskType.SPARSE_EMBEDDING, "test_sparse_inference"); } public static void createRerankInferenceEndpoint(RestClient client) throws IOException { - Request request = new Request("PUT", "_inference/rerank/test_reranker"); - request.setJsonEntity(""" + createInferenceEndpoint(client, TaskType.RERANK, "test_reranker", """ { "service": "test_reranking_service", - "service_settings": { - "model_id": "my_model", - "api_key": "abc64" - }, - "task_settings": { - "use_text_length": true - } + "service_settings": { "model_id": "my_model", "api_key": "abc64" }, + "task_settings": { "use_text_length": true } } """); - client.performRequest(request); } public static void deleteRerankInferenceEndpoint(RestClient client) throws IOException { - try { - client.performRequest(new Request("DELETE", "_inference/rerank/test_reranker")); - } catch (ResponseException e) { - // 404 here means the endpoint was not created - if (e.getResponse().getStatusLine().getStatusCode() != 404) { - throw e; - } - } + deleteInferenceEndpoint(client, "test_reranker"); } public static boolean clusterHasRerankInferenceEndpoint(RestClient client) throws IOException { - Request request = new Request("GET", "_inference/rerank/test_reranker"); + return clusterHasInferenceEndpoint(client, TaskType.RERANK, "test_reranker"); + } + + public static void createCompletionInferenceEndpoint(RestClient client) throws IOException { + createInferenceEndpoint(client, TaskType.COMPLETION, "test_completion", """ + { + "service": "completion_test_service", + "service_settings": { "model": "my_model", "api_key": "abc64" }, + "task_settings": { "temperature": 3 } + } + """); + } + + public static void deleteCompletionInferenceEndpoint(RestClient client) throws IOException { + deleteInferenceEndpoint(client, "test_completion"); + } + + public static boolean clusterHasCompletionInferenceEndpoint(RestClient client) throws IOException { + return clusterHasInferenceEndpoint(client, TaskType.COMPLETION, "test_completion"); + } + + private static void createInferenceEndpoint(RestClient client, TaskType taskType, String inferenceId, String modelSettings) + throws IOException { + Request request = new Request("PUT", "_inference/" + taskType.name() + "/" + inferenceId); + request.setJsonEntity(modelSettings); + client.performRequest(request); + } + + private static boolean clusterHasInferenceEndpoint(RestClient client, TaskType taskType, String inferenceId) throws IOException { + Request request = new Request("GET", "_inference/" + taskType.name() + "/" + inferenceId); try { client.performRequest(request); } catch (ResponseException e) { @@ -461,6 +475,17 @@ public static boolean clusterHasRerankInferenceEndpoint(RestClient client) throw return true; } + private static void deleteInferenceEndpoint(RestClient client, String inferenceId) throws IOException { + try { + client.performRequest(new Request("DELETE", "_inference/" + inferenceId)); + } catch (ResponseException e) { + // 404 here means the endpoint was not created + if (e.getResponse().getStatusLine().getStatusCode() != 404) { + throw e; + } + } + } + private static void loadEnrichPolicy(RestClient client, String policyName, String policyFileName, Logger logger) throws IOException { URL policyMapping = getResource("/" + policyFileName); String entity = readTextFile(policyMapping); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/completion.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/completion.csv-spec new file mode 100644 index 0000000000000..bbb0278f2b021 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/completion.csv-spec @@ -0,0 +1,61 @@ +// Note: +// The "test_completion" service returns the prompt in uppercase, making the output easy to guess. + + +completion using a ROW source operator +required_capability: completion + +ROW prompt="Who is Victor Hugo?" +| COMPLETION prompt WITH test_completion AS completion_output +; + +prompt:keyword | completion_output:keyword +Who is Victor Hugo? | WHO IS VICTOR HUGO? +; + + +completion using a ROW source operator and prompt is a multi-valued field +required_capability: completion + +ROW prompt=["Answer the following question:", "Who is Victor Hugo?"] +| COMPLETION prompt WITH test_completion AS completion_output +; + +prompt:keyword | completion_output:keyword +[Answer the following question:, Who is Victor Hugo?] | ANSWER THE FOLLOWING QUESTION:\nWHO IS VICTOR HUGO? +; + + +completion after a search +required_capability: completion +required_capability: match_operator_colon + +FROM books METADATA _score +| WHERE title:"war and peace" AND author:"Tolstoy" +| SORT _score DESC +| LIMIT 2 +| COMPLETION title WITH test_completion +| KEEP title, completion +; + +title:text | completion:keyword +War and Peace | WAR AND PEACE +War and Peace (Signet Classics) | WAR AND PEACE (SIGNET CLASSICS) +; + +completion using a function as a prompt +required_capability: completion +required_capability: match_operator_colon + +FROM books METADATA _score +| WHERE title:"war and peace" AND author:"Tolstoy" +| SORT _score DESC +| LIMIT 2 +| COMPLETION CONCAT("This is a prompt: ", title) WITH test_completion +| KEEP title, completion +; + +title:text | completion:keyword +War and Peace | THIS IS A PROMPT: WAR AND PEACE +War and Peace (Signet Classics) | THIS IS A PROMPT: WAR AND PEACE (SIGNET CLASSICS) +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index a265e86adc943..08da4ba032c4c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -617,7 +617,7 @@ private LogicalPlan resolveCompletion(Completion p, List childrenOutp Expression prompt = p.prompt(); if (targetField instanceof UnresolvedAttribute ua) { - targetField = new ReferenceAttribute(ua.source(), ua.name(), TEXT); + targetField = new ReferenceAttribute(ua.source(), ua.name(), KEYWORD); } if (prompt.resolved() == false) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java new file mode 100644 index 0000000000000..fe6ab6e9a998c --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java @@ -0,0 +1,186 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.AsyncOperator; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; + +import java.util.List; + +import static org.elasticsearch.common.logging.LoggerMessageFormat.format; + +/** + * An abstract asynchronous operator that performs throttled bulk inference execution using an {@link InferenceRunner}. + *

+ * The {@code InferenceOperator} integrates with the compute framework supports throttled bulk execution of inference requests. It + * transforms input {@link Page} into inference requests, asynchronously executes them, and converts the responses into a new {@link Page}. + *

+ */ +public abstract class InferenceOperator extends AsyncOperator { + private final String inferenceId; + private final BlockFactory blockFactory; + private final BulkInferenceExecutor bulkInferenceExecutor; + + /** + * Constructs a new {@code InferenceOperator}. + * + * @param driverContext The driver context. + * @param inferenceRunner The runner used to execute inference requests. + * @param bulkExecutionConfig Configuration for inference execution. + * @param threadPool The thread pool used for executing async inference. + * @param inferenceId The ID of the inference model to use. + */ + public InferenceOperator( + DriverContext driverContext, + InferenceRunner inferenceRunner, + BulkInferenceExecutionConfig bulkExecutionConfig, + ThreadPool threadPool, + String inferenceId + ) { + super(driverContext, inferenceRunner.threadPool().getThreadContext(), bulkExecutionConfig.workers()); + this.blockFactory = driverContext.blockFactory(); + this.bulkInferenceExecutor = new BulkInferenceExecutor(inferenceRunner, threadPool, bulkExecutionConfig); + this.inferenceId = inferenceId; + } + + /** + * Returns the {@link BlockFactory} used to create output data blocks. + */ + protected BlockFactory blockFactory() { + return blockFactory; + } + + /** + * Returns the inference model ID used for this operator. + */ + protected String inferenceId() { + return inferenceId; + } + + /** + * Initiates asynchronous inferences for the given input page. + */ + @Override + protected void performAsync(Page input, ActionListener listener) { + try { + BulkInferenceRequestIterator requests = requests(input); + listener = ActionListener.releaseBefore(requests, listener); + bulkInferenceExecutor.execute(requests, listener.map(responses -> new OngoingInferenceResult(input, responses))); + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Releases resources associated with an ongoing inference. + */ + @Override + protected void releaseFetchedOnAnyThread(OngoingInferenceResult ongoingInferenceResult) { + Releasables.close(ongoingInferenceResult); + } + + /** + * Returns the next available output page constructed from completed inference results. + */ + @Override + public Page getOutput() { + OngoingInferenceResult ongoingInferenceResult = fetchFromBuffer(); + if (ongoingInferenceResult == null) { + return null; + } + + try (OutputBuilder outputBuilder = outputBuilder(ongoingInferenceResult.inputPage)) { + for (InferenceAction.Response response : ongoingInferenceResult.responses) { + outputBuilder.addInferenceResponse(response); + } + return outputBuilder.buildOutput(); + + } finally { + releaseFetchedOnAnyThread(ongoingInferenceResult); + } + } + + /** + * Converts the given input page into a sequence of inference requests. + * + * @param input The input page to process. + */ + protected abstract BulkInferenceRequestIterator requests(Page input); + + /** + * Creates a new {@link OutputBuilder} instance used to build the output page. + * + * @param input The corresponding input page used to generate the inference requests. + */ + protected abstract OutputBuilder outputBuilder(Page input); + + /** + * An interface for accumulating inference responses and constructing a result {@link Page}. + */ + public interface OutputBuilder extends Releasable { + + /** + * Adds an inference response to the output. + *

+ * The responses must be added in the same order as the corresponding inference requests were generated. + * Failing to preserve order may lead to incorrect or misaligned output rows. + *

+ * + * @param inferenceResponse The inference response to include. + */ + void addInferenceResponse(InferenceAction.Response inferenceResponse); + + /** + * Builds the final output page from accumulated inference responses. + * + * @return The constructed output page. + */ + Page buildOutput(); + + static IR inferenceResults(InferenceAction.Response inferenceResponse, Class clazz) { + InferenceServiceResults results = inferenceResponse.getResults(); + if (clazz.isInstance(results)) { + return clazz.cast(results); + } + + throw new IllegalStateException( + format("Inference result has wrong type. Got [{}] while expecting [{}]", results.getClass().getName(), clazz.getName()) + ); + } + + default void releasePageOnAnyThread(Page page) { + InferenceOperator.releasePageOnAnyThread(page); + } + } + + /** + * Represents the result of an ongoing inference operation, including the original input page + * and the list of inference responses. + * + * @param inputPage The input page used to generate inference requests. + * @param responses The inference responses returned by the inference service. + */ + public record OngoingInferenceResult(Page inputPage, List responses) implements Releasable { + + @Override + public void close() { + releasePageOnAnyThread(inputPage); + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java index 8ee930e5560df..6d6f52a17d428 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java @@ -10,8 +10,8 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.client.internal.Client; -import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.esql.core.expression.FoldContext; @@ -27,13 +27,15 @@ public class InferenceRunner { private final Client client; + private final ThreadPool threadPool; - public InferenceRunner(Client client) { + public InferenceRunner(Client client, ThreadPool threadPool) { this.client = client; + this.threadPool = threadPool; } - public ThreadContext getThreadContext() { - return client.threadPool().getThreadContext(); + public ThreadPool threadPool() { + return threadPool; } public void resolveInferenceIds(List> plans, ActionListener listener) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/RerankOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/RerankOperator.java deleted file mode 100644 index 0b5d0384c56c1..0000000000000 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/RerankOperator.java +++ /dev/null @@ -1,210 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.inference; - -import org.apache.lucene.util.BytesRef; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.lucene.BytesRefs; -import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.BytesRefBlock; -import org.elasticsearch.compute.data.DoubleBlock; -import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.operator.AsyncOperator; -import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; -import org.elasticsearch.compute.operator.Operator; -import org.elasticsearch.core.Releasable; -import org.elasticsearch.core.Releasables; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; - -import java.util.List; - -public class RerankOperator extends AsyncOperator { - - // Move to a setting. - private static final int MAX_INFERENCE_WORKER = 10; - - public record Factory( - InferenceRunner inferenceRunner, - String inferenceId, - String queryText, - ExpressionEvaluator.Factory rowEncoderFactory, - int scoreChannel - ) implements OperatorFactory { - - @Override - public String describe() { - return "RerankOperator[inference_id=[" + inferenceId + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]"; - } - - @Override - public Operator get(DriverContext driverContext) { - return new RerankOperator( - driverContext, - inferenceRunner, - inferenceId, - queryText, - rowEncoderFactory().get(driverContext), - scoreChannel - ); - } - } - - private final InferenceRunner inferenceRunner; - private final BlockFactory blockFactory; - private final String inferenceId; - private final String queryText; - private final ExpressionEvaluator rowEncoder; - private final int scoreChannel; - - public RerankOperator( - DriverContext driverContext, - InferenceRunner inferenceRunner, - String inferenceId, - String queryText, - ExpressionEvaluator rowEncoder, - int scoreChannel - ) { - super(driverContext, inferenceRunner.getThreadContext(), MAX_INFERENCE_WORKER); - - assert inferenceRunner.getThreadContext() != null; - - this.blockFactory = driverContext.blockFactory(); - this.inferenceRunner = inferenceRunner; - this.inferenceId = inferenceId; - this.queryText = queryText; - this.rowEncoder = rowEncoder; - this.scoreChannel = scoreChannel; - } - - @Override - protected void performAsync(Page inputPage, ActionListener listener) { - // Ensure input page blocks are released when the listener is called. - listener = listener.delegateResponse((l, e) -> { - releasePageOnAnyThread(inputPage); - l.onFailure(e); - }); - try { - inferenceRunner.doInference(buildInferenceRequest(inputPage), listener.map(resp -> new OngoingRerank(inputPage, resp))); - } catch (Exception e) { - listener.onFailure(e); - } - } - - @Override - protected void doClose() { - Releasables.closeExpectNoException(rowEncoder); - } - - @Override - protected void releaseFetchedOnAnyThread(OngoingRerank result) { - releasePageOnAnyThread(result.inputPage); - } - - @Override - public Page getOutput() { - var fetched = fetchFromBuffer(); - if (fetched == null) { - return null; - } else { - return fetched.buildOutput(blockFactory, scoreChannel); - } - } - - @Override - public String toString() { - return "RerankOperator[inference_id=[" + inferenceId + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]"; - } - - private InferenceAction.Request buildInferenceRequest(Page inputPage) { - try (BytesRefBlock encodedRowsBlock = (BytesRefBlock) rowEncoder.eval(inputPage)) { - assert (encodedRowsBlock.getPositionCount() == inputPage.getPositionCount()); - String[] inputs = new String[inputPage.getPositionCount()]; - BytesRef buffer = new BytesRef(); - - for (int pos = 0; pos < inputPage.getPositionCount(); pos++) { - if (encodedRowsBlock.isNull(pos)) { - inputs[pos] = ""; - } else { - buffer = encodedRowsBlock.getBytesRef(encodedRowsBlock.getFirstValueIndex(pos), buffer); - inputs[pos] = BytesRefs.toString(buffer); - } - } - - return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build(); - } - } - - public static final class OngoingRerank { - final Page inputPage; - final Double[] rankedScores; - - OngoingRerank(Page inputPage, InferenceAction.Response resp) { - if (resp.getResults() instanceof RankedDocsResults == false) { - releasePageOnAnyThread(inputPage); - throw new IllegalStateException( - "Inference result has wrong type. Got [" - + resp.getResults().getClass() - + "] while expecting [" - + RankedDocsResults.class - + "]" - ); - - } - final var results = (RankedDocsResults) resp.getResults(); - this.inputPage = inputPage; - this.rankedScores = extractRankedScores(inputPage.getPositionCount(), results); - } - - private static Double[] extractRankedScores(int positionCount, RankedDocsResults rankedDocsResults) { - Double[] sortedRankedDocsScores = new Double[positionCount]; - for (RankedDocsResults.RankedDoc rankedDoc : rankedDocsResults.getRankedDocs()) { - sortedRankedDocsScores[rankedDoc.index()] = (double) rankedDoc.relevanceScore(); - } - return sortedRankedDocsScores; - } - - Page buildOutput(BlockFactory blockFactory, int scoreChannel) { - int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1); - Block[] blocks = new Block[blockCount]; - Page outputPage = null; - try (Releasable ignored = inputPage::releaseBlocks) { - for (int b = 0; b < blockCount; b++) { - if (b == scoreChannel) { - blocks[b] = buildScoreBlock(blockFactory); - } else { - blocks[b] = inputPage.getBlock(b); - blocks[b].incRef(); - } - } - outputPage = new Page(blocks); - return outputPage; - } finally { - if (outputPage == null) { - Releasables.closeExpectNoException(blocks); - } - } - } - - private Block buildScoreBlock(BlockFactory blockFactory) { - try (DoubleBlock.Builder scoreBlockFactory = blockFactory.newDoubleBlockBuilder(rankedScores.length)) { - for (Double rankedScore : rankedScores) { - if (rankedScore != null) { - scoreBlockFactory.appendDouble(rankedScore); - } else { - scoreBlockFactory.appendNull(); - } - } - return scoreBlockFactory.build(); - } - } - } -} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionConfig.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionConfig.java new file mode 100644 index 0000000000000..8bc48a908fe22 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionConfig.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.bulk; + +public record BulkInferenceExecutionConfig(int workers, int maxOutstandingRequests) { + public static final int DEFAULT_WORKERS = 10; + public static final int DEFAULT_MAX_OUTSTANDING_REQUESTS = 50; + + public static final BulkInferenceExecutionConfig DEFAULT = new BulkInferenceExecutionConfig( + DEFAULT_WORKERS, + DEFAULT_MAX_OUTSTANDING_REQUESTS + ); +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionState.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionState.java new file mode 100644 index 0000000000000..307dae6c425c2 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionState.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.bulk; + +import org.elasticsearch.compute.operator.FailureCollector; +import org.elasticsearch.index.seqno.LocalCheckpointTracker; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.elasticsearch.index.seqno.SequenceNumbers.NO_OPS_PERFORMED; + +/** + * Tracks the state of a bulk inference execution, including sequencing, failure management, and buffering of inference responses for + * ordered output construction. + */ +public class BulkInferenceExecutionState { + private final LocalCheckpointTracker checkpoint = new LocalCheckpointTracker(NO_OPS_PERFORMED, NO_OPS_PERFORMED); + private final FailureCollector failureCollector = new FailureCollector(); + private final Map bufferedResponses; + private final AtomicBoolean finished = new AtomicBoolean(false); + + public BulkInferenceExecutionState(int bufferSize) { + this.bufferedResponses = new ConcurrentHashMap<>(bufferSize); + } + + /** + * Generates a new unique sequence number for an inference request. + */ + public long generateSeqNo() { + return checkpoint.generateSeqNo(); + } + + /** + * Returns the highest sequence number marked as persisted, such that all lower sequence numbers have also been marked as persisted. + */ + public long getPersistedCheckpoint() { + return checkpoint.getPersistedCheckpoint(); + } + + /** + * Returns the highest sequence number marked as processed, such that all lower sequence numbers have also been marked as processed. + */ + public long getProcessedCheckpoint() { + return checkpoint.getProcessedCheckpoint(); + } + + /** + * Highest generated sequence number. + */ + public long getMaxSeqNo() { + return checkpoint.getMaxSeqNo(); + } + + /** + * Marks an inference response as persisted. + * + * @param seqNo The corresponding sequence number + */ + public void markSeqNoAsPersisted(long seqNo) { + checkpoint.markSeqNoAsPersisted(seqNo); + } + + /** + * Add an inference response to the buffer and marks the corresponding sequence number as processed. + * + * @param seqNo The sequence number of the inference request. + * @param response The inference response. + */ + public synchronized void onInferenceResponse(long seqNo, InferenceAction.Response response) { + if (failureCollector.hasFailure() == false) { + bufferedResponses.put(seqNo, response); + } + checkpoint.markSeqNoAsProcessed(seqNo); + } + + /** + * * Handles an exception thrown during inference execution. + * Records the failure and marks the corresponding sequence number as processed. + * + * @param seqNo The sequence number of the inference request. + * @param e The exception + */ + public synchronized void onInferenceException(long seqNo, Exception e) { + failureCollector.unwrapAndCollect(e); + checkpoint.markSeqNoAsProcessed(seqNo); + bufferedResponses.clear(); + } + + /** + * Retrieves and removes the buffered response by sequence number. + * + * @param seqNo The sequence number of the response to fetch. + */ + public synchronized InferenceAction.Response fetchBufferedResponse(long seqNo) { + return bufferedResponses.remove(seqNo); + } + + /** + * Returns whether any failure has been recorded during execution. + */ + public boolean hasFailure() { + return failureCollector.hasFailure(); + } + + /** + * Returns the recorded failure, if any. + */ + public Exception getFailure() { + return failureCollector.getFailure(); + } + + public void addFailure(Exception e) { + failureCollector.unwrapAndCollect(e); + } + + /** + * Indicates whether the entire bulk execution is marked as finished and all responses have been successfully persisted. + */ + public boolean finished() { + return finished.get() && getMaxSeqNo() == getPersistedCheckpoint(); + } + + /** + * Marks the bulk as finished, indicating that all inference requests have been sent. + */ + public void finish() { + this.finished.set(true); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java new file mode 100644 index 0000000000000..d05a9a57d5265 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java @@ -0,0 +1,258 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.bulk; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Executes a sequence of inference requests in bulk with throttling and concurrency control. + */ +public class BulkInferenceExecutor { + private final ThrottledInferenceRunner throttledInferenceRunner; + private final BulkInferenceExecutionConfig bulkExecutionConfig; + + /** + * Constructs a new {@code BulkInferenceExecutor}. + * + * @param inferenceRunner The inference runner used to execute individual inference requests. + * @param threadPool The thread pool for executing inference tasks. + * @param bulkExecutionConfig Configuration options (throttling and concurrency limits). + */ + public BulkInferenceExecutor(InferenceRunner inferenceRunner, ThreadPool threadPool, BulkInferenceExecutionConfig bulkExecutionConfig) { + this.throttledInferenceRunner = ThrottledInferenceRunner.create(inferenceRunner, executorService(threadPool), bulkExecutionConfig); + this.bulkExecutionConfig = bulkExecutionConfig; + } + + /** + * Executes the provided bulk inference requests. + *

+ * Each request is sent to the {@link ThrottledInferenceRunner} to be executed. + * The final listener is notified with all successful responses once all requests are completed. + * + * @param requests An iterator over the inference requests to be executed. + * @param listener A listener notified with the complete list of responses or a failure. + */ + public void execute(BulkInferenceRequestIterator requests, ActionListener> listener) { + if (requests.hasNext() == false) { + listener.onResponse(List.of()); + return; + } + + final BulkInferenceExecutionState bulkExecutionState = new BulkInferenceExecutionState( + bulkExecutionConfig.maxOutstandingRequests() + ); + final ResponseHandler responseHandler = new ResponseHandler(bulkExecutionState, listener, requests.estimatedSize()); + + while (bulkExecutionState.finished() == false && requests.hasNext()) { + InferenceAction.Request request = requests.next(); + long seqNo = bulkExecutionState.generateSeqNo(); + + if (requests.hasNext() == false) { + bulkExecutionState.finish(); + } + + throttledInferenceRunner.doInference( + request, + ActionListener.runAfter( + ActionListener.wrap( + r -> bulkExecutionState.onInferenceResponse(seqNo, r), + e -> bulkExecutionState.onInferenceException(seqNo, e) + ), + responseHandler::persistPendingResponses + ) + ); + } + } + + /** + * Handles collection and delivery of inference responses once they are complete. + */ + private static class ResponseHandler { + private final List responses; + private final ActionListener> listener; + private final BulkInferenceExecutionState bulkExecutionState; + private final AtomicBoolean responseSent = new AtomicBoolean(false); + + private ResponseHandler( + BulkInferenceExecutionState bulkExecutionState, + ActionListener> listener, + int estimatedSize + ) { + this.listener = listener; + this.bulkExecutionState = bulkExecutionState; + this.responses = new ArrayList<>(estimatedSize); + } + + /** + * Persists all buffered responses that can be delivered in order, and sends the final response if all requests are finished. + */ + public synchronized void persistPendingResponses() { + long persistedSeqNo = bulkExecutionState.getPersistedCheckpoint(); + + while (persistedSeqNo < bulkExecutionState.getProcessedCheckpoint()) { + persistedSeqNo++; + if (bulkExecutionState.hasFailure() == false) { + try { + InferenceAction.Response response = bulkExecutionState.fetchBufferedResponse(persistedSeqNo); + assert response != null; + responses.add(response); + } catch (Exception e) { + bulkExecutionState.addFailure(e); + } + } + bulkExecutionState.markSeqNoAsPersisted(persistedSeqNo); + } + + sendResponseOnCompletion(); + } + + /** + * Sends the final response or failure once all inference tasks have completed. + */ + private void sendResponseOnCompletion() { + if (bulkExecutionState.finished() && responseSent.compareAndSet(false, true)) { + if (bulkExecutionState.hasFailure() == false) { + try { + listener.onResponse(responses); + return; + } catch (Exception e) { + bulkExecutionState.addFailure(e); + } + } + + listener.onFailure(bulkExecutionState.getFailure()); + } + } + } + + /** + * Manages throttled inference tasks execution. + */ + private static class ThrottledInferenceRunner { + private final InferenceRunner inferenceRunner; + private final ExecutorService executorService; + private final BlockingQueue pendingRequestsQueue; + private final Semaphore permits; + + private ThrottledInferenceRunner(InferenceRunner inferenceRunner, ExecutorService executorService, int maxRunningTasks) { + this.executorService = executorService; + this.permits = new Semaphore(maxRunningTasks); + this.inferenceRunner = inferenceRunner; + this.pendingRequestsQueue = new ArrayBlockingQueue<>(maxRunningTasks); + } + + /** + * Creates a new {@code ThrottledInferenceRunner} with the specified configuration. + * + * @param inferenceRunner TThe inference runner used to execute individual inference requests. + * @param executorService The executor used for asynchronous execution. + * @param bulkExecutionConfig Configuration options (throttling and concurrency limits). + */ + public static ThrottledInferenceRunner create( + InferenceRunner inferenceRunner, + ExecutorService executorService, + BulkInferenceExecutionConfig bulkExecutionConfig + ) { + return new ThrottledInferenceRunner(inferenceRunner, executorService, bulkExecutionConfig.maxOutstandingRequests()); + } + + /** + * Schedules the inference task for execution. If a permit is available, the task runs immediately; otherwise, it is queued. + * + * @param request The inference request. + * @param listener The listener to notify on response or failure. + */ + public void doInference(InferenceAction.Request request, ActionListener listener) { + enqueueTask(request, listener); + executePendingRequests(); + } + + /** + * Attempts to execute as many pending inference tasks as possible, limited by available permits. + */ + private void executePendingRequests() { + while (permits.tryAcquire()) { + AbstractRunnable task = pendingRequestsQueue.poll(); + + if (task == null) { + permits.release(); + return; + } + + try { + executorService.execute(task); + } catch (Exception e) { + task.onFailure(e); + permits.release(); + } + } + } + + /** + * Add an inference task to the queue. + * + * @param request The inference request. + * * @param listener The listener to notify on response or failure. + */ + private void enqueueTask(InferenceAction.Request request, ActionListener listener) { + try { + pendingRequestsQueue.put(createTask(request, listener)); + } catch (Exception e) { + listener.onFailure(new IllegalStateException("An error occurred while adding the inference request to the queue", e)); + } + } + + /** + * Wraps an inference request into an {@link AbstractRunnable} that releases its permit on completion and triggers any remaining + * queued tasks. + * + * @param request The inference request. + * @param listener The listener to notify on completion. + * @return A runnable task encapsulating the request. + */ + private AbstractRunnable createTask(InferenceAction.Request request, ActionListener listener) { + final ActionListener completionListener = ActionListener.runAfter(listener, () -> { + permits.release(); + executePendingRequests(); + }); + + return new AbstractRunnable() { + @Override + protected void doRun() { + try { + inferenceRunner.doInference(request, completionListener); + } catch (Throwable e) { + listener.onFailure(new RuntimeException("Unexpected failure while running inference", e)); + } + } + + @Override + public void onFailure(Exception e) { + completionListener.onFailure(e); + } + }; + } + } + + private static ExecutorService executorService(ThreadPool threadPool) { + return threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestIterator.java new file mode 100644 index 0000000000000..7327b182d0b6c --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestIterator.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.bulk; + +import org.elasticsearch.core.Releasable; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; + +import java.util.Iterator; + +public interface BulkInferenceRequestIterator extends Iterator, Releasable { + + /** + * Returns an estimate of the number of requests that will be produced. + * + *

This is typically used to pre-allocate buffers or output to th appropriate size.

+ */ + int estimatedSize(); + +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java new file mode 100644 index 0000000000000..e53fda90c88b3 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.completion; + +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.esql.inference.InferenceOperator; +import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; + +import java.util.stream.IntStream; + +/** + * {@link CompletionOperator} is an {@link InferenceOperator} that performs inference using prompt-based model (e.g., text completion). + * It evaluates a prompt expression for each input row, constructs inference requests, and emits the model responses as output. + */ +public class CompletionOperator extends InferenceOperator { + + private final ExpressionEvaluator promptEvaluator; + + public CompletionOperator( + DriverContext driverContext, + InferenceRunner inferenceRunner, + ThreadPool threadPool, + String inferenceId, + ExpressionEvaluator promptEvaluator + ) { + super(driverContext, inferenceRunner, BulkInferenceExecutionConfig.DEFAULT, threadPool, inferenceId); + this.promptEvaluator = promptEvaluator; + } + + @Override + protected void doClose() { + Releasables.close(promptEvaluator); + } + + @Override + public String toString() { + return "CompletionOperator[inference_id=[" + inferenceId() + "]]"; + } + + @Override + public void addInput(Page input) { + try { + super.addInput(input.appendBlock(promptEvaluator.eval(input))); + } catch (Exception e) { + releasePageOnAnyThread(input); + throw e; + } + } + + /** + * Constructs the completion inference requests iterator for the given input page by evaluating the prompt expression. + * + * @param inputPage The input data page. + */ + @Override + protected BulkInferenceRequestIterator requests(Page inputPage) { + int inputBlockChannel = inputPage.getBlockCount() - 1; + return new CompletionOperatorRequestIterator(inputPage.getBlock(inputBlockChannel), inferenceId()); + } + + /** + * Creates a new {@link CompletionOperatorOutputBuilder} to collect and emit the completion results. + * + * @param input The input page for which results will be constructed. + */ + @Override + protected CompletionOperatorOutputBuilder outputBuilder(Page input) { + BytesRefBlock.Builder outputBlockBuilder = blockFactory().newBytesRefBlockBuilder(input.getPositionCount()); + return new CompletionOperatorOutputBuilder( + outputBlockBuilder, + input.projectBlocks(IntStream.range(0, input.getBlockCount() - 1).toArray()) + ); + } + + /** + * Factory for creating {@link CompletionOperator} instances. + */ + public record Factory(InferenceRunner inferenceRunner, String inferenceId, ExpressionEvaluator.Factory promptEvaluatorFactory) + implements + OperatorFactory { + @Override + public String describe() { + return "CompletionOperator[inference_id=[" + inferenceId + "]]"; + } + + @Override + public Operator get(DriverContext driverContext) { + return new CompletionOperator( + driverContext, + inferenceRunner, + inferenceRunner.threadPool(), + inferenceId, + promptEvaluatorFactory.get(driverContext) + ); + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java new file mode 100644 index 0000000000000..d44a13786437a --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.completion; + +import org.apache.lucene.util.BytesRefBuilder; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.esql.inference.InferenceOperator; + +/** + * {@link CompletionOperatorOutputBuilder} builds the output page for {@link CompletionOperator} by converting {@link ChatCompletionResults} + * into a {@link BytesRefBlock}. + */ +public class CompletionOperatorOutputBuilder implements InferenceOperator.OutputBuilder { + private final Page inputPage; + private final BytesRefBlock.Builder outputBlockBuilder; + private final BytesRefBuilder bytesRefBuilder = new BytesRefBuilder(); + + public CompletionOperatorOutputBuilder(BytesRefBlock.Builder outputBlockBuilder, Page inputPage) { + this.inputPage = inputPage; + this.outputBlockBuilder = outputBlockBuilder; + } + + @Override + public void close() { + Releasables.close(outputBlockBuilder); + releasePageOnAnyThread(inputPage); + } + + /** + * Adds an inference response to the output builder. + * + *

+ * If the response is null or not of type {@link ChatCompletionResults} an {@link IllegalStateException} is thrown. + * Else, the result text is added to the output block. + *

+ * + *

+ * The responses must be added in the same order as the corresponding inference requests were generated. + * Failing to preserve order may lead to incorrect or misaligned output rows. + *

+ */ + @Override + public void addInferenceResponse(InferenceAction.Response inferenceResponse) { + ChatCompletionResults completionResults = inferenceResults(inferenceResponse); + + if (completionResults == null) { + throw new IllegalStateException("Received null inference result; expected a non-null result of type ChatCompletionResults"); + } + + outputBlockBuilder.beginPositionEntry(); + for (ChatCompletionResults.Result completionResult : completionResults.getResults()) { + bytesRefBuilder.copyChars(completionResult.content()); + outputBlockBuilder.appendBytesRef(bytesRefBuilder.get()); + bytesRefBuilder.clear(); + } + outputBlockBuilder.endPositionEntry(); + } + + /** + * Builds the final output page by appending the completion output block to a shallow copy of the input page. + */ + @Override + public Page buildOutput() { + Block outputBlock = outputBlockBuilder.build(); + assert outputBlock.getPositionCount() == inputPage.getPositionCount(); + return inputPage.shallowCopy().appendBlock(outputBlock); + } + + private ChatCompletionResults inferenceResults(InferenceAction.Response inferenceResponse) { + return InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, ChatCompletionResults.class); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java new file mode 100644 index 0000000000000..d7755a310098a --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.completion; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; + +import java.util.List; +import java.util.NoSuchElementException; + +/** + * This iterator reads prompts from a {@link BytesRefBlock} and converts them into individual {@link InferenceAction.Request} instances + * of type {@link TaskType#COMPLETION}. + */ +public class CompletionOperatorRequestIterator implements BulkInferenceRequestIterator { + + private final PromptReader promptReader; + private final String inferenceId; + private final int size; + private int currentPos = 0; + + /** + * Constructs a new iterator from the given block of prompts. + * + * @param promptBlock The input block containing prompts. + * @param inferenceId The ID of the inference model to invoke. + */ + public CompletionOperatorRequestIterator(BytesRefBlock promptBlock, String inferenceId) { + this.promptReader = new PromptReader(promptBlock); + this.size = promptBlock.getPositionCount(); + this.inferenceId = inferenceId; + } + + @Override + public boolean hasNext() { + return currentPos < size; + } + + @Override + public InferenceAction.Request next() { + if (hasNext() == false) { + throw new NoSuchElementException(); + } + return inferenceRequest(promptReader.readPrompt(currentPos++)); + } + + /** + * Wraps a single prompt string into an {@link InferenceAction.Request}. + */ + private InferenceAction.Request inferenceRequest(String prompt) { + return InferenceAction.Request.builder(inferenceId, TaskType.COMPLETION).setInput(List.of(prompt)).build(); + } + + @Override + public int estimatedSize() { + return promptReader.estimatedSize(); + } + + @Override + public void close() { + Releasables.close(promptReader); + } + + /** + * Helper class that reads prompts from a {@link BytesRefBlock}. + */ + private static class PromptReader implements Releasable { + private final BytesRefBlock promptBlock; + private final StringBuilder strBuilder = new StringBuilder(); + private BytesRef readBuffer = new BytesRef(); + + private PromptReader(BytesRefBlock promptBlock) { + this.promptBlock = promptBlock; + } + + /** + * Reads the prompt string at the given position.. + * + * @param pos the position index in the block + */ + public String readPrompt(int pos) { + if (promptBlock.isNull(pos)) { + return null; + } + + strBuilder.setLength(0); + + for (int valueIndex = 0; valueIndex < promptBlock.getValueCount(pos); valueIndex++) { + readBuffer = promptBlock.getBytesRef(promptBlock.getFirstValueIndex(pos) + valueIndex, readBuffer); + strBuilder.append(readBuffer.utf8ToString()); + if (valueIndex != promptBlock.getValueCount(pos) - 1) { + strBuilder.append("\n"); + } + } + + return strBuilder.toString(); + } + + /** + * Returns the total number of positions (prompts) in the block. + */ + public int estimatedSize() { + return promptBlock.getPositionCount(); + } + + @Override + public void close() { + + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java new file mode 100644 index 0000000000000..ca628fdba8a8f --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java @@ -0,0 +1,128 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.rerank; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.esql.inference.InferenceOperator; +import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig; + +import java.util.stream.IntStream; + +/** + * {@link RerankOperator} is an inference operator that compute scores for rows using a reranking model. + */ +public class RerankOperator extends InferenceOperator { + + // Default number of rows to include per inference request + private static final int DEFAULT_BATCH_SIZE = 20; + private final String queryText; + + // Encodes each input row into a string representation for the model + private final ExpressionEvaluator rowEncoder; + private final int scoreChannel; + + // Batch size used to group rows into a single inference request (currently fixed) + // TODO: make it configurable either in the command or as query pragmas + private final int batchSize = DEFAULT_BATCH_SIZE; + + public RerankOperator( + DriverContext driverContext, + InferenceRunner inferenceRunner, + ThreadPool threadPool, + String inferenceId, + String queryText, + ExpressionEvaluator rowEncoder, + int scoreChannel + ) { + super(driverContext, inferenceRunner, BulkInferenceExecutionConfig.DEFAULT, threadPool, inferenceId); + this.queryText = queryText; + this.rowEncoder = rowEncoder; + this.scoreChannel = scoreChannel; + } + + @Override + public void addInput(Page input) { + try { + Block inputBlock = rowEncoder.eval(input); + super.addInput(input.appendBlock(inputBlock)); + } catch (Exception e) { + releasePageOnAnyThread(input); + throw e; + } + } + + @Override + protected void doClose() { + Releasables.close(rowEncoder); + } + + @Override + public String toString() { + return "RerankOperator[inference_id=[" + inferenceId() + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]"; + } + + /** + * Returns the request iterator responsible for batching and converting input rows into inference requests. + */ + @Override + protected RerankOperatorRequestIterator requests(Page inputPage) { + int inputBlockChannel = inputPage.getBlockCount() - 1; + return new RerankOperatorRequestIterator(inputPage.getBlock(inputBlockChannel), inferenceId(), queryText, batchSize); + } + + /** + * Returns the output builder responsible for collecting inference responses and building the output page. + */ + @Override + protected RerankOperatorOutputBuilder outputBuilder(Page input) { + DoubleBlock.Builder outputBlockBuilder = blockFactory().newDoubleBlockBuilder(input.getPositionCount()); + return new RerankOperatorOutputBuilder( + outputBlockBuilder, + input.projectBlocks(IntStream.range(0, input.getBlockCount() - 1).toArray()), + scoreChannel + ); + } + + /** + * Factory for creating {@link RerankOperator} instances + */ + public record Factory( + InferenceRunner inferenceRunner, + String inferenceId, + String queryText, + ExpressionEvaluator.Factory rowEncoderFactory, + int scoreChannel + ) implements OperatorFactory { + + @Override + public String describe() { + return "RerankOperator[inference_id=[" + inferenceId + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]"; + } + + @Override + public Operator get(DriverContext driverContext) { + return new RerankOperator( + driverContext, + inferenceRunner, + inferenceRunner.threadPool(), + inferenceId, + queryText, + rowEncoderFactory().get(driverContext), + scoreChannel + ); + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java new file mode 100644 index 0000000000000..1813aa3e9fb59 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java @@ -0,0 +1,94 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.rerank; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.esql.inference.InferenceOperator; + +import java.util.Comparator; +import java.util.Iterator; + +/** + * Builds the output page for the {@link RerankOperator} by adding + * * reranked relevance scores into the specified score channel of the input page. + */ + +public class RerankOperatorOutputBuilder implements InferenceOperator.OutputBuilder { + + private final Page inputPage; + private final DoubleBlock.Builder scoreBlockBuilder; + private final int scoreChannel; + + public RerankOperatorOutputBuilder(DoubleBlock.Builder scoreBlockBuilder, Page inputPage, int scoreChannel) { + this.inputPage = inputPage; + this.scoreBlockBuilder = scoreBlockBuilder; + this.scoreChannel = scoreChannel; + } + + @Override + public void close() { + Releasables.close(scoreBlockBuilder); + releasePageOnAnyThread(inputPage); + } + + /** + * Constructs a new output {@link Page} which contains all original blocks from the input page, with the reranked scores + * inserted at {@code scoreChannel}. + */ + @Override + public Page buildOutput() { + int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1); + Block[] blocks = new Block[blockCount]; + + try { + for (int b = 0; b < blockCount; b++) { + if (b == scoreChannel) { + blocks[b] = scoreBlockBuilder.build(); + } else { + blocks[b] = inputPage.getBlock(b); + blocks[b].incRef(); + } + } + return new Page(blocks); + } catch (Exception e) { + Releasables.close(blocks); + throw (e); + } + } + + /** + * Extracts the ranked document results from the inference response and appends their relevance scores to the score block builder. + *

+ * If the response is not of type {@link ChatCompletionResults} an {@link IllegalStateException} is thrown. + *

+ *

+ * The responses must be added in the same order as the corresponding inference requests were generated. + * Failing to preserve order may lead to incorrect or misaligned output rows. + *

+ */ + @Override + public void addInferenceResponse(InferenceAction.Response inferenceResponse) { + Iterator sortedRankedDocIterator = inferenceResults(inferenceResponse).getRankedDocs() + .stream() + .sorted(Comparator.comparingInt(RankedDocsResults.RankedDoc::index)) + .iterator(); + while (sortedRankedDocIterator.hasNext()) { + scoreBlockBuilder.appendDouble(sortedRankedDocIterator.next().relevanceScore()); + } + } + + private RankedDocsResults inferenceResults(InferenceAction.Response inferenceResponse) { + return InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, RankedDocsResults.class); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java new file mode 100644 index 0000000000000..3e73bcc8bea1f --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.rerank; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.lucene.BytesRefs; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; + +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; + +/** + * Iterator over input data blocks to create batched inference requests for the Rerank task. + * + *

This iterator reads from a {@link BytesRefBlock} containing input documents or items to be reranked. It slices the input into batches + * of configurable size and converts each batch into an {@link InferenceAction.Request} with the task type {@link TaskType#RERANK}. + */ +public class RerankOperatorRequestIterator implements BulkInferenceRequestIterator { + private final BytesRefBlock inputBlock; + private final String inferenceId; + private final String queryText; + private final int batchSize; + private int remainingPositions; + + public RerankOperatorRequestIterator(BytesRefBlock inputBlock, String inferenceId, String queryText, int batchSize) { + this.inputBlock = inputBlock; + this.inferenceId = inferenceId; + this.queryText = queryText; + this.batchSize = batchSize; + this.remainingPositions = inputBlock.getPositionCount(); + } + + @Override + public boolean hasNext() { + return remainingPositions > 0; + } + + @Override + public InferenceAction.Request next() { + if (hasNext() == false) { + throw new NoSuchElementException(); + } + + final int inputSize = Math.min(remainingPositions, batchSize); + final List inputs = new ArrayList<>(inputSize); + BytesRef scratch = new BytesRef(); + + int startIndex = inputBlock.getPositionCount() - remainingPositions; + for (int i = 0; i < inputSize; i++) { + int pos = startIndex + i; + if (inputBlock.isNull(pos)) { + inputs.add(""); + } else { + scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(pos), scratch); + inputs.add(BytesRefs.toString(scratch)); + } + } + + remainingPositions -= inputSize; + return inferenceRequest(inputs); + } + + @Override + public int estimatedSize() { + return inputBlock.getPositionCount(); + } + + private InferenceAction.Request inferenceRequest(List inputs) { + return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(inputs).setQuery(queryText).build(); + } + + @Override + public void close() { + + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java index 43e8e871cd021..a577229f51aef 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java @@ -138,7 +138,7 @@ protected AttributeSet computeReferences() { @Override public boolean expressionsResolved() { - return super.expressionsResolved() && prompt.resolved(); + return super.expressionsResolved() && prompt.resolved() && targetField.resolved(); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java index 3e66a55b58d94..412bdbaa6914e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java @@ -86,8 +86,9 @@ import org.elasticsearch.xpack.esql.evaluator.command.GrokEvaluatorExtracter; import org.elasticsearch.xpack.esql.expression.Order; import org.elasticsearch.xpack.esql.inference.InferenceRunner; -import org.elasticsearch.xpack.esql.inference.RerankOperator; import org.elasticsearch.xpack.esql.inference.XContentRowEncoder; +import org.elasticsearch.xpack.esql.inference.completion.CompletionOperator; +import org.elasticsearch.xpack.esql.inference.rerank.RerankOperator; import org.elasticsearch.xpack.esql.plan.logical.Fork; import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; import org.elasticsearch.xpack.esql.plan.physical.ChangePointExec; @@ -116,6 +117,7 @@ import org.elasticsearch.xpack.esql.plan.physical.ShowExec; import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesSourceExec; import org.elasticsearch.xpack.esql.plan.physical.TopNExec; +import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec; import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec; import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders.ShardContext; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; @@ -262,9 +264,12 @@ private PhysicalOperation plan(PhysicalPlan node, LocalExecutionPlannerContext c return planRerank(rerank, context); } else if (node instanceof ChangePointExec changePoint) { return planChangePoint(changePoint, context); + } else if (node instanceof CompletionExec completion) { + return planCompletion(completion, context); } else if (node instanceof SampleExec Sample) { return planSample(Sample, context); } + // source nodes else if (node instanceof EsQueryExec esQuery) { return planEsQueryNode(esQuery, context); @@ -301,6 +306,19 @@ else if (node instanceof OutputExec outputExec) { throw new EsqlIllegalArgumentException("unknown physical plan node [" + node.nodeName() + "]"); } + private PhysicalOperation planCompletion(CompletionExec completion, LocalExecutionPlannerContext context) { + PhysicalOperation source = plan(completion.child(), context); + String inferenceId = BytesRefs.toString(completion.inferenceId().fold(context.foldCtx())); + Layout outputLayout = source.layout.builder().append(completion.targetField()).build(); + EvalOperator.ExpressionEvaluator.Factory promptEvaluatorFactory = EvalMapper.toEvaluator( + context.foldCtx(), + completion.prompt(), + source.layout + ); + + return source.with(new CompletionOperator.Factory(inferenceRunner, inferenceId, promptEvaluatorFactory), outputLayout); + } + private PhysicalOperation planRrfScoreEvalExec(RrfScoreEvalExec rrf, LocalExecutionPlannerContext context) { PhysicalOperation source = plan(rrf.child(), context); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java index 33146991609eb..fe8d200992b0d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java @@ -159,7 +159,7 @@ public TransportEsqlQueryAction( projectResolver, indexNameExpressionResolver, usageService, - new InferenceRunner(client) + new InferenceRunner(client, threadPool) ); this.computeService = new ComputeService( diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index bf25feb9db553..15e4a6708e8d7 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -268,6 +268,10 @@ public final void test() throws Throwable { "can't use rereank in csv tests", testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.RERANK.capabilityName()) ); + assumeFalse( + "can't use completion in csv tests", + testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.COMPLETION.capabilityName()) + ); assumeFalse( "can't use match in csv tests", testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.MATCH_OPERATOR_COLON.capabilityName()) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 54dde6d716217..20ba27dd20ea2 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -3764,7 +3764,7 @@ public void testResolveCompletionTargetField() { """, "mapping-books.json"); Completion completion = as(as(plan, Limit.class).child(), Completion.class); - assertThat(completion.targetField(), equalTo(referenceAttribute("translation", DataType.TEXT))); + assertThat(completion.targetField(), equalTo(referenceAttribute("translation", DataType.KEYWORD))); } public void testResolveCompletionDefaultTargetField() { @@ -3776,7 +3776,7 @@ public void testResolveCompletionDefaultTargetField() { """, "mapping-books.json"); Completion completion = as(as(plan, Limit.class).child(), Completion.class); - assertThat(completion.targetField(), equalTo(referenceAttribute("completion", DataType.TEXT))); + assertThat(completion.targetField(), equalTo(referenceAttribute("completion", DataType.KEYWORD))); } public void testResolveCompletionPrompt() { @@ -3814,7 +3814,7 @@ public void testResolveCompletionOutputField() { """, "mapping-books.json"); Completion completion = as(as(plan, Limit.class).child(), Completion.class); - assertThat(completion.targetField(), equalTo(referenceAttribute("description", DataType.TEXT))); + assertThat(completion.targetField(), equalTo(referenceAttribute("description", DataType.KEYWORD))); EsRelation esRelation = as(completion.child(), EsRelation.class); assertThat(getAttributeByName(completion.output(), "description"), equalTo(completion.targetField())); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java new file mode 100644 index 0000000000000..900e17f724156 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java @@ -0,0 +1,208 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.logging.LoggerMessageFormat; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockUtils; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.AsyncOperator; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.compute.test.AbstractBlockSourceOperator; +import org.elasticsearch.compute.test.OperatorTestCase; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.FixedExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; +import org.junit.After; +import org.junit.Before; + +import java.util.function.BiFunction; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.notNullValue; + +public abstract class InferenceOperatorTestCase extends OperatorTestCase { + private ThreadPool threadPool; + + @Before + public void setThreadPool() { + threadPool = new TestThreadPool( + getTestClass().getSimpleName(), + new FixedExecutorBuilder( + Settings.EMPTY, + EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME, + between(1, 10), + 1024, + "esql", + EsExecutors.TaskTrackingConfig.DEFAULT + ) + ); + } + + @After + public void shutdownThreadPool() { + terminate(threadPool); + } + + protected ThreadPool threadPool() { + return threadPool; + } + + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + return new AbstractBlockSourceOperator(blockFactory, 8 * 1024) { + @Override + protected int remaining() { + return size - currentPosition; + } + + @Override + protected Page createPage(int positionOffset, int length) { + length = Integer.min(length, remaining()); + try (var builder = blockFactory.newBytesRefVectorBuilder(length)) { + for (int i = 0; i < length; i++) { + builder.appendBytesRef(new BytesRef(randomAlphaOfLength(10))); + } + currentPosition += length; + return new Page(builder.build().asBlock()); + } + } + }; + } + + @Override + public void testOperatorStatus() { + DriverContext driverContext = driverContext(); + try (var operator = simple().get(driverContext)) { + AsyncOperator.Status status = asInstanceOf(AsyncOperator.Status.class, operator.status()); + + assertThat(status, notNullValue()); + assertThat(status.receivedPages(), equalTo(0L)); + assertThat(status.completedPages(), equalTo(0L)); + assertThat(status.procesNanos(), greaterThanOrEqualTo(0L)); + } + } + + @SuppressWarnings("unchecked") + protected InferenceRunner mockedSimpleInferenceRunner() { + Client client = new NoOpClient(threadPool) { + @Override + protected void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + Runnable runnable = () -> { + if (action == InferenceAction.INSTANCE && request instanceof InferenceAction.Request inferenceRequest) { + InferenceAction.Response inferenceResponse = new InferenceAction.Response(mockInferenceResult(inferenceRequest)); + listener.onResponse((Response) inferenceResponse); + return; + } + + fail("Unexpected call to action [" + action.name() + "]"); + }; + + if (randomBoolean()) { + runnable.run(); + } else { + threadPool.schedule(runnable, TimeValue.timeValueNanos(between(1, 100)), threadPool.executor(ThreadPool.Names.SEARCH)); + } + } + }; + + return new InferenceRunner(client, threadPool); + } + + protected abstract InferenceResultsType mockInferenceResult(InferenceAction.Request request); + + protected void assertBlockContentEquals(Block input, Block result) { + BytesRef scratch = new BytesRef(); + switch (input.elementType()) { + case BOOLEAN -> assertBlockContentEquals(input, result, BooleanBlock::getBoolean, BooleanBlock.class); + case INT -> assertBlockContentEquals(input, result, IntBlock::getInt, IntBlock.class); + case LONG -> assertBlockContentEquals(input, result, LongBlock::getLong, LongBlock.class); + case FLOAT -> assertBlockContentEquals(input, result, FloatBlock::getFloat, FloatBlock.class); + case DOUBLE -> assertBlockContentEquals(input, result, DoubleBlock::getDouble, DoubleBlock.class); + case BYTES_REF -> assertByteRefsBlockContentEquals(input, result, scratch); + default -> throw new AssertionError(LoggerMessageFormat.format("Unexpected block type {}", input.elementType())); + } + } + + private void assertBlockContentEquals( + Block input, + Block result, + BiFunction valueReader, + Class blockClass + ) { + V inputBlock = asInstanceOf(blockClass, input); + V resultBlock = asInstanceOf(blockClass, result); + + assertAllPositions(inputBlock, (pos) -> { + if (inputBlock.isNull(pos)) { + assertThat(resultBlock.isNull(pos), equalTo(inputBlock.isNull(pos))); + } else { + assertThat(resultBlock.getValueCount(pos), equalTo(inputBlock.getValueCount(pos))); + assertThat(resultBlock.getFirstValueIndex(pos), equalTo(inputBlock.getFirstValueIndex(pos))); + for (int i = 0; i < inputBlock.getValueCount(pos); i++) { + assertThat( + valueReader.apply(resultBlock, resultBlock.getFirstValueIndex(pos) + i), + equalTo(valueReader.apply(inputBlock, inputBlock.getFirstValueIndex(pos) + i)) + ); + } + } + }); + } + + private void assertAllPositions(Block block, Consumer consumer) { + for (int pos = 0; pos < block.getPositionCount(); pos++) { + consumer.accept(pos); + } + } + + private void assertByteRefsBlockContentEquals(Block input, Block result, BytesRef readBuffer) { + assertBlockContentEquals(input, result, (BytesRefBlock b, Integer pos) -> b.getBytesRef(pos, readBuffer), BytesRefBlock.class); + } + + protected EvalOperator.ExpressionEvaluator.Factory evaluatorFactory(int channel) { + return context -> new EvalOperator.ExpressionEvaluator() { + @Override + public Block eval(Page page) { + return BlockUtils.deepCopyOf(page.getBlock(channel), blockFactory()); + } + + @Override + public void close() { + + } + }; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java index b72bce05c506a..1a22da2701b53 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java @@ -12,15 +12,23 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.FixedExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; +import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; +import org.junit.After; +import org.junit.Before; import java.util.List; @@ -34,8 +42,30 @@ import static org.mockito.Mockito.when; public class InferenceRunnerTests extends ESTestCase { + private TestThreadPool threadPool; + + @Before + public void setThreadPool() { + threadPool = new TestThreadPool( + getTestClass().getSimpleName(), + new FixedExecutorBuilder( + Settings.EMPTY, + EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME, + between(1, 10), + 1024, + "esql", + EsExecutors.TaskTrackingConfig.DEFAULT + ) + ); + } + + @After + public void shutdownThreadPool() { + terminate(threadPool); + } + public void testResolveInferenceIds() throws Exception { - InferenceRunner inferenceRunner = new InferenceRunner(mockClient()); + InferenceRunner inferenceRunner = new InferenceRunner(mockClient(), threadPool); List> inferencePlans = List.of(mockInferencePlan("rerank-plan")); SetOnce inferenceResolutionSetOnce = new SetOnce<>(); @@ -52,7 +82,7 @@ public void testResolveInferenceIds() throws Exception { } public void testResolveMultipleInferenceIds() throws Exception { - InferenceRunner inferenceRunner = new InferenceRunner(mockClient()); + InferenceRunner inferenceRunner = new InferenceRunner(mockClient(), threadPool); List> inferencePlans = List.of( mockInferencePlan("rerank-plan"), mockInferencePlan("rerank-plan"), @@ -80,7 +110,7 @@ public void testResolveMultipleInferenceIds() throws Exception { } public void testResolveMissingInferenceIds() throws Exception { - InferenceRunner inferenceRunner = new InferenceRunner(mockClient()); + InferenceRunner inferenceRunner = new InferenceRunner(mockClient(), threadPool); List> inferencePlans = List.of(mockInferencePlan("missing-plan")); SetOnce inferenceResolutionSetOnce = new SetOnce<>(); @@ -100,17 +130,29 @@ public void testResolveMissingInferenceIds() throws Exception { } @SuppressWarnings({ "unchecked", "raw-types" }) - private static Client mockClient() { + private Client mockClient() { Client client = mock(Client.class); doAnswer(i -> { - GetInferenceModelAction.Request request = i.getArgument(1, GetInferenceModelAction.Request.class); - ActionListener listener = (ActionListener) i.getArgument(2, ActionListener.class); - ActionResponse response = getInferenceModelResponse(request); - - if (response == null) { - listener.onFailure(new ResourceNotFoundException("inference endpoint not found")); + Runnable sendResponse = () -> { + GetInferenceModelAction.Request request = i.getArgument(1, GetInferenceModelAction.Request.class); + ActionListener listener = (ActionListener) i.getArgument(2, ActionListener.class); + ActionResponse response = getInferenceModelResponse(request); + + if (response == null) { + listener.onFailure(new ResourceNotFoundException("inference endpoint not found")); + } else { + listener.onResponse(response); + } + }; + + if (randomBoolean()) { + sendResponse.run(); } else { - listener.onResponse(response); + threadPool.schedule( + sendResponse, + TimeValue.timeValueNanos(between(1, 1_000)), + threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME) + ); } return null; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/RerankOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/RerankOperatorTests.java deleted file mode 100644 index b1335901361b2..0000000000000 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/RerankOperatorTests.java +++ /dev/null @@ -1,306 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.inference; - -import org.apache.lucene.util.BytesRef; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.logging.LoggerMessageFormat; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.EsExecutors; -import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.BooleanBlock; -import org.elasticsearch.compute.data.BytesRefBlock; -import org.elasticsearch.compute.data.DoubleBlock; -import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.FloatBlock; -import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.LongBlock; -import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.operator.AsyncOperator; -import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.compute.operator.Operator; -import org.elasticsearch.compute.operator.SourceOperator; -import org.elasticsearch.compute.test.AbstractBlockSourceOperator; -import org.elasticsearch.compute.test.OperatorTestCase; -import org.elasticsearch.compute.test.RandomBlock; -import org.elasticsearch.core.Releasables; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.threadpool.FixedExecutorBuilder; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; -import org.hamcrest.Matcher; -import org.junit.After; -import org.junit.Before; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.function.BiFunction; -import java.util.function.Consumer; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.notNullValue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -public class RerankOperatorTests extends OperatorTestCase { - private static final String ESQL_TEST_EXECUTOR = "esql_test_executor"; - private static final String SIMPLE_INFERENCE_ID = "test_reranker"; - private static final String SIMPLE_QUERY = "query text"; - private ThreadPool threadPool; - private List inputChannelElementTypes; - private XContentRowEncoder.Factory rowEncoderFactory; - private int scoreChannel; - - @Before - private void initChannels() { - int channelCount = randomIntBetween(2, 10); - scoreChannel = randomIntBetween(0, channelCount - 1); - inputChannelElementTypes = IntStream.range(0, channelCount).sorted().mapToObj(this::randomElementType).collect(Collectors.toList()); - rowEncoderFactory = mockRowEncoderFactory(); - } - - @Before - public void setThreadPool() { - int numThreads = randomBoolean() ? 1 : between(2, 16); - threadPool = new TestThreadPool( - "test", - new FixedExecutorBuilder(Settings.EMPTY, ESQL_TEST_EXECUTOR, numThreads, 1024, "esql", EsExecutors.TaskTrackingConfig.DEFAULT) - ); - } - - @After - public void shutdownThreadPool() { - terminate(threadPool); - } - - @Override - protected Operator.OperatorFactory simple(SimpleOptions options) { - InferenceRunner inferenceRunner = mockedSimpleInferenceRunner(); - return new RerankOperator.Factory(inferenceRunner, SIMPLE_INFERENCE_ID, SIMPLE_QUERY, rowEncoderFactory, scoreChannel); - } - - private InferenceRunner mockedSimpleInferenceRunner() { - InferenceRunner inferenceRunner = mock(InferenceRunner.class); - when(inferenceRunner.getThreadContext()).thenReturn(threadPool.getThreadContext()); - doAnswer(invocation -> { - Runnable sendResponse = () -> { - @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocation.getArgument( - 1, - ActionListener.class - ); - InferenceAction.Response inferenceResponse = mock(InferenceAction.Response.class); - when(inferenceResponse.getResults()).thenReturn( - mockedRankedDocResults(invocation.getArgument(0, InferenceAction.Request.class)) - ); - listener.onResponse(inferenceResponse); - }; - if (randomBoolean()) { - sendResponse.run(); - } else { - threadPool.schedule(sendResponse, TimeValue.timeValueNanos(between(1, 1_000)), threadPool.executor(ESQL_TEST_EXECUTOR)); - } - return null; - }).when(inferenceRunner).doInference(any(), any()); - - return inferenceRunner; - } - - private RankedDocsResults mockedRankedDocResults(InferenceAction.Request request) { - List rankedDocs = new ArrayList<>(); - for (int rank = 0; rank < request.getInput().size(); rank++) { - if (rank % 10 != 0) { - rankedDocs.add(new RankedDocsResults.RankedDoc(rank, 1f / rank, request.getInput().get(rank))); - } - } - return new RankedDocsResults(rankedDocs); - } - - @Override - protected Matcher expectedDescriptionOfSimple() { - return expectedToStringOfSimple(); - } - - @Override - protected Matcher expectedToStringOfSimple() { - return equalTo( - "RerankOperator[inference_id=[" + SIMPLE_INFERENCE_ID + "], query=[" + SIMPLE_QUERY + "], score_channel=[" + scoreChannel + "]]" - ); - } - - @Override - protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { - final int minPageSize = Math.max(1, size / 100); - return new AbstractBlockSourceOperator(blockFactory, between(minPageSize, size)) { - @Override - protected int remaining() { - return size - currentPosition; - } - - @Override - protected Page createPage(int positionOffset, int length) { - Block[] blocks = new Block[inputChannelElementTypes.size()]; - try { - currentPosition += length; - for (int b = 0; b < inputChannelElementTypes.size(); b++) { - blocks[b] = RandomBlock.randomBlock( - blockFactory, - inputChannelElementTypes.get(b), - length, - randomBoolean(), - 0, - 10, - 0, - 10 - ).block(); - } - return new Page(blocks); - } catch (Exception e) { - Releasables.closeExpectNoException(blocks); - throw (e); - } - } - }; - } - - /** - * Ensures that the Operator.Status of this operator has the standard fields. - */ - public void testOperatorStatus() throws IOException { - DriverContext driverContext = driverContext(); - try (var operator = simple().get(driverContext)) { - AsyncOperator.Status status = asInstanceOf(AsyncOperator.Status.class, operator.status()); - - assertThat(status, notNullValue()); - assertThat(status.receivedPages(), equalTo(0L)); - assertThat(status.completedPages(), equalTo(0L)); - assertThat(status.procesNanos(), greaterThanOrEqualTo(0L)); - } - } - - @Override - protected void assertSimpleOutput(List inputPages, List resultPages) { - assertThat(inputPages, hasSize(resultPages.size())); - - for (int pageId = 0; pageId < inputPages.size(); pageId++) { - Page inputPage = inputPages.get(pageId); - Page resultPage = resultPages.get(pageId); - - // Check all rows are present and the output shape is unchanged. - assertThat(inputPage.getPositionCount(), equalTo(resultPage.getPositionCount())); - assertThat(inputPage.getBlockCount(), equalTo(resultPage.getBlockCount())); - - BytesRef readBuffer = new BytesRef(); - - for (int channel = 0; channel < inputPage.getBlockCount(); channel++) { - Block inputBlock = inputPage.getBlock(channel); - Block resultBlock = resultPage.getBlock(channel); - - assertThat(resultBlock.getPositionCount(), equalTo(resultPage.getPositionCount())); - assertThat(resultBlock.elementType(), equalTo(inputBlock.elementType())); - - if (channel == scoreChannel) { - assertExpectedScore(asInstanceOf(DoubleBlock.class, resultBlock)); - } else { - switch (inputBlock.elementType()) { - case BOOLEAN -> assertBlockContentEquals(inputBlock, resultBlock, BooleanBlock::getBoolean, BooleanBlock.class); - case INT -> assertBlockContentEquals(inputBlock, resultBlock, IntBlock::getInt, IntBlock.class); - case LONG -> assertBlockContentEquals(inputBlock, resultBlock, LongBlock::getLong, LongBlock.class); - case FLOAT -> assertBlockContentEquals(inputBlock, resultBlock, FloatBlock::getFloat, FloatBlock.class); - case DOUBLE -> assertBlockContentEquals(inputBlock, resultBlock, DoubleBlock::getDouble, DoubleBlock.class); - case BYTES_REF -> assertByteRefsBlockContentEquals(inputBlock, resultBlock, readBuffer); - default -> throw new AssertionError( - LoggerMessageFormat.format("Unexpected block type {}", inputBlock.elementType()) - ); - } - } - } - } - } - - private int inputChannelCount() { - return inputChannelElementTypes.size(); - } - - private ElementType randomElementType(int channel) { - return channel == scoreChannel ? ElementType.DOUBLE : randomFrom(ElementType.FLOAT, ElementType.DOUBLE, ElementType.LONG); - } - - private XContentRowEncoder.Factory mockRowEncoderFactory() { - XContentRowEncoder.Factory factory = mock(XContentRowEncoder.Factory.class); - doAnswer(factoryInvocation -> { - DriverContext driverContext = factoryInvocation.getArgument(0, DriverContext.class); - XContentRowEncoder rowEncoder = mock(XContentRowEncoder.class); - doAnswer(encoderInvocation -> { - Page inputPage = encoderInvocation.getArgument(0, Page.class); - return driverContext.blockFactory() - .newConstantBytesRefBlockWith(new BytesRef(randomRealisticUnicodeOfCodepointLength(4)), inputPage.getPositionCount()); - }).when(rowEncoder).eval(any(Page.class)); - - return rowEncoder; - }).when(factory).get(any(DriverContext.class)); - - return factory; - } - - private void assertExpectedScore(DoubleBlock scoreBlockResult) { - assertAllPositions(scoreBlockResult, (pos) -> { - if (pos % 10 == 0) { - assertThat(scoreBlockResult.isNull(pos), equalTo(true)); - } else { - assertThat(scoreBlockResult.getValueCount(pos), equalTo(1)); - assertThat(scoreBlockResult.getDouble(scoreBlockResult.getFirstValueIndex(pos)), equalTo((double) (1f / pos))); - } - }); - } - - void assertBlockContentEquals( - Block input, - Block result, - BiFunction valueReader, - Class blockClass - ) { - V inputBlock = asInstanceOf(blockClass, input); - V resultBlock = asInstanceOf(blockClass, result); - - assertAllPositions(inputBlock, (pos) -> { - if (inputBlock.isNull(pos)) { - assertThat(resultBlock.isNull(pos), equalTo(inputBlock.isNull(pos))); - } else { - assertThat(resultBlock.getValueCount(pos), equalTo(inputBlock.getValueCount(pos))); - assertThat(resultBlock.getFirstValueIndex(pos), equalTo(inputBlock.getFirstValueIndex(pos))); - for (int i = 0; i < inputBlock.getValueCount(pos); i++) { - assertThat( - valueReader.apply(resultBlock, resultBlock.getFirstValueIndex(pos) + i), - equalTo(valueReader.apply(inputBlock, inputBlock.getFirstValueIndex(pos) + i)) - ); - } - } - }); - } - - private void assertAllPositions(Block block, Consumer consumer) { - for (int pos = 0; pos < block.getPositionCount(); pos++) { - consumer.accept(pos); - } - } - - private void assertByteRefsBlockContentEquals(Block input, Block result, BytesRef readBuffer) { - assertBlockContentEquals(input, result, (BytesRefBlock b, Integer pos) -> b.getBytesRef(pos, readBuffer), BytesRefBlock.class); - } -} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java new file mode 100644 index 0000000000000..e539a8ec775b6 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java @@ -0,0 +1,205 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.bulk; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.FixedExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; +import org.junit.After; +import org.junit.Before; +import org.mockito.stubbing.Answer; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class BulkInferenceExecutorTests extends ESTestCase { + private ThreadPool threadPool; + + @Before + public void setThreadPool() { + threadPool = new TestThreadPool( + getTestClass().getSimpleName(), + new FixedExecutorBuilder( + Settings.EMPTY, + EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME, + between(1, 20), + 1024, + "esql", + EsExecutors.TaskTrackingConfig.DEFAULT + ) + ); + } + + @After + public void shutdownThreadPool() { + terminate(threadPool); + } + + public void testSuccessfulExecution() throws Exception { + List requests = randomInferenceRequestList(between(1, 10_000)); + List responses = randomInferenceResponseList(requests.size()); + + InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> { + runWithRandomDelay(() -> { + ActionListener l = invocation.getArgument(1); + l.onResponse(responses.get(requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class)))); + }); + return null; + }); + + AtomicReference> output = new AtomicReference<>(); + ActionListener> listener = ActionListener.wrap(output::set, r -> fail("Unexpected exception")); + + bulkExecutor(inferenceRunner).execute(requestIterator(requests), listener); + + assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), equalTo(responses)))); + } + + public void testSuccessfulExecutionOnEmptyRequest() throws Exception { + BulkInferenceRequestIterator requestIterator = mock(BulkInferenceRequestIterator.class); + when(requestIterator.hasNext()).thenReturn(false); + + AtomicReference> output = new AtomicReference<>(); + ActionListener> listener = ActionListener.wrap(output::set, r -> fail("Unexpected exception")); + + bulkExecutor(mock(InferenceRunner.class)).execute(requestIterator, listener); + + assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), empty()))); + } + + public void testInferenceRunnerAlwaysFails() throws Exception { + List requests = randomInferenceRequestList(between(1, 10_000)); + + InferenceRunner inferenceRunner = mock(invocation -> { + runWithRandomDelay(() -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("inference failure")); + }); + return null; + }); + + AtomicReference exception = new AtomicReference<>(); + ActionListener> listener = ActionListener.wrap(r -> fail("Expceted exception"), exception::set); + + bulkExecutor(inferenceRunner).execute(requestIterator(requests), listener); + + assertBusy(() -> { + assertThat(exception.get(), notNullValue()); + assertThat(exception.get().getMessage(), equalTo("inference failure")); + }); + } + + public void testInferenceRunnerSometimesFails() throws Exception { + List requests = randomInferenceRequestList(between(1, 10_000)); + + InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> { + ActionListener listener = invocation.getArgument(1); + runWithRandomDelay(() -> { + if ((requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class)) % requests.size()) == 0) { + listener.onFailure(new RuntimeException("inference failure")); + } else { + listener.onResponse(mockInferenceResponse()); + } + }); + + return null; + }); + + AtomicReference exception = new AtomicReference<>(); + ActionListener> listener = ActionListener.wrap(r -> fail("Expceted exception"), exception::set); + + bulkExecutor(inferenceRunner).execute(requestIterator(requests), listener); + + assertBusy(() -> { + assertThat(exception.get(), notNullValue()); + assertThat(exception.get().getMessage(), equalTo("inference failure")); + }); + } + + private BulkInferenceExecutor bulkExecutor(InferenceRunner inferenceRunner) { + return new BulkInferenceExecutor(inferenceRunner, threadPool, randomBulkExecutionConfig()); + } + + private InferenceAction.Request mockInferenceRequest() { + return mock(InferenceAction.Request.class); + } + + private InferenceAction.Response mockInferenceResponse() { + InferenceAction.Response response = mock(InferenceAction.Response.class); + when(response.getResults()).thenReturn(mock(RankedDocsResults.class)); + return response; + } + + private BulkInferenceExecutionConfig randomBulkExecutionConfig() { + return new BulkInferenceExecutionConfig(between(1, 100), between(1, 100)); + } + + private BulkInferenceRequestIterator requestIterator(List requests) { + final Iterator delegate = requests.iterator(); + BulkInferenceRequestIterator iterator = mock(BulkInferenceRequestIterator.class); + doAnswer(i -> delegate.hasNext()).when(iterator).hasNext(); + doAnswer(i -> delegate.next()).when(iterator).next(); + doAnswer(i -> requests.size()).when(iterator).estimatedSize(); + return iterator; + } + + private List randomInferenceRequestList(int size) { + List requests = new ArrayList<>(size); + while (requests.size() < size) { + requests.add(this.mockInferenceRequest()); + } + return requests; + + } + + private List randomInferenceResponseList(int size) { + List response = new ArrayList<>(size); + while (response.size() < size) { + response.add(mock(InferenceAction.Response.class)); + } + return response; + } + + private InferenceRunner mockInferenceRunner(Answer doInferenceAnswer) { + InferenceRunner inferenceRunner = mock(InferenceRunner.class); + doAnswer(doInferenceAnswer).when(inferenceRunner).doInference(any(), any()); + return inferenceRunner; + } + + private void runWithRandomDelay(Runnable runnable) { + if (randomBoolean()) { + runnable.run(); + } else { + threadPool.schedule( + runnable, + TimeValue.timeValueNanos(between(1, 1_000)), + threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME) + ); + } + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilderTests.java new file mode 100644 index 0000000000000..77d50ca5ee981 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilderTests.java @@ -0,0 +1,92 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.completion; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.test.ComputeTestCase; +import org.elasticsearch.compute.test.RandomBlock; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; + +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class CompletionOperatorOutputBuilderTests extends ComputeTestCase { + + public void testBuildSmallOutput() { + assertBuildOutput(between(1, 100)); + } + + public void testBuildLargeOutput() { + assertBuildOutput(between(10_000, 100_000)); + } + + private void assertBuildOutput(int size) { + final Page inputPage = randomInputPage(size, between(1, 20)); + try ( + CompletionOperatorOutputBuilder outputBuilder = new CompletionOperatorOutputBuilder( + blockFactory().newBytesRefBlockBuilder(size), + inputPage + ) + ) { + for (int currentPos = 0; currentPos < inputPage.getPositionCount(); currentPos++) { + List results = List.of(new ChatCompletionResults.Result("Completion result #" + currentPos)); + outputBuilder.addInferenceResponse(new InferenceAction.Response(new ChatCompletionResults(results))); + } + + final Page outputPage = outputBuilder.buildOutput(); + assertThat(outputPage.getPositionCount(), equalTo(inputPage.getPositionCount())); + assertThat(outputPage.getBlockCount(), equalTo(inputPage.getBlockCount() + 1)); + assertOutputContent(outputPage.getBlock(outputPage.getBlockCount() - 1)); + + outputPage.releaseBlocks(); + + } finally { + inputPage.releaseBlocks(); + } + + } + + private void assertOutputContent(BytesRefBlock block) { + BytesRef scratch = new BytesRef(); + + for (int currentPos = 0; currentPos < block.getPositionCount(); currentPos++) { + assertThat(block.isNull(currentPos), equalTo(false)); + scratch = block.getBytesRef(block.getFirstValueIndex(currentPos), scratch); + assertThat(scratch.utf8ToString(), equalTo("Completion result #" + currentPos)); + } + } + + private Page randomInputPage(int positionCount, int columnCount) { + final Block[] blocks = new Block[columnCount]; + try { + for (int i = 0; i < columnCount; i++) { + blocks[i] = RandomBlock.randomBlock( + blockFactory(), + RandomBlock.randomElementType(), + positionCount, + randomBoolean(), + 0, + 0, + randomInt(10), + randomInt(10) + ).block(); + } + + return new Page(blocks); + } catch (Exception e) { + Releasables.close(blocks); + throw (e); + } + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java new file mode 100644 index 0000000000000..5c0253e508553 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.completion; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.test.ComputeTestCase; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; + +import static org.hamcrest.Matchers.equalTo; + +public class CompletionOperatorRequestIteratorTests extends ComputeTestCase { + + public void testIterateSmallInput() { + assertIterate(between(1, 100)); + } + + public void testIterateLargeInput() { + assertIterate(between(10_000, 100_000)); + } + + private void assertIterate(int size) { + final String inferenceId = randomIdentifier(); + + try ( + BytesRefBlock inputBlock = randomInputBlock(size); + CompletionOperatorRequestIterator requestIterator = new CompletionOperatorRequestIterator(inputBlock, inferenceId) + ) { + BytesRef scratch = new BytesRef(); + + for (int currentPos = 0; requestIterator.hasNext(); currentPos++) { + InferenceAction.Request request = requestIterator.next(); + assertThat(request.getInferenceEntityId(), equalTo(inferenceId)); + scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(currentPos), scratch); + assertThat(request.getInput().getFirst(), equalTo(scratch.utf8ToString())); + } + } + } + + private BytesRefBlock randomInputBlock(int size) { + try (BytesRefBlock.Builder blockBuilder = blockFactory().newBytesRefBlockBuilder(size)) { + for (int i = 0; i < size; i++) { + blockBuilder.appendBytesRef(new BytesRef(randomAlphaOfLength(10))); + } + + return blockBuilder.build(); + } + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorTests.java new file mode 100644 index 0000000000000..add8155240ad1 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorTests.java @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.completion; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.esql.inference.InferenceOperatorTestCase; +import org.hamcrest.Matcher; + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class CompletionOperatorTests extends InferenceOperatorTestCase { + private static final String SIMPLE_INFERENCE_ID = "test_completion"; + + @Override + protected Operator.OperatorFactory simple(SimpleOptions options) { + return new CompletionOperator.Factory(mockedSimpleInferenceRunner(), SIMPLE_INFERENCE_ID, evaluatorFactory(0)); + } + + @Override + protected void assertSimpleOutput(List input, List results) { + assertThat(results, hasSize(input.size())); + + for (int curPage = 0; curPage < input.size(); curPage++) { + Page inputPage = input.get(curPage); + Page resultPage = results.get(curPage); + + assertEquals(inputPage.getPositionCount(), resultPage.getPositionCount()); + assertEquals(inputPage.getBlockCount() + 1, resultPage.getBlockCount()); + + for (int channel = 0; channel < inputPage.getBlockCount(); channel++) { + Block inputBlock = inputPage.getBlock(channel); + Block resultBlock = resultPage.getBlock(channel); + assertBlockContentEquals(inputBlock, resultBlock); + } + + assertCompletionResults(inputPage, resultPage); + } + } + + private void assertCompletionResults(Page inputPage, Page resultPage) { + BytesRefBlock inputBlock = resultPage.getBlock(0); + BytesRefBlock resultBlock = resultPage.getBlock(inputPage.getBlockCount()); + + BytesRef scratch = new BytesRef(); + StringBuilder inputBuilder = new StringBuilder(); + + for (int curPos = 0; curPos < inputPage.getPositionCount(); curPos++) { + inputBuilder.setLength(0); + int valueIndex = inputBlock.getFirstValueIndex(curPos); + while (valueIndex < inputBlock.getFirstValueIndex(curPos) + inputBlock.getValueCount(curPos)) { + scratch = inputBlock.getBytesRef(valueIndex, scratch); + inputBuilder.append(scratch.utf8ToString()); + if (valueIndex < inputBlock.getValueCount(curPos) - 1) { + inputBuilder.append("\n"); + } + valueIndex++; + } + scratch = resultBlock.getBytesRef(resultBlock.getFirstValueIndex(curPos), scratch); + + assertThat(scratch.utf8ToString(), equalTo(inputBuilder.toString().toUpperCase(Locale.ROOT))); + } + } + + @Override + protected ChatCompletionResults mockInferenceResult(InferenceAction.Request request) { + List results = new ArrayList<>(); + for (String input : request.getInput()) { + results.add(new ChatCompletionResults.Result(input.toUpperCase(Locale.ROOT))); + } + return new ChatCompletionResults(results); + } + + @Override + protected Matcher expectedDescriptionOfSimple() { + return expectedToStringOfSimple(); + } + + @Override + protected Matcher expectedToStringOfSimple() { + return equalTo("CompletionOperator[inference_id=[" + SIMPLE_INFERENCE_ID + "]]"); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilderTests.java new file mode 100644 index 0000000000000..7117ccc19005e --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilderTests.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.rerank; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.test.ComputeTestCase; +import org.elasticsearch.compute.test.RandomBlock; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; + +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class RerankOperatorOutputBuilderTests extends ComputeTestCase { + + public void testBuildSmallOutput() { + assertBuildOutput(between(1, 100)); + } + + public void testBuildLargeOutput() { + assertBuildOutput(between(10_000, 100_000)); + } + + private void assertBuildOutput(int size) { + final Page inputPage = randomInputPage(size, between(1, 20)); + final int scoreChannel = randomIntBetween(0, inputPage.getBlockCount()); + try ( + RerankOperatorOutputBuilder outputBuilder = new RerankOperatorOutputBuilder( + blockFactory().newDoubleBlockBuilder(size), + inputPage, + scoreChannel + ) + ) { + int batchSize = randomIntBetween(1, size); + for (int currentPos = 0; currentPos < inputPage.getPositionCount();) { + List rankedDocs = new ArrayList<>(); + for (int rankedDocIndex = 0; rankedDocIndex < batchSize && currentPos < inputPage.getPositionCount(); rankedDocIndex++) { + rankedDocs.add(new RankedDocsResults.RankedDoc(rankedDocIndex, relevanceScore(currentPos), randomIdentifier())); + currentPos++; + } + + outputBuilder.addInferenceResponse(new InferenceAction.Response(new RankedDocsResults(rankedDocs))); + } + + final Page outputPage = outputBuilder.buildOutput(); + try { + assertThat(outputPage.getPositionCount(), equalTo(inputPage.getPositionCount())); + LogManager.getLogger(RerankOperatorOutputBuilderTests.class) + .info( + "{} , {}, {}, {}", + scoreChannel, + inputPage.getBlockCount(), + outputPage.getBlockCount(), + Math.max(scoreChannel + 1, inputPage.getBlockCount()) + ); + assertThat(outputPage.getBlockCount(), equalTo(Integer.max(scoreChannel + 1, inputPage.getBlockCount()))); + assertOutputContent(outputPage.getBlock(scoreChannel)); + } finally { + outputPage.releaseBlocks(); + } + + } finally { + inputPage.releaseBlocks(); + } + } + + private float relevanceScore(int position) { + return (float) 1 / (1 + position); + } + + private void assertOutputContent(DoubleBlock block) { + for (int currentPos = 0; currentPos < block.getPositionCount(); currentPos++) { + assertThat(block.isNull(currentPos), equalTo(false)); + assertThat(block.getValueCount(currentPos), equalTo(1)); + assertThat(block.getDouble(block.getFirstValueIndex(currentPos)), equalTo((double) relevanceScore(currentPos))); + } + } + + private Page randomInputPage(int positionCount, int columnCount) { + final Block[] blocks = new Block[columnCount]; + try { + for (int i = 0; i < columnCount; i++) { + blocks[i] = RandomBlock.randomBlock( + blockFactory(), + RandomBlock.randomElementType(), + positionCount, + randomBoolean(), + 0, + 0, + randomInt(10), + randomInt(10) + ).block(); + } + + return new Page(blocks); + } catch (Exception e) { + Releasables.close(blocks); + throw (e); + } + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java new file mode 100644 index 0000000000000..133bfeaaf02ad --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.rerank; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.test.ComputeTestCase; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; + +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class RerankOperatorRequestIteratorTests extends ComputeTestCase { + + public void testIterateSmallInput() { + assertIterate(between(1, 100), randomIntBetween(1, 1_000)); + } + + public void testIterateLargeInput() { + assertIterate(between(10_000, 100_000), randomIntBetween(1, 1_000)); + } + + private void assertIterate(int size, int batchSize) { + final String inferenceId = randomIdentifier(); + final String queryText = randomIdentifier(); + + try ( + BytesRefBlock inputBlock = randomInputBlock(size); + RerankOperatorRequestIterator requestIterator = new RerankOperatorRequestIterator(inputBlock, inferenceId, queryText, batchSize) + ) { + BytesRef scratch = new BytesRef(); + + for (int currentPos = 0; requestIterator.hasNext();) { + InferenceAction.Request request = requestIterator.next(); + + assertThat(request.getInferenceEntityId(), equalTo(inferenceId)); + assertThat(request.getQuery(), equalTo(queryText)); + List inputs = request.getInput(); + for (String input : inputs) { + scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(currentPos), scratch); + assertThat(input, equalTo(scratch.utf8ToString())); + currentPos++; + } + } + } + } + + private BytesRefBlock randomInputBlock(int size) { + try (BytesRefBlock.Builder blockBuilder = blockFactory().newBytesRefBlockBuilder(size)) { + for (int i = 0; i < size; i++) { + blockBuilder.appendBytesRef(new BytesRef(randomAlphaOfLength(10))); + } + + return blockBuilder.build(); + } + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorTests.java new file mode 100644 index 0000000000000..f5dc1b3c05fd9 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorTests.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.rerank; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.esql.inference.InferenceOperatorTestCase; +import org.hamcrest.Matcher; + +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class RerankOperatorTests extends InferenceOperatorTestCase { + + private static final String SIMPLE_INFERENCE_ID = "test_reranker"; + private static final String SIMPLE_QUERY = "query text"; + + @Override + protected Operator.OperatorFactory simple(SimpleOptions options) { + return new RerankOperator.Factory(mockedSimpleInferenceRunner(), SIMPLE_INFERENCE_ID, SIMPLE_QUERY, evaluatorFactory(0), 1); + } + + @Override + protected void assertSimpleOutput(List inputPages, List resultPages) { + assertThat(inputPages, hasSize(resultPages.size())); + + for (int pageId = 0; pageId < inputPages.size(); pageId++) { + Page inputPage = inputPages.get(pageId); + Page resultPage = resultPages.get(pageId); + + assertThat(resultPage.getPositionCount(), equalTo(inputPage.getPositionCount())); + assertThat(resultPage.getBlockCount(), equalTo(Integer.max(2, inputPage.getBlockCount()))); + + for (int channel = 0; channel < inputPage.getBlockCount(); channel++) { + Block inputBlock = inputPage.getBlock(channel); + Block resultBlock = resultPage.getBlock(channel); + + assertThat(resultBlock.getPositionCount(), equalTo(resultPage.getPositionCount())); + assertThat(resultBlock.elementType(), equalTo(inputBlock.elementType())); + + if (channel != 1) { + assertBlockContentEquals(inputBlock, resultBlock); + } + + if (channel == 0) { + assertExpectedScore((BytesRefBlock) inputBlock, resultPage.getBlock(1)); + } + } + } + } + + private void assertExpectedScore(BytesRefBlock inputBlock, DoubleBlock scoreBlock) { + assertThat(scoreBlock.getPositionCount(), equalTo(inputBlock.getPositionCount())); + for (int pos = 0; pos < inputBlock.getPositionCount(); pos++) { + double score = scoreBlock.getDouble(scoreBlock.getFirstValueIndex(pos)); + double expectedScore = score(pos); + assertThat(score, equalTo(expectedScore)); + } + } + + @Override + protected Matcher expectedDescriptionOfSimple() { + return expectedToStringOfSimple(); + } + + @Override + protected Matcher expectedToStringOfSimple() { + return equalTo( + "RerankOperator[inference_id=[" + SIMPLE_INFERENCE_ID + "], query=[" + SIMPLE_QUERY + "], score_channel=[" + 1 + "]]" + ); + } + + @Override + protected RankedDocsResults mockInferenceResult(InferenceAction.Request request) { + List rankedDocs = new ArrayList<>(); + for (int rank = 0; rank < request.getInput().size(); rank++) { + rankedDocs.add(new RankedDocsResults.RankedDoc(rank, score(rank), request.getInput().get(rank))); + } + + return new RankedDocsResults(rankedDocs); + } + + private float score(int rank) { + return 1f / (rank % 20); + } +} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 4de3c9f31d38e..d4231863e31ae 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -32,7 +32,7 @@ public static void init() { public void testGetServicesWithoutTaskType() throws IOException { List services = getAllServices(); - assertThat(services.size(), equalTo(23)); + assertThat(services.size(), equalTo(24)); var providers = providers(services); @@ -57,6 +57,7 @@ public void testGetServicesWithoutTaskType() throws IOException { "mistral", "openai", "streaming_completion_test_service", + "completion_test_service", "test_reranking_service", "test_service", "text_embedding_test_service", @@ -134,7 +135,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { public void testGetServicesWithCompletionTaskType() throws IOException { List services = getServices(TaskType.COMPLETION); - assertThat(services.size(), equalTo(13)); + assertThat(services.size(), equalTo(14)); var providers = providers(services); @@ -153,6 +154,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException { "googleaistudio", "openai", "streaming_completion_test_service", + "completion_test_service", "hugging_face", "amazon_sagemaker" ).toArray() diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java new file mode 100644 index 0000000000000..9c15ac77cc13f --- /dev/null +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java @@ -0,0 +1,233 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mock; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +public class TestCompletionServiceExtension implements InferenceServiceExtension { + @Override + public List getInferenceServiceFactories() { + return List.of(TestInferenceService::new); + } + + public static class TestInferenceService extends AbstractTestInferenceService { + private static final String NAME = "completion_test_service"; + private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of(TaskType.COMPLETION); + + public TestInferenceService(InferenceServiceFactoryContext context) {} + + @Override + public String name() { + return NAME; + } + + @Override + protected ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap) { + return TestServiceSettings.fromMap(serviceSettingsMap); + } + + @Override + @SuppressWarnings("unchecked") + public void parseRequestConfig( + String modelId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); + var serviceSettings = TestSparseInferenceServiceExtension.TestServiceSettings.fromMap(serviceSettingsMap); + var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap); + + var taskSettingsMap = getTaskSettingsMap(config); + var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); + + parsedModelListener.onResponse(new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings)); + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return SUPPORTED_TASK_TYPES; + } + + @Override + public void infer( + Model model, + String query, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + List input, + boolean stream, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + switch (model.getConfigurations().getTaskType()) { + case COMPLETION -> listener.onResponse(makeChatCompletionResults(input)); + default -> listener.onFailure( + new ElasticsearchStatusException( + TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), + RestStatus.BAD_REQUEST + ) + ); + } + } + + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + listener.onFailure( + new ElasticsearchStatusException( + TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), + RestStatus.BAD_REQUEST + ) + ); + } + + @Override + public void chunkedInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + listener.onFailure( + new ElasticsearchStatusException( + TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), + RestStatus.BAD_REQUEST + ) + ); + } + + private InferenceServiceResults makeChatCompletionResults(List inputs) { + List results = new ArrayList<>(); + for (String text : inputs) { + results.add(new ChatCompletionResults.Result(text.toUpperCase(Locale.ROOT))); + } + + return new ChatCompletionResults(results); + } + + public static class Configuration { + public static InferenceServiceConfiguration get() { + return configuration.getOrCompute(); + } + + private static final LazyInitializable configuration = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + + configurationMap.put( + "model_id", + new SettingsConfiguration.Builder(EnumSet.of(TaskType.COMPLETION)).setDescription("") + .setLabel("Model ID") + .setRequired(true) + .setSensitive(true) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setTaskTypes(SUPPORTED_TASK_TYPES) + .setConfigurations(configurationMap) + .build(); + } + ); + } + } + + public record TestServiceSettings(String modelId) implements ServiceSettings { + public static final String NAME = "completion_test_service_settings"; + + public TestServiceSettings(StreamInput in) throws IOException { + this(in.readString()); + } + + public static TestServiceSettings fromMap(Map map) { + var modelId = map.remove("model").toString(); + + if (modelId == null) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError("missing model id"); + throw validationException; + } + + return new TestServiceSettings(modelId); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId()); + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return this; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject().field("model", modelId()).endObject(); + } + } +} diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java index 1d04aab022f91..4cfc7e388a911 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java @@ -54,6 +54,11 @@ public List getNamedWriteables() { ServiceSettings.class, TestStreamingCompletionServiceExtension.TestServiceSettings.NAME, TestStreamingCompletionServiceExtension.TestServiceSettings::new + ), + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + TestCompletionServiceExtension.TestServiceSettings.NAME, + TestCompletionServiceExtension.TestServiceSettings::new ) ); } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/resources/META-INF/services/org.elasticsearch.inference.InferenceServiceExtension b/x-pack/plugin/inference/qa/test-service-plugin/src/main/resources/META-INF/services/org.elasticsearch.inference.InferenceServiceExtension index c996a33d1e916..a481e2a4a0451 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/resources/META-INF/services/org.elasticsearch.inference.InferenceServiceExtension +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/resources/META-INF/services/org.elasticsearch.inference.InferenceServiceExtension @@ -2,3 +2,4 @@ org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension org.elasticsearch.xpack.inference.mock.TestRerankingServiceExtension org.elasticsearch.xpack.inference.mock.TestStreamingCompletionServiceExtension +org.elasticsearch.xpack.inference.mock.TestCompletionServiceExtension