diff --git a/docs/changelog/125239.yaml b/docs/changelog/125239.yaml new file mode 100644 index 0000000000000..60ec9bb0b7177 --- /dev/null +++ b/docs/changelog/125239.yaml @@ -0,0 +1,6 @@ +pr: 125239 +summary: Adding common rerank options to Perform Inference API +area: Machine Learning +type: enhancement +issues: + - 111273 diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 0005cf11f7265..4e46fd33e16b7 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -154,6 +154,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL_8_19 = def(8_841_0_12); public static final TransportVersion INFERENCE_MODEL_REGISTRY_METADATA_8_19 = def(8_841_0_13); public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14); + public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00); public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01); public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02); @@ -200,6 +201,7 @@ static TransportVersion def(int id) { public static final TransportVersion INDEXING_STATS_INCLUDES_RECENT_WRITE_LOAD = def(9_034_0_00); public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL = def(9_035_0_00); public static final TransportVersion INDEX_METADATA_INCLUDES_RECENT_WRITE_LOAD = def(9_036_0_00); + public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED = def(9_037_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index de1925cb641e9..309db20083ece 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -91,18 +91,22 @@ default boolean hideFromConfigurationApi() { /** * Perform inference on the model. * - * @param model The model - * @param query Inference query, mainly for re-ranking - * @param input Inference input - * @param stream Stream inference results - * @param taskSettings Settings in the request to override the model's defaults - * @param inputType For search, ingest etc - * @param timeout The timeout for the request - * @param listener Inference result listener + * @param model The model + * @param query Inference query, mainly for re-ranking + * @param returnDocuments For re-ranking task type, whether to return documents + * @param topN For re-ranking task type, how many docs to return + * @param input Inference input + * @param stream Stream inference results + * @param taskSettings Settings in the request to override the model's defaults + * @param inputType For search, ingest etc + * @param timeout The timeout for the request + * @param listener Inference result listener */ void infer( Model model, @Nullable String query, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, List input, boolean stream, Map taskSettings, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index b6d9689086dc4..e9ccb1baeb8bc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -60,6 +60,8 @@ public static class Request extends BaseInferenceActionRequest { public static final ParseField INPUT_TYPE = new ParseField("input_type"); public static final ParseField TASK_SETTINGS = new ParseField("task_settings"); public static final ParseField QUERY = new ParseField("query"); + public static final ParseField RETURN_DOCUMENTS = new ParseField("return_documents"); + public static final ParseField TOP_N = new ParseField("top_n"); public static final ParseField TIMEOUT = new ParseField("timeout"); static final ObjectParser PARSER = new ObjectParser<>(NAME, Request.Builder::new); @@ -68,6 +70,8 @@ public static class Request extends BaseInferenceActionRequest { PARSER.declareString(Request.Builder::setInputType, INPUT_TYPE); PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS); PARSER.declareString(Request.Builder::setQuery, QUERY); + PARSER.declareBoolean(Request.Builder::setReturnDocuments, RETURN_DOCUMENTS); + PARSER.declareInt(Request.Builder::setTopN, TOP_N); PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT); } @@ -89,6 +93,8 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType, private final TaskType taskType; private final String inferenceEntityId; private final String query; + private final Boolean returnDocuments; + private final Integer topN; private final List input; private final Map taskSettings; private final InputType inputType; @@ -99,6 +105,8 @@ public Request( TaskType taskType, String inferenceEntityId, String query, + Boolean returnDocuments, + Integer topN, List input, Map taskSettings, InputType inputType, @@ -109,6 +117,8 @@ public Request( taskType, inferenceEntityId, query, + returnDocuments, + topN, input, taskSettings, inputType, @@ -122,6 +132,8 @@ public Request( TaskType taskType, String inferenceEntityId, String query, + Boolean returnDocuments, + Integer topN, List input, Map taskSettings, InputType inputType, @@ -133,6 +145,8 @@ public Request( this.taskType = taskType; this.inferenceEntityId = inferenceEntityId; this.query = query; + this.returnDocuments = returnDocuments; + this.topN = topN; this.input = input; this.taskSettings = taskSettings; this.inputType = inputType; @@ -164,6 +178,15 @@ public Request(StreamInput in) throws IOException { this.inferenceTimeout = DEFAULT_TIMEOUT; } + if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_COMMON_OPTIONS_ADDED) + || in.getTransportVersion().isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19)) { + this.returnDocuments = in.readOptionalBoolean(); + this.topN = in.readOptionalInt(); + } else { + this.returnDocuments = null; + this.topN = null; + } + // streaming is not supported yet for transport traffic this.stream = false; } @@ -184,6 +207,14 @@ public String getQuery() { return query; } + public Boolean getReturnDocuments() { + return returnDocuments; + } + + public Integer getTopN() { + return topN; + } + public Map getTaskSettings() { return taskSettings; } @@ -225,6 +256,17 @@ public ActionRequestValidationException validate() { e.addValidationError(format("Field [query] cannot be empty for task type [%s]", TaskType.RERANK)); return e; } + } else if (taskType.equals(TaskType.ANY) == false) { + if (returnDocuments != null) { + var e = new ActionRequestValidationException(); + e.addValidationError(format("Field [return_documents] cannot be specified for task type [%s]", taskType)); + return e; + } + if (topN != null) { + var e = new ActionRequestValidationException(); + e.addValidationError(format("Field [top_n] cannot be specified for task type [%s]", taskType)); + return e; + } } if (taskType.equals(TaskType.TEXT_EMBEDDING) == false @@ -258,6 +300,12 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(query); out.writeTimeValue(inferenceTimeout); } + + if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_COMMON_OPTIONS_ADDED) + || out.getTransportVersion().isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19)) { + out.writeOptionalBoolean(returnDocuments); + out.writeOptionalInt(topN); + } } // default for easier testing @@ -283,6 +331,8 @@ public boolean equals(Object o) { && taskType == request.taskType && Objects.equals(inferenceEntityId, request.inferenceEntityId) && Objects.equals(query, request.query) + && Objects.equals(returnDocuments, request.returnDocuments) + && Objects.equals(topN, request.topN) && Objects.equals(input, request.input) && Objects.equals(taskSettings, request.taskSettings) && inputType == request.inputType @@ -296,6 +346,8 @@ public int hashCode() { taskType, inferenceEntityId, query, + returnDocuments, + topN, input, taskSettings, inputType, @@ -312,6 +364,8 @@ public static class Builder { private InputType inputType = InputType.UNSPECIFIED; private Map taskSettings = Map.of(); private String query; + private Boolean returnDocuments; + private Integer topN; private TimeValue timeout = DEFAULT_TIMEOUT; private boolean stream = false; private InferenceContext context; @@ -338,6 +392,16 @@ public Builder setQuery(String query) { return this; } + public Builder setReturnDocuments(Boolean returnDocuments) { + this.returnDocuments = returnDocuments; + return this; + } + + public Builder setTopN(Integer topN) { + this.topN = topN; + return this; + } + public Builder setInputType(InputType inputType) { this.inputType = inputType; return this; @@ -373,7 +437,19 @@ public Builder setContext(InferenceContext context) { } public Request build() { - return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream, context); + return new Request( + taskType, + inferenceEntityId, + query, + returnDocuments, + topN, + input, + taskSettings, + inputType, + timeout, + stream, + context + ); } } @@ -384,6 +460,10 @@ public String toString() { + this.getInferenceEntityId() + ", query=" + this.getQuery() + + ", returnDocuments=" + + this.getReturnDocuments() + + ", topN=" + + this.getTopN() + ", input=" + this.getInput() + ", taskSettings=" diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java index 024205b365a71..2e2b9bf9b0d23 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java @@ -44,6 +44,8 @@ protected InferenceAction.Request createTestInstance() { randomFrom(TaskType.values()), randomAlphaOfLength(6), randomAlphaOfLengthOrNull(10), + randomBoolean(), + randomIntBetween(0, 10), randomList(1, 5, () -> randomAlphaOfLength(8)), randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))), randomFrom(InputType.values()), @@ -85,6 +87,8 @@ public void testValidation_TextEmbedding() { TaskType.TEXT_EMBEDDING, "model", null, + null, + null, List.of("input"), null, null, @@ -100,6 +104,8 @@ public void testValidation_Rerank() { TaskType.RERANK, "model", "query", + Boolean.TRUE, + 34, List.of("input"), null, null, @@ -119,6 +125,8 @@ public void testValidation_TextEmbedding_Null() { null, null, null, + null, + null, false ); ActionRequestValidationException inputNullError = inputNullRequest.validate(); @@ -131,6 +139,8 @@ public void testValidation_TextEmbedding_Empty() { TaskType.TEXT_EMBEDDING, "model", null, + null, + null, List.of(), null, null, @@ -142,11 +152,52 @@ public void testValidation_TextEmbedding_Empty() { assertThat(inputEmptyError.getMessage(), is("Validation Failed: 1: Field [input] cannot be an empty array;")); } + public void testValidation_TextEmbedding_WithReturnDocument() { + InferenceAction.Request inputRequest = new InferenceAction.Request( + TaskType.TEXT_EMBEDDING, + "model", + null, + Boolean.TRUE, + null, + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException inputError = inputRequest.validate(); + assertNotNull(inputError); + assertThat( + inputError.getMessage(), + is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [text_embedding];") + ); + } + + public void testValidation_TextEmbedding_WithTopN() { + InferenceAction.Request inputRequest = new InferenceAction.Request( + TaskType.TEXT_EMBEDDING, + "model", + null, + null, + 12, + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException inputError = inputRequest.validate(); + assertNotNull(inputError); + assertThat(inputError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [text_embedding];")); + } + public void testValidation_Rerank_Null() { InferenceAction.Request queryNullRequest = new InferenceAction.Request( TaskType.RERANK, "model", null, + null, + null, List.of("input"), null, null, @@ -163,6 +214,8 @@ public void testValidation_Rerank_Empty() { TaskType.RERANK, "model", "", + null, + null, List.of("input"), null, null, @@ -179,6 +232,8 @@ public void testValidation_Rerank_WithInputType() { TaskType.RERANK, "model", "query", + null, + null, List.of("input"), null, InputType.SEARCH, @@ -195,6 +250,8 @@ public void testValidation_SparseEmbedding_WithInputType() { TaskType.SPARSE_EMBEDDING, "model", "", + null, + null, List.of("input"), null, InputType.SEARCH, @@ -209,11 +266,56 @@ public void testValidation_SparseEmbedding_WithInputType() { ); } + public void testValidation_SparseEmbedding_WithReturnDocument() { + InferenceAction.Request queryRequest = new InferenceAction.Request( + TaskType.SPARSE_EMBEDDING, + "model", + "", + Boolean.FALSE, + null, + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException queryError = queryRequest.validate(); + assertNotNull(queryError); + assertThat( + queryError.getMessage(), + is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [sparse_embedding];") + ); + + } + + public void testValidation_SparseEmbedding_WithTopN() { + InferenceAction.Request queryRequest = new InferenceAction.Request( + TaskType.SPARSE_EMBEDDING, + "model", + "", + null, + 22, + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException queryError = queryRequest.validate(); + assertNotNull(queryError); + assertThat( + queryError.getMessage(), + is("Validation Failed: 1: Field [top_n] cannot be specified for task type [sparse_embedding];") + ); + } + public void testValidation_Completion_WithInputType() { InferenceAction.Request queryRequest = new InferenceAction.Request( TaskType.COMPLETION, "model", "", + null, + null, List.of("input"), null, InputType.SEARCH, @@ -225,11 +327,52 @@ public void testValidation_Completion_WithInputType() { assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [input_type] cannot be specified for task type [completion];")); } + public void testValidation_Completion_WithReturnDocuments() { + InferenceAction.Request queryRequest = new InferenceAction.Request( + TaskType.COMPLETION, + "model", + "", + Boolean.TRUE, + null, + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException queryError = queryRequest.validate(); + assertNotNull(queryError); + assertThat( + queryError.getMessage(), + is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [completion];") + ); + } + + public void testValidation_Completion_WithTopN() { + InferenceAction.Request queryRequest = new InferenceAction.Request( + TaskType.COMPLETION, + "model", + "", + null, + 77, + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException queryError = queryRequest.validate(); + assertNotNull(queryError); + assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [completion];")); + } + public void testValidation_ChatCompletion_WithInputType() { InferenceAction.Request queryRequest = new InferenceAction.Request( TaskType.CHAT_COMPLETION, "model", "", + null, + null, List.of("input"), null, InputType.SEARCH, @@ -244,6 +387,45 @@ public void testValidation_ChatCompletion_WithInputType() { ); } + public void testValidation_ChatCompletion_WithReturnDocuments() { + InferenceAction.Request queryRequest = new InferenceAction.Request( + TaskType.CHAT_COMPLETION, + "model", + "", + Boolean.TRUE, + null, + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException queryError = queryRequest.validate(); + assertNotNull(queryError); + assertThat( + queryError.getMessage(), + is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [chat_completion];") + ); + } + + public void testValidation_ChatCompletion_WithTopN() { + InferenceAction.Request queryRequest = new InferenceAction.Request( + TaskType.CHAT_COMPLETION, + "model", + "", + null, + 11, + List.of("input"), + null, + InputType.SEARCH, + null, + false + ); + ActionRequestValidationException queryError = queryRequest.validate(); + assertNotNull(queryError); + assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [chat_completion];")); + } + public void testParseRequest_DefaultsInputTypeToIngest() throws IOException { String singleInputRequest = """ { @@ -271,6 +453,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc nextTask, instance.getInferenceEntityId(), instance.getQuery(), + instance.getReturnDocuments(), + instance.getTopN(), instance.getInput(), instance.getTaskSettings(), instance.getInputType(), @@ -283,6 +467,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskType(), instance.getInferenceEntityId() + "foo", instance.getQuery(), + instance.getReturnDocuments(), + instance.getTopN(), instance.getInput(), instance.getTaskSettings(), instance.getInputType(), @@ -297,6 +483,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskType(), instance.getInferenceEntityId(), instance.getQuery(), + instance.getReturnDocuments(), + instance.getTopN(), changedInputs, instance.getTaskSettings(), instance.getInputType(), @@ -317,6 +505,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskType(), instance.getInferenceEntityId(), instance.getQuery(), + instance.getReturnDocuments(), + instance.getTopN(), instance.getInput(), taskSettings, instance.getInputType(), @@ -331,6 +521,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskType(), instance.getInferenceEntityId(), instance.getQuery(), + instance.getReturnDocuments(), + instance.getTopN(), instance.getInput(), instance.getTaskSettings(), nextInputType, @@ -343,6 +535,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskType(), instance.getInferenceEntityId(), instance.getQuery() == null ? randomAlphaOfLength(10) : instance.getQuery() + randomAlphaOfLength(1), + instance.getReturnDocuments(), + instance.getTopN(), instance.getInput(), instance.getTaskSettings(), instance.getInputType(), @@ -360,6 +554,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskType(), instance.getInferenceEntityId(), instance.getQuery(), + instance.getReturnDocuments(), + instance.getTopN(), instance.getInput(), instance.getTaskSettings(), instance.getInputType(), @@ -374,6 +570,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskType(), instance.getInferenceEntityId(), instance.getQuery(), + instance.getReturnDocuments(), + instance.getTopN(), instance.getInput(), instance.getTaskSettings(), instance.getInputType(), @@ -395,6 +593,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getTaskType(), instance.getInferenceEntityId(), null, + null, + null, instance.getInput().subList(0, 1), instance.getTaskSettings(), InputType.UNSPECIFIED, @@ -406,6 +606,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getTaskType(), instance.getInferenceEntityId(), null, + null, + null, instance.getInput(), instance.getTaskSettings(), InputType.UNSPECIFIED, @@ -420,6 +622,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getTaskType(), instance.getInferenceEntityId(), null, + null, + null, instance.getInput(), instance.getTaskSettings(), InputType.INGEST, @@ -432,6 +636,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getTaskType(), instance.getInferenceEntityId(), null, + null, + null, instance.getInput(), instance.getTaskSettings(), InputType.UNSPECIFIED, @@ -443,6 +649,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getTaskType(), instance.getInferenceEntityId(), null, + null, + null, instance.getInput(), instance.getTaskSettings(), instance.getInputType(), @@ -455,6 +663,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getTaskType(), instance.getInferenceEntityId(), instance.getQuery(), + null, + null, instance.getInput(), instance.getTaskSettings(), instance.getInputType(), @@ -462,9 +672,24 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque false, InferenceContext.EMPTY_INSTANCE ); - } else { - mutated = instance; - } + } else if (version.before(TransportVersions.RERANK_COMMON_OPTIONS_ADDED) + && version.isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19) == false) { + mutated = new InferenceAction.Request( + instance.getTaskType(), + instance.getInferenceEntityId(), + instance.getQuery(), + null, + null, + instance.getInput(), + instance.getTaskSettings(), + instance.getInputType(), + instance.getInferenceTimeout(), + false, + instance.getContext() + ); + } else { + mutated = instance; + } // We always assume that a request has been rerouted, if it came from a node without adaptive rate limiting if (version.before(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) { @@ -481,6 +706,8 @@ public void testWriteTo_WhenVersionIsOnAfterUnspecifiedAdded() throws IOExceptio TaskType.TEXT_EMBEDDING, "model", null, + null, + null, List.of(), Map.of(), InputType.UNSPECIFIED, @@ -503,6 +730,8 @@ public void testWriteTo_WhenVersionIsBeforeInputTypeAdded_ShouldSetInputTypeToUn TaskType.TEXT_EMBEDDING, "model", null, + null, + null, List.of(), Map.of(), InputType.INGEST, @@ -525,6 +754,8 @@ public void testWriteTo_WhenVersionIsBeforeAdaptiveRateLimiting_ShouldSetHasBeen TaskType.TEXT_EMBEDDING, "model", null, + null, + null, List.of("input"), Map.of(), InputType.UNSPECIFIED, @@ -548,6 +779,8 @@ public void testWriteTo_WhenVersionIsBeforeInferenceContext_ShouldSetContextToEm TaskType.TEXT_EMBEDDING, "model", null, + null, + null, List.of("input"), Map.of(), InputType.UNSPECIFIED, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 51ae8b5437b44..ad6f1b88de328 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -110,6 +110,8 @@ public EnumSet supportedTaskTypes() { public void infer( Model model, @Nullable String query, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, List input, boolean stream, Map taskSettings, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index 8be2317a9ee6f..d4e3642affddb 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -102,6 +102,8 @@ public EnumSet supportedTaskTypes() { public void infer( Model model, @Nullable String query, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, List input, boolean stream, Map taskSettings, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index b860bb85ebd0e..6f533d83884ea 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -103,6 +103,8 @@ public EnumSet supportedTaskTypes() { public void infer( Model model, @Nullable String query, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, List input, boolean stream, Map taskSettings, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 8c876e9947bba..6bcec22bb50b3 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -103,6 +104,8 @@ public EnumSet supportedTaskTypes() { public void infer( Model model, String query, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, List input, boolean stream, Map taskSettings, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index e8f52e42f5708..7d24b7766baa3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -77,6 +77,8 @@ protected void doInference( service.infer( model, request.getQuery(), + request.getReturnDocuments(), + request.getTopN(), request.getInput(), request.isStreaming(), request.getTaskSettings(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java index e87c7f6eb014a..ceac02e3985bb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java @@ -75,7 +75,13 @@ public ExecutableAction create(VoyageAIRerankModel model, Map ta serviceComponents.threadPool(), overriddenModel, RERANK_HANDLER, - (rerankInput) -> new VoyageAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model), + (rerankInput) -> new VoyageAIRerankRequest( + rerankInput.getQuery(), + rerankInput.getChunks(), + rerankInput.getReturnDocuments(), + rerankInput.getTopN(), + model + ), QueryAndDocsInputs.class ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRerankRequestManager.java index 446db40aa5ae5..b50b1e3fbad87 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRerankRequestManager.java @@ -69,6 +69,8 @@ public void execute( account, rerankInput.getQuery(), rerankInput.getChunks(), + rerankInput.getReturnDocuments(), + rerankInput.getTopN(), model ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java index d27812b17399b..4d379a5c8fee0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java @@ -49,7 +49,13 @@ public void execute( ActionListener listener ) { var rerankInput = QueryAndDocsInputs.of(inferenceInputs); - CohereRerankRequest request = new CohereRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model); + CohereRerankRequest request = new CohereRerankRequest( + rerankInput.getQuery(), + rerankInput.getChunks(), + rerankInput.getReturnDocuments(), + rerankInput.getTopN(), + model + ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java index e74f0049fffb0..f499917a8c93d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java @@ -62,7 +62,13 @@ public void execute( ActionListener listener ) { var rerankInput = QueryAndDocsInputs.of(inferenceInputs); - GoogleVertexAiRerankRequest request = new GoogleVertexAiRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model); + GoogleVertexAiRerankRequest request = new GoogleVertexAiRerankRequest( + rerankInput.getQuery(), + rerankInput.getChunks(), + rerankInput.getReturnDocuments(), + rerankInput.getTopN(), + model + ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRerankRequestManager.java index 26f134873bca0..4fc49eaf442ec 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRerankRequestManager.java @@ -49,7 +49,13 @@ public void execute( ActionListener listener ) { var rerankInput = QueryAndDocsInputs.of(inferenceInputs); - JinaAIRerankRequest request = new JinaAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model); + JinaAIRerankRequest request = new JinaAIRerankRequest( + rerankInput.getQuery(), + rerankInput.getChunks(), + rerankInput.getReturnDocuments(), + rerankInput.getTopN(), + model + ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java index 5af5245ac5b40..d755ac982ac31 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.external.http.sender; +import org.elasticsearch.core.Nullable; + import java.util.List; import java.util.Objects; @@ -22,15 +24,25 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) { private final String query; private final List chunks; + private final Boolean returnDocuments; + private final Integer topN; public QueryAndDocsInputs(String query, List chunks) { - this(query, chunks, false); + this(query, chunks, null, null, false); } - public QueryAndDocsInputs(String query, List chunks, boolean stream) { + public QueryAndDocsInputs( + String query, + List chunks, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + boolean stream + ) { super(stream); this.query = Objects.requireNonNull(query); this.chunks = Objects.requireNonNull(chunks); + this.returnDocuments = returnDocuments; + this.topN = topN; } public String getQuery() { @@ -41,6 +53,14 @@ public List getChunks() { return chunks; } + public Boolean getReturnDocuments() { + return returnDocuments; + } + + public Integer getTopN() { + return topN; + } + public int inputSize() { return chunks.size(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequest.java index 878bcc6e6a0db..5e392725b9f49 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequest.java @@ -12,6 +12,7 @@ import org.apache.http.client.utils.URIBuilder; import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount; import org.elasticsearch.xpack.inference.external.request.HttpRequest; @@ -32,6 +33,8 @@ public class AlibabaCloudSearchRerankRequest implements Request { private final AlibabaCloudSearchAccount account; private final String query; private final List input; + private final Boolean returnDocuments; + private final Integer topN; private final URI uri; private final AlibabaCloudSearchRerankTaskSettings taskSettings; private final String model; @@ -44,6 +47,8 @@ public AlibabaCloudSearchRerankRequest( AlibabaCloudSearchAccount account, String query, List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, AlibabaCloudSearchRerankModel rerankModel ) { Objects.requireNonNull(rerankModel); @@ -51,6 +56,8 @@ public AlibabaCloudSearchRerankRequest( this.account = Objects.requireNonNull(account); this.query = Objects.requireNonNull(query); this.input = Objects.requireNonNull(input); + this.returnDocuments = returnDocuments; + this.topN = topN; taskSettings = rerankModel.getTaskSettings(); model = rerankModel.getServiceSettings().getCommonSettings().modelId(); host = rerankModel.getServiceSettings().getCommonSettings().getHost(); @@ -67,7 +74,8 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(uri); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new AlibabaCloudSearchRerankRequestEntity(query, input, taskSettings)).getBytes(StandardCharsets.UTF_8) + Strings.toString(new AlibabaCloudSearchRerankRequestEntity(query, input, returnDocuments, topN, taskSettings)) + .getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntity.java index 054e373e3e525..a5731f29d93e5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntity.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankTaskSettings; @@ -15,9 +16,13 @@ import java.util.List; import java.util.Objects; -public record AlibabaCloudSearchRerankRequestEntity(String query, List input, AlibabaCloudSearchRerankTaskSettings taskSettings) - implements - ToXContentObject { +public record AlibabaCloudSearchRerankRequestEntity( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + AlibabaCloudSearchRerankTaskSettings taskSettings +) implements ToXContentObject { private static final String SEARCH_QUERY = "query"; private static final String TEXTS_FIELD = "docs"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequest.java index 4ec04c0187329..283ed759884ce 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequest.java @@ -11,6 +11,7 @@ import org.apache.http.client.utils.URIBuilder; import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xpack.inference.external.cohere.CohereAccount; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; @@ -28,16 +29,26 @@ public class CohereRerankRequest extends CohereRequest { private final CohereAccount account; private final String query; private final List input; + private final Boolean returnDocuments; + private final Integer topN; private final CohereRerankTaskSettings taskSettings; private final String model; private final String inferenceEntityId; - public CohereRerankRequest(String query, List input, CohereRerankModel model) { + public CohereRerankRequest( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + CohereRerankModel model + ) { Objects.requireNonNull(model); this.account = CohereAccount.of(model, CohereRerankRequest::buildDefaultUri); this.input = Objects.requireNonNull(input); this.query = Objects.requireNonNull(query); + this.returnDocuments = returnDocuments; + this.topN = topN; taskSettings = model.getTaskSettings(); this.model = model.getServiceSettings().modelId(); inferenceEntityId = model.getInferenceEntityId(); @@ -48,7 +59,8 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(account.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new CohereRerankRequestEntity(query, input, taskSettings, model)).getBytes(StandardCharsets.UTF_8) + Strings.toString(new CohereRerankRequestEntity(query, input, returnDocuments, topN, taskSettings, model)) + .getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntity.java index e7abe0990eb0c..085aa0a14316e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntity.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.external.request.cohere; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; @@ -15,9 +16,14 @@ import java.util.List; import java.util.Objects; -public record CohereRerankRequestEntity(String model, String query, List documents, CohereRerankTaskSettings taskSettings) - implements - ToXContentObject { +public record CohereRerankRequestEntity( + String model, + String query, + List documents, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + CohereRerankTaskSettings taskSettings +) implements ToXContentObject { private static final String DOCUMENTS_FIELD = "documents"; private static final String QUERY_FIELD = "query"; @@ -29,8 +35,15 @@ public record CohereRerankRequestEntity(String model, String query, List Objects.requireNonNull(taskSettings); } - public CohereRerankRequestEntity(String query, List input, CohereRerankTaskSettings taskSettings, String model) { - this(model, query, input, taskSettings); + public CohereRerankRequestEntity( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + CohereRerankTaskSettings taskSettings, + String model + ) { + this(model, query, input, returnDocuments, topN, taskSettings); } @Override @@ -41,11 +54,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(QUERY_FIELD, query); builder.field(DOCUMENTS_FIELD, documents); - if (taskSettings.getDoesReturnDocuments() != null) { + // prefer the root level return_documents over task settings + if (returnDocuments != null) { + builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments); + } else if (taskSettings.getDoesReturnDocuments() != null) { builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments()); } - if (taskSettings.getTopNDocumentsOnly() != null) { + // prefer the root level top_n over task settings + if (topN != null) { + builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, topN); + } else if (taskSettings.getTopNDocumentsOnly() != null) { builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequest.java index 79606c63e0ed6..9004c061423c1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequest.java @@ -11,6 +11,7 @@ import org.apache.http.client.methods.HttpPost; import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; @@ -29,10 +30,22 @@ public class GoogleVertexAiRerankRequest implements GoogleVertexAiRequest { private final List input; - public GoogleVertexAiRerankRequest(String query, List input, GoogleVertexAiRerankModel model) { + private final Boolean returnDocuments; + + private final Integer topN; + + public GoogleVertexAiRerankRequest( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + GoogleVertexAiRerankModel model + ) { this.model = Objects.requireNonNull(model); this.query = Objects.requireNonNull(query); this.input = Objects.requireNonNull(input); + this.returnDocuments = returnDocuments; + this.topN = topN; } @Override @@ -41,7 +54,13 @@ public HttpRequest createHttpRequest() { ByteArrayEntity byteEntity = new ByteArrayEntity( Strings.toString( - new GoogleVertexAiRerankRequestEntity(query, input, model.getServiceSettings().modelId(), model.getTaskSettings().topN()) + new GoogleVertexAiRerankRequestEntity( + query, + input, + returnDocuments, + topN != null ? topN : model.getTaskSettings().topN(), + model.getServiceSettings().modelId() + ) ).getBytes(StandardCharsets.UTF_8) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntity.java index 2cac067f622cc..13f6b1da9fc86 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntity.java @@ -15,9 +15,13 @@ import java.util.List; import java.util.Objects; -public record GoogleVertexAiRerankRequestEntity(String query, List inputs, @Nullable String model, @Nullable Integer topN) - implements - ToXContentObject { +public record GoogleVertexAiRerankRequestEntity( + String query, + List inputs, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + @Nullable String model +) implements ToXContentObject { private static final String MODEL_FIELD = "model"; private static final String QUERY_FIELD = "query"; @@ -26,6 +30,7 @@ public record GoogleVertexAiRerankRequestEntity(String query, List input private static final String CONTENT_FIELD = "content"; private static final String TOP_N_FIELD = "topN"; + private static final String IGNORE_RECORD_DETAILS_IN_RESPONSE_FIELD = "ignoreRecordDetailsInResponse"; public GoogleVertexAiRerankRequestEntity { Objects.requireNonNull(query); @@ -57,10 +62,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endArray(); + // prefer the root level top_n over task settings if (topN != null) { builder.field(TOP_N_FIELD, topN); } + if (returnDocuments != null) { + // if returnDocuments = true, we do not want to ignore record details + builder.field(IGNORE_RECORD_DETAILS_IN_RESPONSE_FIELD, returnDocuments == Boolean.TRUE ? Boolean.FALSE : Boolean.TRUE); + } + builder.endObject(); return builder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequest.java index 93d4ab830c604..8994a23f42726 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequest.java @@ -11,6 +11,7 @@ import org.apache.http.client.utils.URIBuilder; import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xpack.inference.external.jinaai.JinaAIAccount; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; @@ -28,16 +29,26 @@ public class JinaAIRerankRequest extends JinaAIRequest { private final JinaAIAccount account; private final String query; private final List input; + private final Boolean returnDocuments; + private final Integer topN; private final JinaAIRerankTaskSettings taskSettings; private final String model; private final String inferenceEntityId; - public JinaAIRerankRequest(String query, List input, JinaAIRerankModel model) { + public JinaAIRerankRequest( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + JinaAIRerankModel model + ) { Objects.requireNonNull(model); this.account = JinaAIAccount.of(model, JinaAIRerankRequest::buildDefaultUri); this.input = Objects.requireNonNull(input); this.query = Objects.requireNonNull(query); + this.returnDocuments = returnDocuments; + this.topN = topN; taskSettings = model.getTaskSettings(); this.model = model.getServiceSettings().modelId(); inferenceEntityId = model.getInferenceEntityId(); @@ -48,7 +59,8 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(account.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new JinaAIRerankRequestEntity(query, input, taskSettings, model)).getBytes(StandardCharsets.UTF_8) + Strings.toString(new JinaAIRerankRequestEntity(query, input, returnDocuments, topN, taskSettings, model)) + .getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntity.java index 7f470d5fa91f5..1a770026f9d2d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntity.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.external.request.jinaai; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings; @@ -15,9 +16,14 @@ import java.util.List; import java.util.Objects; -public record JinaAIRerankRequestEntity(String model, String query, List documents, JinaAIRerankTaskSettings taskSettings) - implements - ToXContentObject { +public record JinaAIRerankRequestEntity( + String model, + String query, + List documents, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + JinaAIRerankTaskSettings taskSettings +) implements ToXContentObject { private static final String DOCUMENTS_FIELD = "documents"; private static final String QUERY_FIELD = "query"; @@ -30,8 +36,15 @@ public record JinaAIRerankRequestEntity(String model, String query, List Objects.requireNonNull(taskSettings); } - public JinaAIRerankRequestEntity(String query, List input, JinaAIRerankTaskSettings taskSettings, String model) { - this(model, query, input, taskSettings != null ? taskSettings : JinaAIRerankTaskSettings.EMPTY_SETTINGS); + public JinaAIRerankRequestEntity( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + JinaAIRerankTaskSettings taskSettings, + String model + ) { + this(model, query, input, returnDocuments, topN, taskSettings != null ? taskSettings : JinaAIRerankTaskSettings.EMPTY_SETTINGS); } @Override @@ -42,13 +55,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(QUERY_FIELD, query); builder.field(DOCUMENTS_FIELD, documents); - if (taskSettings.getTopNDocumentsOnly() != null) { + // prefer the root level top_n over task settings + if (topN != null) { + builder.field(JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY, topN); + } else if (taskSettings.getTopNDocumentsOnly() != null) { builder.field(JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly()); } - var return_documents = taskSettings.getDoesReturnDocuments(); - if (return_documents != null) { - builder.field(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, return_documents); + // prefer the root level return_documents over task settings + if (returnDocuments != null) { + builder.field(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments); + } else if (taskSettings.getDoesReturnDocuments() != null) { + builder.field(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments()); } builder.endObject(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java index 9fb50720e4d55..9b0b4268fc703 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java @@ -10,6 +10,7 @@ import org.apache.http.client.methods.HttpPost; import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel; @@ -23,13 +24,23 @@ public class VoyageAIRerankRequest extends VoyageAIRequest { private final String query; private final List input; + private final Boolean returnDocuments; + private final Integer topN; private final VoyageAIRerankModel model; - public VoyageAIRerankRequest(String query, List input, VoyageAIRerankModel model) { + public VoyageAIRerankRequest( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + VoyageAIRerankModel model + ) { this.model = Objects.requireNonNull(model); this.input = Objects.requireNonNull(input); this.query = Objects.requireNonNull(query); + this.returnDocuments = returnDocuments; + this.topN = topN; } @Override @@ -37,8 +48,16 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(model.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new VoyageAIRerankRequestEntity(query, input, model.getTaskSettings(), model.getServiceSettings().modelId())) - .getBytes(StandardCharsets.UTF_8) + Strings.toString( + new VoyageAIRerankRequestEntity( + query, + input, + returnDocuments, + topN, + model.getTaskSettings(), + model.getServiceSettings().modelId() + ) + ).getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java index 0f7baaa35044e..a52013f5d6f07 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.external.request.voyageai; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings; @@ -15,15 +16,19 @@ import java.util.List; import java.util.Objects; -public record VoyageAIRerankRequestEntity(String model, String query, List documents, VoyageAIRerankTaskSettings taskSettings) - implements - ToXContentObject { +public record VoyageAIRerankRequestEntity( + String model, + String query, + List documents, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + VoyageAIRerankTaskSettings taskSettings +) implements ToXContentObject { private static final String DOCUMENTS_FIELD = "documents"; private static final String QUERY_FIELD = "query"; private static final String MODEL_FIELD = "model"; public static final String TRUNCATION_FIELD = "truncation"; - public static final String RETURN_DOCUMENTS_FIELD = "return_documents"; public VoyageAIRerankRequestEntity { Objects.requireNonNull(query); @@ -32,8 +37,15 @@ public record VoyageAIRerankRequestEntity(String model, String query, List input, VoyageAIRerankTaskSettings taskSettings, String model) { - this(model, query, input, taskSettings != null ? taskSettings : VoyageAIRerankTaskSettings.EMPTY_SETTINGS); + public VoyageAIRerankRequestEntity( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + VoyageAIRerankTaskSettings taskSettings, + String model + ) { + this(model, query, input, returnDocuments, topN, taskSettings != null ? taskSettings : VoyageAIRerankTaskSettings.EMPTY_SETTINGS); } @Override @@ -44,11 +56,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(QUERY_FIELD, query); builder.field(DOCUMENTS_FIELD, documents); - if (taskSettings.getDoesReturnDocuments() != null) { + // prefer the root level return_documents over task settings + if (returnDocuments != null) { + builder.field(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments); + } else if (taskSettings.getDoesReturnDocuments() != null) { builder.field(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments()); } - if (taskSettings.getTopKDocumentsOnly() != null) { + // prefer the root level top_n over task settings + if (topN != null) { + builder.field(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, topN); + } else if (taskSettings.getTopKDocumentsOnly() != null) { builder.field(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, taskSettings.getTopKDocumentsOnly()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntity.java index 78673277797d2..f3c08c60c53b1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntity.java @@ -103,10 +103,6 @@ private static List doParse(XContentParser parser) return parseList(parser, (listParser, index) -> { var parsedRankedDoc = RankedDoc.parse(parser); - if (parsedRankedDoc.content == null) { - throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.CONTENT.getPreferredName())); - } - if (parsedRankedDoc.score == null) { throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.SCORE.getPreferredName())); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 96dbd3948cdc5..182c083ef1c26 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -232,6 +232,8 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu TaskType.ANY, inferenceId, null, + null, + null, List.of(query), Map.of(), InputType.INTERNAL_SEARCH, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index dae79f0105811..7f245ae854eac 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -153,6 +153,8 @@ protected InferenceAction.Request generateRequest(List docFeatures) { TaskType.RERANK, inferenceId, inferenceText, + null, + null, docFeatures, Map.of(), InputType.INTERNAL_SEARCH, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 6258e47517432..ddde0699ec6c7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -60,6 +60,8 @@ protected ServiceComponents getServiceComponents() { public void infer( Model model, @Nullable String query, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, List input, boolean stream, Map taskSettings, @@ -68,7 +70,7 @@ public void infer( ActionListener listener ) { init(); - var inferenceInput = createInput(this, model, input, inputType, query, stream); + var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream); doInfer(model, inferenceInput, taskSettings, timeout, listener); } @@ -78,11 +80,20 @@ private static InferenceInputs createInput( List input, InputType inputType, @Nullable String query, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, boolean stream ) { return switch (model.getTaskType()) { case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream); - case RERANK -> new QueryAndDocsInputs(query, input, stream); + case RERANK -> { + ValidationException validationException = new ValidationException(); + service.validateRerankParameters(returnDocuments, topN, validationException); + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + yield new QueryAndDocsInputs(query, input, returnDocuments, topN, stream); + } case TEXT_EMBEDDING, SPARSE_EMBEDDING -> { ValidationException validationException = new ValidationException(); service.validateInputType(inputType, model, validationException); @@ -141,6 +152,8 @@ protected abstract void doInfer( protected abstract void validateInputType(InputType inputType, Model model, ValidationException validationException); + protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) {} + protected abstract void doUnifiedCompletionInfer( Model model, UnifiedChatInput inputs, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 1ca63908ec5f1..a0c77599b6ce6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -735,6 +735,8 @@ public static void getEmbeddingSize(Model model, InferenceService service, Actio service.infer( model, null, + null, + null, List.of(TEST_EMBEDDING_INPUT), false, Map.of(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index fe844bbe0c1a3..bf1fbda2b826b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; @@ -300,6 +301,24 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException); } + @Override + protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) { + if (returnDocuments != null) { + validationException.addValidationError( + Strings.format( + "Invalid return_documents [%s]. The return_documents option is not supported by this service", + returnDocuments + ) + ); + } + + if (topN != null) { + validationException.addValidationError( + Strings.format("Invalid top_n [%s]. The top_n option is not supported by this service", topN) + ); + } + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 1962315325562..cbf203ee4a68b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -620,6 +620,8 @@ public void unifiedCompletionInfer( public void infer( Model model, @Nullable String query, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, List input, boolean stream, Map taskSettings, @@ -632,7 +634,7 @@ public void infer( if (TaskType.TEXT_EMBEDDING.equals(taskType)) { inferTextEmbedding(esModel, input, inputType, timeout, listener); } else if (TaskType.RERANK.equals(taskType)) { - inferRerank(esModel, query, input, inputType, timeout, taskSettings, listener); + inferRerank(esModel, query, input, returnDocuments, topN, inputType, timeout, taskSettings, listener); } else if (TaskType.SPARSE_EMBEDDING.equals(taskType)) { inferSparseEmbedding(esModel, input, inputType, timeout, listener); } else { @@ -693,6 +695,8 @@ public void inferRerank( ElasticsearchInternalModel model, String query, List inputs, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, InputType inputType, TimeValue timeout, Map requestTaskSettings, @@ -701,7 +705,9 @@ public void inferRerank( var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout); var returnDocs = Boolean.TRUE; - if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) { + if (returnDocuments != null) { + returnDocs = returnDocuments; + } else if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) { var requestSettings = RerankTaskSettings.fromMap(requestTaskSettings); returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments(); } @@ -709,7 +715,9 @@ public void inferRerank( Function inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null; ActionListener mlResultsListener = listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse(textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier)) + (l, inferenceResult) -> l.onResponse( + textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier, topN) + ) ); var maybeDeployListener = mlResultsListener.delegateResponse( @@ -824,7 +832,8 @@ public String name() { private RankedDocsResults textSimilarityResultsToRankedDocs( List results, - Function inputSupplier + Function inputSupplier, + @Nullable Integer topN ) { List rankings = new ArrayList<>(results.size()); for (int i = 0; i < results.size(); i++) { @@ -851,7 +860,7 @@ private RankedDocsResults textSimilarityResultsToRankedDocs( } Collections.sort(rankings); - return new RankedDocsResults(rankings); + return new RankedDocsResults(topN != null ? rankings.subList(0, topN) : rankings); } public List defaultConfigIds() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java index 08cb9933c2b3c..4c48e3018b956 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java @@ -30,6 +30,8 @@ public void validate(InferenceService service, Model model, ActionListener { - listenerAction.accept(ans.getArgument(7)); + listenerAction.accept(ans.getArgument(9)); return null; - }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); + }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any()); doAnswer(ans -> { listenerAction.accept(ans.getArgument(3)); return null; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java index 12d67ae3dc960..e91b0b3451a77 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java @@ -23,7 +23,7 @@ public void testCastToSucceeds() { var emptyRequest = new UnifiedCompletionRequest(List.of(), null, null, null, null, null, null, null); assertThat(new UnifiedChatInput(emptyRequest, false).castTo(UnifiedChatInput.class), Matchers.instanceOf(UnifiedChatInput.class)); assertThat( - new QueryAndDocsInputs("hello", List.of(), false).castTo(QueryAndDocsInputs.class), + new QueryAndDocsInputs("hello", List.of(), Boolean.TRUE, 33, false).castTo(QueryAndDocsInputs.class), Matchers.instanceOf(QueryAndDocsInputs.class) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntityTests.java index 8f981d64d36eb..0d48e7692b2e8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntityTests.java @@ -22,7 +22,13 @@ public class AlibabaCloudSearchRerankRequestEntityTests extends ESTestCase { public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { - var entity = new AlibabaCloudSearchRerankRequestEntity("query", List.of("abc"), new AlibabaCloudSearchRerankTaskSettings()); + var entity = new AlibabaCloudSearchRerankRequestEntity( + "query", + List.of("abc"), + Boolean.TRUE, + 22, + new AlibabaCloudSearchRerankTaskSettings() + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntityTests.java new file mode 100644 index 0000000000000..c33d72d6bd746 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntityTests.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.cohere; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; + +public class CohereRerankRequestEntityTests extends ESTestCase { + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = new CohereRerankRequestEntity( + "query", + List.of("abc"), + Boolean.TRUE, + 22, + new CohereRerankTaskSettings(null, null, 3), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":22,"max_chunks_per_doc":3}""")); + } + + public void testXContent_WritesMinimalFields() throws IOException { + var entity = new CohereRerankRequestEntity( + "query", + List.of("abc"), + null, + null, + new CohereRerankTaskSettings(null, null, null), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"model":"model","query":"query","documents":["abc"]}""")); + } + + public void testXContent_PrefersRootLevelReturnDocumentsAndTopN() throws IOException { + var entity = new CohereRerankRequestEntity( + "query", + List.of("abc"), + Boolean.FALSE, + 99, + new CohereRerankTaskSettings(33, Boolean.TRUE, null), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"model":"model","query":"query","documents":["abc"],"return_documents":false,"top_n":99}""")); + } + + public void testXContent_UsesTaskSettingsIfNoRootOptionsDefined() throws IOException { + var entity = new CohereRerankRequestEntity( + "query", + List.of("abc"), + null, + null, + new CohereRerankTaskSettings(33, Boolean.TRUE, null), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":33}""")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntityTests.java index fd18d2573efcc..764aedfc5a190 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntityTests.java @@ -20,8 +20,8 @@ import static org.hamcrest.MatcherAssert.assertThat; public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase { - public void testXContent_SingleRequest_WritesModelAndTopNIfDefined() throws IOException { - var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), "model", 8); + public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOException { + var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), Boolean.TRUE, 10, "model"); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -37,13 +37,14 @@ public void testXContent_SingleRequest_WritesModelAndTopNIfDefined() throws IOEx "content": "abc" } ], - "topN": 8 + "topN": 10, + "ignoreRecordDetailsInResponse": false } """)); } - public void testXContent_SingleRequest_DoesNotWriteModelAndTopNIfNull() throws IOException { - var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), null, null); + public void testXContent_SingleRequest_WritesMinimalFields() throws IOException { + var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), null, null, null); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -62,8 +63,8 @@ public void testXContent_SingleRequest_DoesNotWriteModelAndTopNIfNull() throws I """)); } - public void testXContent_MultipleRequests_WritesModelAndTopNIfDefined() throws IOException { - var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), "model", 8); + public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOException { + var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), Boolean.FALSE, 12, "model"); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -83,13 +84,14 @@ public void testXContent_MultipleRequests_WritesModelAndTopNIfDefined() throws I "content": "def" } ], - "topN": 8 + "topN": 12, + "ignoreRecordDetailsInResponse": true } """)); } - public void testXContent_MultipleRequests_DoesNotWriteModelAndTopNIfNull() throws IOException { - var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), null, null); + public void testXContent_MultipleRequests_WritesMinimalFields() throws IOException { + var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), null, null, null); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -111,5 +113,4 @@ public void testXContent_MultipleRequests_DoesNotWriteModelAndTopNIfNull() throw } """)); } - } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestTests.java index 811adb6612a4e..20aa270c08086 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestTests.java @@ -29,11 +29,11 @@ public class GoogleVertexAiRerankRequestTests extends ESTestCase { private static final String AUTH_HEADER_VALUE = "foo"; - public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException { + public void testCreateRequest_WithMinimalFieldsSet() throws IOException { var input = "input"; var query = "query"; - var request = createRequest(query, input, null, null); + var request = createRequest(query, input, null, null, null, null); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -53,8 +53,9 @@ public void testCreateRequest_WithTopNSet() throws IOException { var input = "input"; var query = "query"; var topN = 1; + var taskSettingsTopN = 3; - var request = createRequest(query, input, null, topN); + var request = createRequest(query, input, null, topN, null, taskSettingsTopN); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -71,12 +72,55 @@ public void testCreateRequest_WithTopNSet() throws IOException { assertThat(requestMap.get("topN"), is(topN)); } + public void testCreateRequest_UsesTaskSettingsTopNWhenRootLevelIsNull() throws IOException { + var input = "input"; + var query = "query"; + var topN = 1; + + var request = createRequest(query, input, null, null, null, topN); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + + assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap.get("records"), is(List.of(Map.of("id", "0", "content", input)))); + assertThat(requestMap.get("query"), is(query)); + assertThat(requestMap.get("topN"), is(topN)); + } + + public void testCreateRequest_WithReturnDocumentsSet() throws IOException { + var input = "input"; + var query = "query"; + + var request = createRequest(query, input, null, null, Boolean.TRUE, null); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + + assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap.get("records"), is(List.of(Map.of("id", "0", "content", input)))); + assertThat(requestMap.get("query"), is(query)); + assertThat(requestMap.get("ignoreRecordDetailsInResponse"), is(Boolean.FALSE)); + } + public void testCreateRequest_WithModelSet() throws IOException { var input = "input"; var query = "query"; var modelId = "model"; - var request = createRequest(query, input, modelId, null); + var request = createRequest(query, input, modelId, null, null, null); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -94,24 +138,37 @@ public void testCreateRequest_WithModelSet() throws IOException { } public void testTruncate_DoesNotTruncate() { - var request = createRequest("query", "input", null, null); + var request = createRequest("query", "input", null, null, null, null); var truncatedRequest = request.truncate(); assertThat(truncatedRequest, sameInstance(request)); } - private static GoogleVertexAiRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topN) { - var rerankModel = GoogleVertexAiRerankModelTests.createModel(modelId, topN); - - return new GoogleVertexAiRerankWithoutAuthRequest(query, List.of(input), rerankModel); + private static GoogleVertexAiRerankRequest createRequest( + String query, + String input, + @Nullable String modelId, + @Nullable Integer topN, + @Nullable Boolean returnDocuments, + @Nullable Integer taskSettingsTopN + ) { + var rerankModel = GoogleVertexAiRerankModelTests.createModel(modelId, taskSettingsTopN); + + return new GoogleVertexAiRerankWithoutAuthRequest(query, List.of(input), rerankModel, topN, returnDocuments); } /** * We use this class to fake the auth implementation to avoid static mocking of {@link GoogleVertexAiRequest} */ private static class GoogleVertexAiRerankWithoutAuthRequest extends GoogleVertexAiRerankRequest { - GoogleVertexAiRerankWithoutAuthRequest(String query, List input, GoogleVertexAiRerankModel model) { - super(query, input, model); + GoogleVertexAiRerankWithoutAuthRequest( + String query, + List input, + GoogleVertexAiRerankModel model, + @Nullable Integer topN, + @Nullable Boolean returnDocuments + ) { + super(query, input, returnDocuments, topN, model); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntityTests.java index 7fd738fa2a8e4..11f2810e13e01 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntityTests.java @@ -21,8 +21,15 @@ import static org.hamcrest.MatcherAssert.assertThat; public class JinaAIRerankRequestEntityTests extends ESTestCase { - public void testXContent_SingleRequest_WritesModelAndTopNIfDefined() throws IOException { - var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), new JinaAIRerankTaskSettings(8, null), "model"); + public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOException { + var entity = new JinaAIRerankRequestEntity( + "query", + List.of("abc"), + Boolean.TRUE, + 12, + new JinaAIRerankTaskSettings(8, Boolean.FALSE), + "model" + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -35,13 +42,14 @@ public void testXContent_SingleRequest_WritesModelAndTopNIfDefined() throws IOEx "documents": [ "abc" ], - "top_n": 8 + "top_n": 12, + "return_documents": true } """)); } - public void testXContent_SingleRequest_WritesModelAndTopNIfDefined_ReturnDocumentsTrue() throws IOException { - var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), new JinaAIRerankTaskSettings(8, true), "model"); + public void testXContent_SingleRequest_WritesMinimalFields() throws IOException { + var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), null, null, new JinaAIRerankTaskSettings(null, null), "model"); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -53,15 +61,20 @@ public void testXContent_SingleRequest_WritesModelAndTopNIfDefined_ReturnDocumen "query": "query", "documents": [ "abc" - ], - "top_n": 8, - "return_documents": true + ] } """)); } - public void testXContent_SingleRequest_WritesModelAndTopNIfDefined_ReturnDocumentsFalse() throws IOException { - var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), new JinaAIRerankTaskSettings(8, false), "model"); + public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOException { + var entity = new JinaAIRerankRequestEntity( + "query", + List.of("abc", "def"), + Boolean.FALSE, + 12, + new JinaAIRerankTaskSettings(8, Boolean.TRUE), + "model" + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -72,16 +85,17 @@ public void testXContent_SingleRequest_WritesModelAndTopNIfDefined_ReturnDocumen "model": "model", "query": "query", "documents": [ - "abc" + "abc", + "def" ], - "top_n": 8, + "top_n": 12, "return_documents": false } """)); } - public void testXContent_SingleRequest_DoesNotWriteTopNIfNull() throws IOException { - var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), null, "model"); + public void testXContent_MultipleRequests_WritesMinimalFields() throws IOException { + var entity = new JinaAIRerankRequestEntity("query", List.of("abc", "def"), null, null, null, "model"); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -92,14 +106,22 @@ public void testXContent_SingleRequest_DoesNotWriteTopNIfNull() throws IOExcepti "model": "model", "query": "query", "documents": [ - "abc" + "abc", + "def" ] } """)); } - public void testXContent_MultipleRequests_WritesModelAndTopNIfDefined() throws IOException { - var entity = new JinaAIRerankRequestEntity("query", List.of("abc", "def"), new JinaAIRerankTaskSettings(8, null), "model"); + public void testXContent_SingleRequest_UsesTaskSettingsTopNIfRootIsNotDefined() throws IOException { + var entity = new JinaAIRerankRequestEntity( + "query", + List.of("abc"), + null, + null, + new JinaAIRerankTaskSettings(8, Boolean.FALSE), + "model" + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -110,29 +132,10 @@ public void testXContent_MultipleRequests_WritesModelAndTopNIfDefined() throws I "model": "model", "query": "query", "documents": [ - "abc", - "def" + "abc" ], - "top_n": 8 - } - """)); - } - - public void testXContent_MultipleRequests_DoesNotWriteTopNIfNull() throws IOException { - var entity = new JinaAIRerankRequestEntity("query", List.of("abc", "def"), null, "model"); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" - { - "model": "model", - "query": "query", - "documents": [ - "abc", - "def" - ] + "top_n": 8, + "return_documents": false } """)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestTests.java index 819362d397ba5..439bcf3ae006a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestTests.java @@ -27,12 +27,12 @@ public class JinaAIRerankRequestTests extends ESTestCase { private static final String API_KEY = "foo"; - public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException { + public void testCreateRequest_WithMinimalFieldsSet() throws IOException { var input = "input"; var query = "query"; var modelId = "model"; - var request = createRequest(query, input, modelId, null); + var request = createRequest(query, input, modelId, null, null, null); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -49,13 +49,14 @@ public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOExce assertThat(requestMap.get("model"), is(modelId)); } - public void testCreateRequest_WithTopNSet() throws IOException { + public void testCreateRequest_WithAllFieldsSet() throws IOException { var input = "input"; var query = "query"; var topN = 1; + var taskSettingsTopN = 2; var modelId = "model"; - var request = createRequest(query, input, modelId, topN); + var request = createRequest(query, input, modelId, topN, Boolean.FALSE, taskSettingsTopN); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -66,10 +67,11 @@ public void testCreateRequest_WithTopNSet() throws IOException { var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(4)); + assertThat(requestMap, aMapWithSize(5)); assertThat(requestMap.get("documents"), is(List.of(input))); assertThat(requestMap.get("query"), is(query)); assertThat(requestMap.get("top_n"), is(topN)); + assertThat(requestMap.get("return_documents"), is(Boolean.FALSE)); assertThat(requestMap.get("model"), is(modelId)); } @@ -78,7 +80,7 @@ public void testCreateRequest_WithModelSet() throws IOException { var query = "query"; var modelId = "model"; - var request = createRequest(query, input, modelId, null); + var request = createRequest(query, input, modelId, null, null, null); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -96,15 +98,22 @@ public void testCreateRequest_WithModelSet() throws IOException { } public void testTruncate_DoesNotTruncate() { - var request = createRequest("query", "input", "null", null); + var request = createRequest("query", "input", "null", null, null, null); var truncatedRequest = request.truncate(); assertThat(truncatedRequest, sameInstance(request)); } - private static JinaAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topN) { - var rerankModel = JinaAIRerankModelTests.createModel(API_KEY, modelId, topN); - return new JinaAIRerankRequest(query, List.of(input), rerankModel); + private static JinaAIRerankRequest createRequest( + String query, + String input, + @Nullable String modelId, + @Nullable Integer topN, + @Nullable Boolean returnDocuments, + @Nullable Integer taskSettingsTopN + ) { + var rerankModel = JinaAIRerankModelTests.createModel(API_KEY, modelId, taskSettingsTopN); + return new JinaAIRerankRequest(query, List.of(input), returnDocuments, topN, rerankModel); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java index ae431b4b7bb13..f05e9052861f8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java @@ -20,8 +20,15 @@ import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; public class VoyageAIRerankRequestEntityTests extends ESTestCase { - public void testXContent_SingleRequest_WritesModelAndTopKIfDefined() throws IOException { - var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, null, null), "model"); + public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOException { + var entity = new VoyageAIRerankRequestEntity( + "query", + List.of("abc"), + Boolean.TRUE, + 12, + new VoyageAIRerankTaskSettings(8, Boolean.FALSE, null), + "model" + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -34,13 +41,21 @@ public void testXContent_SingleRequest_WritesModelAndTopKIfDefined() throws IOEx "documents": [ "abc" ], - "top_k": 8 + "return_documents": true, + "top_k": 12 } """)); } - public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumentsTrue() throws IOException { - var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, true, null), "model"); + public void testXContent_SingleRequest_WritesMinimalFields() throws IOException { + var entity = new VoyageAIRerankRequestEntity( + "query", + List.of("abc"), + null, + null, + new VoyageAIRerankTaskSettings(null, true, null), + "model" + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -53,14 +68,20 @@ public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumen "documents": [ "abc" ], - "return_documents": true, - "top_k": 8 + "return_documents": true } """)); } - public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumentsFalse() throws IOException { - var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, null), "model"); + public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationTrue() throws IOException { + var entity = new VoyageAIRerankRequestEntity( + "query", + List.of("abc"), + null, + null, + new VoyageAIRerankTaskSettings(8, false, true), + "model" + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -74,13 +95,21 @@ public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumen "abc" ], "return_documents": false, - "top_k": 8 + "top_k": 8, + "truncation": true } """)); } - public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationTrue() throws IOException { - var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, true), "model"); + public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationFalse() throws IOException { + var entity = new VoyageAIRerankRequestEntity( + "query", + List.of("abc"), + null, + null, + new VoyageAIRerankTaskSettings(8, false, false), + "model" + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -95,13 +124,20 @@ public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationTru ], "return_documents": false, "top_k": 8, - "truncation": true + "truncation": false } """)); } - public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationFalse() throws IOException { - var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, false), "model"); + public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOException { + var entity = new VoyageAIRerankRequestEntity( + "query", + List.of("abc", "def"), + Boolean.FALSE, + 11, + new VoyageAIRerankTaskSettings(8, null, null), + "model" + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -112,17 +148,17 @@ public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationFal "model": "model", "query": "query", "documents": [ - "abc" + "abc", + "def" ], "return_documents": false, - "top_k": 8, - "truncation": false + "top_k": 11 } """)); } - public void testXContent_SingleRequest_DoesNotWriteTopKIfNull() throws IOException { - var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), null, "model"); + public void testXContent_MultipleRequests_DoesNotWriteTopKIfNull() throws IOException { + var entity = new VoyageAIRerankRequestEntity("query", List.of("abc", "def"), null, null, null, "model"); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -133,17 +169,20 @@ public void testXContent_SingleRequest_DoesNotWriteTopKIfNull() throws IOExcepti "model": "model", "query": "query", "documents": [ - "abc" + "abc", + "def" ] } """)); } - public void testXContent_MultipleRequests_WritesModelAndTopKIfDefined() throws IOException { + public void testXContent_UsesTaskSettingsTopNIfRootIsNotDefined() throws IOException { var entity = new VoyageAIRerankRequestEntity( "query", - List.of("abc", "def"), - new VoyageAIRerankTaskSettings(8, null, null), + List.of("abc"), + null, + null, + new VoyageAIRerankTaskSettings(8, Boolean.FALSE, null), "model" ); @@ -156,31 +195,12 @@ public void testXContent_MultipleRequests_WritesModelAndTopKIfDefined() throws I "model": "model", "query": "query", "documents": [ - "abc", - "def" + "abc" ], + "return_documents": false, "top_k": 8 } """)); } - public void testXContent_MultipleRequests_DoesNotWriteTopKIfNull() throws IOException { - var entity = new VoyageAIRerankRequestEntity("query", List.of("abc", "def"), null, "model"); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" - { - "model": "model", - "query": "query", - "documents": [ - "abc", - "def" - ] - } - """)); - } - } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java index a11d259200b98..00237496304d0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java @@ -27,12 +27,12 @@ public class VoyageAIRerankRequestTests extends ESTestCase { private static final String API_KEY = "foo"; - public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException { + public void testCreateRequest_WithMinimalFields() throws IOException { var input = "input"; var query = "query"; var modelId = "model"; - var request = createRequest(query, input, modelId, null); + var request = createRequest(query, input, modelId, null, null, null); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -49,13 +49,14 @@ public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOExce assertThat(requestMap.get("model"), is(modelId)); } - public void testCreateRequest_WithTopNSet() throws IOException { + public void testCreateRequest_WithAllFieldsDefined() throws IOException { var input = "input"; var query = "query"; var topK = 1; + var taskSettingsTopK = 2; var modelId = "model"; - var request = createRequest(query, input, modelId, topK); + var request = createRequest(query, input, modelId, topK, Boolean.FALSE, taskSettingsTopK); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -66,11 +67,12 @@ public void testCreateRequest_WithTopNSet() throws IOException { var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(4)); + assertThat(requestMap, aMapWithSize(5)); assertThat(requestMap.get("documents"), is(List.of(input))); assertThat(requestMap.get("query"), is(query)); assertThat(requestMap.get("top_k"), is(topK)); assertThat(requestMap.get("model"), is(modelId)); + assertThat(requestMap.get("return_documents"), is(Boolean.FALSE)); } public void testCreateRequest_WithModelSet() throws IOException { @@ -78,7 +80,7 @@ public void testCreateRequest_WithModelSet() throws IOException { var query = "query"; var modelId = "model"; - var request = createRequest(query, input, modelId, null); + var request = createRequest(query, input, modelId, null, null, null); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -96,15 +98,22 @@ public void testCreateRequest_WithModelSet() throws IOException { } public void testTruncate_DoesNotTruncate() { - var request = createRequest("query", "input", "null", null); + var request = createRequest("query", "input", "null", null, null, null); var truncatedRequest = request.truncate(); assertThat(truncatedRequest, sameInstance(request)); } - private static VoyageAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topK) { - var rerankModel = VoyageAIRerankModelTests.createModel(API_KEY, modelId, topK); - return new VoyageAIRerankRequest(query, List.of(input), rerankModel); + private static VoyageAIRerankRequest createRequest( + String query, + String input, + @Nullable String modelId, + @Nullable Integer topK, + @Nullable Boolean returnDocuments, + @Nullable Integer taskSettingsTopK + ) { + var rerankModel = VoyageAIRerankModelTests.createModel(API_KEY, modelId, taskSettingsTopK); + return new VoyageAIRerankRequest(query, List.of(input), returnDocuments, topK, rerankModel); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntityTests.java index 7ff79e2618425..eba6887fe5c4e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntityTests.java @@ -42,6 +42,26 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, "content 2")))); } + public void testFromResponse_CreatesResultsForASingleItem_NoContent() throws IOException { + String responseJson = """ + { + "records": [ + { + "id": "2", + "title": "title 2", + "score": 0.97 + } + ] + } + """; + + RankedDocsResults parsedResults = GoogleVertexAiRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, null)))); + } + public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { String responseJson = """ { @@ -72,40 +92,38 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException ); } - public void testFromResponse_FailsWhenRecordsFieldIsNotPresent() { + public void testFromResponse_CreatesResultsForMultipleItems_NoContent() throws IOException { String responseJson = """ { - "not_records": [ + "records": [ { "id": "2", "title": "title 2", - "content": "content 2", "score": 0.97 }, { "id": "1", "title": "title 1", - "content": "content 1", "score": 0.90 } ] } """; - var thrownException = expectThrows( - IllegalStateException.class, - () -> GoogleVertexAiRerankResponseEntity.fromResponse( - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) - ) + RankedDocsResults parsedResults = GoogleVertexAiRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(thrownException.getMessage(), is("Failed to find required field [records] in Google Vertex AI rerank response")); + assertThat( + parsedResults.getRankedDocs(), + is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, null), new RankedDocsResults.RankedDoc(1, 0.90F, null))) + ); } - public void testFromResponse_FailsWhenContentFieldIsNotPresent() { + public void testFromResponse_FailsWhenRecordsFieldIsNotPresent() { String responseJson = """ { - "records": [ + "not_records": [ { "id": "2", "title": "title 2", @@ -113,10 +131,10 @@ public void testFromResponse_FailsWhenContentFieldIsNotPresent() { "score": 0.97 }, { - "id": "1", - "title": "title 1", - "not_content": "content 1", - "score": 0.97 + "id": "1", + "title": "title 1", + "content": "content 1", + "score": 0.90 } ] } @@ -129,7 +147,7 @@ public void testFromResponse_FailsWhenContentFieldIsNotPresent() { ) ); - assertThat(thrownException.getMessage(), is("Failed to find required field [content] in Google Vertex AI rerank response")); + assertThat(thrownException.getMessage(), is("Failed to find required field [records] in Google Vertex AI rerank response")); } public void testFromResponse_FailsWhenScoreFieldIsNotPresent() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java index c4b3c07839923..ae6e5fb5a53a4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java @@ -98,6 +98,8 @@ protected InferenceAction.Request generateRequest(List docFeatures) { TaskType.RERANK, this.inferenceId, inferenceText, + null, + null, docFeatures, Map.of("inferenceResultCount", inferenceResultCount), InputType.INTERNAL_SEARCH, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java index 0d821f411d0b5..dc0e2cc10501d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java @@ -225,6 +225,8 @@ protected InferenceAction.Request generateRequest(List docFeatures) { TaskType.RERANK, inferenceId, inferenceText, + null, + null, docFeatures, Map.of("throwing", true), InputType.INTERNAL_SEARCH, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index dfc64b8fb9324..190520fbc3b68 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -910,11 +910,11 @@ public void testGetEmbeddingSize_ReturnsError_WhenTextEmbeddingResults_IsEmpty() when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(7); + ActionListener listener = invocation.getArgument(9); listener.onResponse(new TextEmbeddingFloatResults(List.of())); return Void.TYPE; - }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); + }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); getEmbeddingSize(model, service, listener); @@ -932,11 +932,11 @@ public void testGetEmbeddingSize_ReturnsError_WhenTextEmbeddingByteResults_IsEmp when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(7); + ActionListener listener = invocation.getArgument(9); listener.onResponse(new TextEmbeddingByteResults(List.of())); return Void.TYPE; - }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); + }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); getEmbeddingSize(model, service, listener); @@ -956,11 +956,11 @@ public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingResults() { var textEmbedding = TextEmbeddingFloatResultsTests.createRandomResults(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(7); + ActionListener listener = invocation.getArgument(9); listener.onResponse(textEmbedding); return Void.TYPE; - }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); + }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); getEmbeddingSize(model, service, listener); @@ -979,11 +979,11 @@ public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingByteResults() { var textEmbedding = TextEmbeddingByteResultsTests.createRandomResults(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(7); + ActionListener listener = invocation.getArgument(9); listener.onResponse(textEmbedding); return Void.TYPE; - }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); + }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); getEmbeddingSize(model, service, listener); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index 2e0f64ed4ef9f..bcf2fb85ae9d8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -389,6 +389,8 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType_TextEmbedding() t () -> service.infer( model, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -431,6 +433,8 @@ public void testInfer_ThrowsValidationExceptionForInvalidInputType_SparseEmbeddi () -> service.infer( model, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -446,6 +450,53 @@ public void testInfer_ThrowsValidationExceptionForInvalidInputType_SparseEmbeddi } } + public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + Map serviceSettingsMap = new HashMap<>(); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id"); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host"); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default"); + serviceSettingsMap.put(ServiceFields.DIMENSIONS, 1536); + + Map taskSettingsMap = new HashMap<>(); + + Map secretSettingsMap = new HashMap<>(); + secretSettingsMap.put("api_key", "secret"); + + var model = AlibabaCloudSearchEmbeddingsModelTests.createModel( + "service", + TaskType.RERANK, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap + ); + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + PlainActionFuture listener = new PlainActionFuture<>(); + var thrownException = expectThrows( + ValidationException.class, + () -> service.infer( + model, + "hi", + Boolean.TRUE, + 10, + List.of("a"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ) + ); + assertThat( + thrownException.getMessage(), + is( + "Validation Failed: 1: Invalid return_documents [true]. The return_documents option is not supported by this " + + "service;2: Invalid top_n [10]. The top_n option is not supported by this service;" + ) + ); + } + } + public void testChunkedInfer_TextEmbeddingChunkingSettingsSet() throws IOException { testChunkedInfer(TaskType.TEXT_EMBEDDING, ChunkingSettingsTests.createRandomChunkingSettings()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 1b9bba3fa1b01..0ec3799cb7dc3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -932,6 +932,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotAmazonBedrockModel() throws IOExc service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -979,6 +981,8 @@ public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForProviderTh () -> service.infer( model, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -1029,6 +1033,8 @@ public void testInfer_SendsRequest_ForEmbeddingsModel() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1071,6 +1077,8 @@ public void testInfer_SendsRequest_ForChatCompletionModel() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1414,6 +1422,8 @@ public void testInfer_UnauthorizedResponse() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index 405aba35e8d3a..71e35aa211d3c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -458,6 +458,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -513,6 +515,8 @@ public void testInfer_SendsCompletionRequest() throws IOException { service.infer( model, null, + null, + null, List.of("input"), false, new HashMap<>(), @@ -571,6 +575,8 @@ private InferenceEventsAssertion streamChatCompletion() throws Exception { service.infer( model, null, + null, + null, List.of("abc"), true, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index e20fc54598aab..00cfa2a53f8b3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -1096,6 +1096,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOExc service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -1134,6 +1136,8 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOExcept () -> service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -1296,6 +1300,8 @@ public void testInfer_WithChatCompletionModel() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1347,6 +1353,8 @@ public void testInfer_UnauthorisedResponse() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1403,6 +1411,8 @@ private InferenceEventsAssertion streamChatCompletion() throws Exception { service.infer( model, null, + null, + null, List.of("abc"), true, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 2e0c0d04fa9cf..ffda34f0e8fdd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -766,6 +766,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -822,6 +824,8 @@ public void testInfer_SendsRequest() throws IOException, URISyntaxException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1286,6 +1290,8 @@ public void testInfer_UnauthorisedResponse() throws IOException, URISyntaxExcept service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1453,6 +1459,8 @@ private InferenceEventsAssertion streamChatCompletion() throws Exception { service.infer( model, null, + null, + null, List.of("abc"), true, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index dec1052589c93..b17a8b29bce26 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -788,6 +788,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -856,6 +858,8 @@ public void testInfer_SendsRequest() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1147,6 +1151,8 @@ public void testInfer_UnauthorisedResponse() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1207,6 +1213,8 @@ public void testInfer_SetsInputTypeToIngest_FromInferParameter_WhenTaskSettingsA service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1281,6 +1289,8 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs service.infer( model, null, + null, + null, List.of("abc"), false, CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH, null), @@ -1353,6 +1363,8 @@ public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspec service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1629,6 +1641,8 @@ private InferenceEventsAssertion streamChatCompletion() throws Exception { service.infer( model, null, + null, + null, List.of("abc"), true, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java index aa1313793274b..88a2fc76aadcf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -232,7 +232,7 @@ public void testDoInfer() throws Exception { try (var service = createService()) { var model = createModel(service, TaskType.COMPLETION); PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(model, null, List.of("hello"), false, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener); + service.infer(model, null, null, null, List.of("hello"), false, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); assertThat(result, isA(ChatCompletionResults.class)); var completionResults = (ChatCompletionResults) result; @@ -255,7 +255,7 @@ public void testDoInferStream() throws Exception { try (var service = createService()) { var model = createModel(service, TaskType.COMPLETION); PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(model, null, List.of("hello"), true, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener); + service.infer(model, null, null, null, List.of("hello"), true, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener); InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream().hasNoErrors().hasEvent(""" {"completion":[{"delta":"hello, world"}]}"""); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index f9fb6521e979b..4f61269fcc6c2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -368,6 +368,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -404,6 +406,8 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -443,6 +447,8 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -494,6 +500,8 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { service.infer( model, null, + null, + null, List.of("input text"), false, new HashMap<>(), @@ -551,6 +559,8 @@ public void testInfer_PropagatesProductUseCaseHeader() throws IOException { service.infer( model, null, + null, + null, List.of("input text"), false, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 3a50e716ab160..a1430d36a0f5b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -662,6 +662,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotGoogleAiStudioModel() throws IOEx service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -700,6 +702,8 @@ public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForModelThatD () -> service.infer( model, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -775,6 +779,8 @@ public void testInfer_SendsCompletionRequest() throws IOException { service.infer( model, null, + null, + null, List.of("input"), false, new HashMap<>(), @@ -832,6 +838,8 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { service.infer( model, null, + null, + null, List.of(input), false, new HashMap<>(), @@ -1005,6 +1013,8 @@ public void testInfer_ResourceNotFound() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java index fe8172cf5db07..3be4b72c1237f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java @@ -65,6 +65,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotHuggingFaceModel() throws IOExcep service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 7c4b0de656c35..9575321494923 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -556,6 +556,8 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -593,6 +595,8 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { () -> service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -627,6 +631,8 @@ public void testInfer_SendsElserRequest() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 7a78dbce6310f..3f508c6cbca50 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -602,6 +602,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotIbmWatsonxModel() throws IOExcept service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -641,6 +643,8 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { () -> service.infer( model, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -697,6 +701,8 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { service.infer( model, null, + null, + null, List.of(input), false, new HashMap<>(), @@ -840,6 +846,8 @@ public void testInfer_ResourceNotFound() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index fabcca09d3e31..e1446c36b893e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -782,6 +782,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotJinaAIModel() throws IOException service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -1044,6 +1046,8 @@ public void testInfer_Embedding_UnauthorisedResponse() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1076,6 +1080,8 @@ public void testInfer_Rerank_UnauthorisedResponse() throws IOException { service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2"), false, new HashMap<>(), @@ -1132,6 +1138,8 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1201,6 +1209,8 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1254,6 +1264,8 @@ public void testInfer_Embedding_Get_Response_clustering() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1320,7 +1332,18 @@ public void testInfer_Embedding_Get_Response_NullInputType() throws IOException JinaAIEmbeddingType.FLOAT ); PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(model, null, List.of("abc"), false, new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, listener); + service.infer( + model, + null, + null, + null, + List.of("abc"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); @@ -1371,6 +1394,8 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2", "candidate3"), false, new HashMap<>(), @@ -1454,6 +1479,8 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2", "candidate3", "candidate4"), false, new HashMap<>(), @@ -1549,6 +1576,8 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2", "candidate3"), false, new HashMap<>(), @@ -1630,6 +1659,8 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2", "candidate3", "candidate4"), false, new HashMap<>(), @@ -1724,6 +1755,8 @@ public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index a2ee005d719fc..db771f13cc0a6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -586,6 +586,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws I service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -625,6 +627,8 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { () -> service.infer( model, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -781,6 +785,8 @@ public void testInfer_UnauthorisedResponse() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 219ff210cfa9d..70452db7c171a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -852,6 +852,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -890,6 +892,8 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { () -> service.infer( model, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -925,6 +929,8 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -964,6 +970,8 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -1024,6 +1032,8 @@ public void testInfer_SendsRequest() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1263,6 +1273,8 @@ private InferenceEventsAssertion streamCompletion() throws Exception { service.infer( model, null, + null, + null, List.of("abc"), true, new HashMap<>(), @@ -1794,6 +1806,8 @@ public void testInfer_UnauthorisedResponse() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java index 44de4a3d9ccdd..9ee2201b4f02b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java @@ -63,6 +63,8 @@ public void testValidate_ServiceThrowsException() { .infer( eq(mockModel), eq(null), + eq(null), + eq(null), eq(TEST_INPUT), eq(false), eq(Map.of()), @@ -97,13 +99,15 @@ public void testValidate_SuccessfulCallToServiceForReRankTaskType() { private void mockSuccessfulCallToService(String query, InferenceServiceResults result) { doAnswer(ans -> { - ActionListener responseListener = ans.getArgument(7); + ActionListener responseListener = ans.getArgument(9); responseListener.onResponse(result); return null; }).when(mockInferenceService) .infer( eq(mockModel), eq(query), + eq(null), + eq(null), eq(TEST_INPUT), eq(false), eq(Map.of()), @@ -120,6 +124,8 @@ private void verifyCallToService(boolean withQuery) { verify(mockInferenceService).infer( eq(mockModel), eq(withQuery ? TEST_QUERY : null), + eq(null), + eq(null), eq(TEST_INPUT), eq(false), eq(Map.of()), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index bddcc27194c47..521d042bb8615 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -722,6 +722,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotVoyageAIModel() throws IOExceptio service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -768,6 +770,8 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOExcept () -> service.infer( model, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -1017,6 +1021,8 @@ public void testInfer_Embedding_UnauthorisedResponse() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1049,6 +1055,8 @@ public void testInfer_Rerank_UnauthorisedResponse() throws IOException { service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2"), false, new HashMap<>(), @@ -1103,6 +1111,8 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1183,6 +1193,8 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1260,7 +1272,18 @@ public void testInfer_Embedding_Get_Response_NullInputType() throws IOException (SimilarityMeasure) null ); PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(model, null, List.of("abc"), false, new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, listener); + service.infer( + model, + null, + null, + null, + List.of("abc"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); @@ -1315,6 +1338,8 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2", "candidate3"), false, new HashMap<>(), @@ -1401,6 +1426,8 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2", "candidate3", "candidate4"), false, new HashMap<>(), @@ -1493,6 +1520,8 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2", "candidate3"), false, new HashMap<>(), @@ -1569,6 +1598,8 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2", "candidate3", "candidate4"), false, new HashMap<>(), @@ -1663,6 +1694,8 @@ public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java index 65b5d3b3110fd..a3ad596447725 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java @@ -123,6 +123,8 @@ private void doInferenceServiceModel(CoordinatedInferenceAction.Request request, TaskType.ANY, request.getModelId(), null, + null, + null, request.getInputs(), request.getTaskSettings(), inputType,