diff --git a/server/src/main/java/org/elasticsearch/inference/RerankingInferenceService.java b/server/src/main/java/org/elasticsearch/inference/RerankingInferenceService.java
new file mode 100644
index 0000000000000..d65a0970b1d92
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/inference/RerankingInferenceService.java
@@ -0,0 +1,26 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.inference;
+
+public interface RerankingInferenceService {
+
+ /**
+ * The default window size for small reranking models (512 input tokens).
+ */
+ int CONSERVATIVE_DEFAULT_WINDOW_SIZE = 300;
+
+ /**
+ * The reranking model's max window or an approximation of
+ * measured in the number of words.
+ * @param modelId The model ID
+ * @return Window size in words
+ */
+ int rerankerWindowSize(String modelId);
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetRerankerWindowSizeAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetRerankerWindowSizeAction.java
new file mode 100644
index 0000000000000..5035461f5f2a0
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetRerankerWindowSizeAction.java
@@ -0,0 +1,103 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.inference.action;
+
+import org.elasticsearch.action.ActionRequest;
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.action.ActionType;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class GetRerankerWindowSizeAction extends ActionType {
+
+ public static final GetRerankerWindowSizeAction INSTANCE = new GetRerankerWindowSizeAction();
+ public static final String NAME = "cluster:internal/xpack/inference/rerankwindowsize/get";
+
+ public GetRerankerWindowSizeAction() {
+ super(NAME);
+ }
+
+ public static class Request extends ActionRequest {
+
+ private final String inferenceEntityId;
+
+ public Request(String inferenceEntityId) {
+ this.inferenceEntityId = inferenceEntityId;
+ }
+
+ public Request(StreamInput in) throws IOException {
+ super(in);
+ this.inferenceEntityId = in.readString();
+ }
+
+ public String getInferenceEntityId() {
+ return inferenceEntityId;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ super.writeTo(out);
+ out.writeString(inferenceEntityId);
+ }
+
+ @Override
+ public ActionRequestValidationException validate() {
+ return null;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o == null || getClass() != o.getClass()) return false;
+ Request request = (Request) o;
+ return Objects.equals(inferenceEntityId, request.inferenceEntityId);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(inferenceEntityId);
+ }
+ }
+
+ public static class Response extends ActionResponse {
+
+ private final int windowSize;
+
+ public Response(int windowSize) {
+ this.windowSize = windowSize;
+ }
+
+ public Response(StreamInput in) throws IOException {
+ this.windowSize = in.readVInt();
+ }
+
+ public int getWindowSize() {
+ return windowSize;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeVInt(windowSize);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o == null || getClass() != o.getClass()) return false;
+ Response response = (Response) o;
+ return windowSize == response.windowSize;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(windowSize);
+ }
+ }
+}
diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java
index 1244548597003..c1cf64b9f2ae8 100644
--- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java
+++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java
@@ -25,6 +25,7 @@
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.RerankingInferenceService;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskSettings;
@@ -48,6 +49,8 @@
public class TestRerankingServiceExtension implements InferenceServiceExtension {
+ public static final int RERANK_WINDOW_SIZE = 333;
+
@Override
public List getInferenceServiceFactories() {
return List.of(TestInferenceService::new);
@@ -62,7 +65,7 @@ public TestRerankingModel(String inferenceEntityId, TestServiceSettings serviceS
}
}
- public static class TestInferenceService extends AbstractTestInferenceService {
+ public static class TestInferenceService extends AbstractTestInferenceService implements RerankingInferenceService {
public static final String NAME = "test_reranking_service";
private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.RERANK);
@@ -200,6 +203,11 @@ protected ServiceSettings getServiceSettingsFromMap(Map serviceS
return TestServiceSettings.fromMap(serviceSettingsMap);
}
+ @Override
+ public int rerankerWindowSize(String modelId) {
+ return RERANK_WINDOW_SIZE;
+ }
+
public static class Configuration {
public static InferenceServiceConfiguration get() {
return configuration.getOrCompute();
diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java
index 33b9adb431a0a..30b8a636b9ac6 100644
--- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java
+++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java
@@ -61,8 +61,9 @@ public static Iterable
*/
-public abstract class AbstractInferenceServiceTests extends ESTestCase {
+public abstract class AbstractInferenceServiceTests extends InferenceServiceTestCase {
protected final MockWebServer webServer = new MockWebServer();
protected ThreadPool threadPool;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/InferenceServiceTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/InferenceServiceTestCase.java
new file mode 100644
index 0000000000000..b24535133107c
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/InferenceServiceTestCase.java
@@ -0,0 +1,52 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services;
+
+import org.elasticsearch.inference.InferenceService;
+import org.elasticsearch.inference.RerankingInferenceService;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+
+public abstract class InferenceServiceTestCase extends ESTestCase {
+
+ public abstract InferenceService createInferenceService();
+
+ public void testRerankersImplementRerankInterface() throws IOException {
+ try (InferenceService inferenceService = createInferenceService()) {
+ boolean implementsReranking = inferenceService instanceof RerankingInferenceService;
+ boolean hasRerankTaskType = inferenceService.supportedTaskTypes().contains(TaskType.RERANK);
+ if (implementsReranking != hasRerankTaskType) {
+ fail(
+ "Reranking inference services should implement RerankingInferenceService and support the RERANK task type. "
+ + "Service ["
+ + inferenceService.name()
+ + "] supports task type: ["
+ + hasRerankTaskType
+ + "] and implements"
+ + " RerankingInferenceService: ["
+ + implementsReranking
+ + "]"
+ );
+ }
+ }
+ }
+
+ public void testRerankersHaveWindowSize() throws IOException {
+ try (InferenceService inferenceService = createInferenceService()) {
+ if (inferenceService instanceof RerankingInferenceService rerankingInferenceService) {
+ assertRerankerWindowSize(rerankingInferenceService);
+ }
+ }
+ }
+
+ protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
+ fail("Reranking services should override this test method to verify window size");
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java
index 88e0ea3287336..cb9731d31910b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java
@@ -19,6 +19,7 @@
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.EmptyTaskSettings;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
@@ -561,4 +562,8 @@ private Map getRequestConfigMap(Map serviceSetti
return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings));
}
+ @Override
+ public InferenceService createInferenceService() {
+ return createService();
+ }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java
index f0258e9f66ed5..90adf6085734f 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java
@@ -19,14 +19,15 @@
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.RerankingInferenceService;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentType;
@@ -42,6 +43,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.action.AlibabaCloudSearchActionVisitor;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
@@ -73,7 +75,7 @@
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.Mockito.mock;
-public class AlibabaCloudSearchServiceTests extends ESTestCase {
+public class AlibabaCloudSearchServiceTests extends InferenceServiceTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private ThreadPool threadPool;
private HttpClientManager clientManager;
@@ -710,4 +712,18 @@ private Map getRequestConfigMap(
Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)
);
}
+
+ @Override
+ public InferenceService createInferenceService() {
+ return new AlibabaCloudSearchService(
+ mock(HttpRequestSender.Factory.class),
+ createWithEmptySettings(threadPool),
+ mockClusterServiceEmpty()
+ );
+ }
+
+ @Override
+ protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
+ assertThat(rerankingInferenceService.rerankerWindowSize("any model"), is(5500));
+ }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java
index c3b1cab4b4e0a..71d7cd5c5c1cd 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java
@@ -23,6 +23,7 @@
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
@@ -31,7 +32,6 @@
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentType;
@@ -43,6 +43,7 @@
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
import org.elasticsearch.xpack.inference.services.ServiceComponentsTests;
import org.elasticsearch.xpack.inference.services.amazonbedrock.client.AmazonBedrockMockRequestSender;
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel;
@@ -92,7 +93,7 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
-public class AmazonBedrockServiceTests extends ESTestCase {
+public class AmazonBedrockServiceTests extends InferenceServiceTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private ThreadPool threadPool;
@@ -1405,6 +1406,11 @@ private AmazonBedrockService createAmazonBedrockService() {
);
}
+ @Override
+ public InferenceService createInferenceService() {
+ return createAmazonBedrockService();
+ }
+
private Map getRequestConfigMap(
Map serviceSettings,
Map taskSettings,
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java
index 9111866d29c88..531239aeb5431 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java
@@ -16,13 +16,13 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
@@ -35,6 +35,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
+import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModel;
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModelTests;
@@ -74,7 +75,7 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
-public class AnthropicServiceTests extends ESTestCase {
+public class AnthropicServiceTests extends InferenceServiceTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
@@ -688,4 +689,9 @@ public void testSupportsStreaming() throws IOException {
private AnthropicService createServiceWithMockSender() {
return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty());
}
+
+ @Override
+ public InferenceService createInferenceService() {
+ return createServiceWithMockSender();
+ }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
index 3383762a9f332..08c31539be888 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
@@ -22,14 +22,15 @@
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.RerankingInferenceService;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
@@ -46,6 +47,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
+import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModelTests;
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettingsTests;
@@ -96,7 +98,7 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
-public class AzureAiStudioServiceTests extends ESTestCase {
+public class AzureAiStudioServiceTests extends InferenceServiceTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
@@ -1682,6 +1684,16 @@ private AzureAiStudioService createService() {
);
}
+ @Override
+ public InferenceService createInferenceService() {
+ return createService();
+ }
+
+ @Override
+ protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
+ assertThat(rerankingInferenceService.rerankerWindowSize("Any model"), is(300));
+ }
+
private Map getRequestConfigMap(
Map serviceSettings,
Map taskSettings,
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
index f3d65c5589169..4eb3b6a53b9ba 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
@@ -22,6 +22,7 @@
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
@@ -29,7 +30,6 @@
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
@@ -44,6 +44,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
+import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests;
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests;
@@ -89,7 +90,7 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
-public class AzureOpenAiServiceTests extends ESTestCase {
+public class AzureOpenAiServiceTests extends InferenceServiceTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
@@ -1223,6 +1224,11 @@ private AzureOpenAiService createAzureOpenAiService() {
);
}
+ @Override
+ public InferenceService createInferenceService() {
+ return createAzureOpenAiService();
+ }
+
private Map getRequestConfigMap(
Map serviceSettings,
Map taskSettings,
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
index 8f189baa33b20..e39dc02c238cb 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
@@ -23,14 +23,15 @@
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.RerankingInferenceService;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
@@ -46,6 +47,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
+import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
@@ -92,7 +94,7 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
-public class CohereServiceTests extends ESTestCase {
+public class CohereServiceTests extends InferenceServiceTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
@@ -1635,4 +1637,13 @@ private CohereService createCohereService() {
return new CohereService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty());
}
+ @Override
+ public InferenceService createInferenceService() {
+ return createCohereService();
+ }
+
+ @Override
+ protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
+ assertThat(rerankingInferenceService.rerankerWindowSize("any model"), is(2800));
+ }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
index a707030a34189..55bb98705a2a3 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
@@ -16,9 +16,11 @@
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.RerankingInferenceService;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.WeightedToken;
@@ -805,4 +807,17 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
assertThat(requestMap.get("input"), is(List.of("a")));
}
}
+
+ @Override
+ public InferenceService createInferenceService() {
+ return createService(threadPool, clientManager);
+ }
+
+ @Override
+ protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
+ assertThat(
+ rerankingInferenceService.rerankerWindowSize("any model"),
+ CoreMatchers.is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE)
+ );
+ }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java
index 908451b8e681f..d15fdeb962fdc 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java
@@ -16,12 +16,12 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
-import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
@@ -35,6 +35,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
+import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
import org.junit.After;
import org.junit.Before;
@@ -61,7 +62,7 @@
import static org.hamcrest.Matchers.isA;
import static org.mockito.Mockito.mock;
-public class DeepSeekServiceTests extends ESTestCase {
+public class DeepSeekServiceTests extends InferenceServiceTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
@@ -365,6 +366,11 @@ private DeepSeekService createService() {
);
}
+ @Override
+ public InferenceService createInferenceService() {
+ return createService();
+ }
+
private void parseRequestConfig(String json, ActionListener listener) throws IOException {
try (var service = createService()) {
service.parseRequestConfig("inference-id", TaskType.CHAT_COMPLETION, map(json), listener);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java
index 4111cab05b7c2..88459133ddc71 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java
@@ -31,12 +31,14 @@
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.InferenceResults;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.RerankingInferenceService;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.telemetry.InferenceStats;
@@ -82,7 +84,9 @@
import org.elasticsearch.xpack.inference.ModelConfigurationsTests;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
+import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
import org.elasticsearch.xpack.inference.services.ServiceFields;
+import org.hamcrest.CoreMatchers;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.Before;
@@ -131,7 +135,7 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
-public class ElasticsearchInternalServiceTests extends ESTestCase {
+public class ElasticsearchInternalServiceTests extends InferenceServiceTestCase {
private String randomInferenceEntityId;
private InferenceStats inferenceStats;
@@ -2090,6 +2094,19 @@ private ElasticsearchInternalService createService(Client client) {
return new ElasticsearchInternalService(context);
}
+ @Override
+ public InferenceService createInferenceService() {
+ return createService(mock(Client.class));
+ }
+
+ @Override
+ protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
+ assertThat(
+ rerankingInferenceService.rerankerWindowSize("any model"),
+ CoreMatchers.is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE)
+ );
+ }
+
private ElasticsearchInternalService createService(Client client, BaseElasticsearchInternalService.PreferredModelVariant modelVariant) {
var context = new InferenceServiceExtension.InferenceServiceFactoryContext(
client,
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java
index 435ea9de5911b..d6ebe4dfde8d8 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java
@@ -23,6 +23,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
@@ -30,7 +31,6 @@
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
@@ -45,6 +45,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel;
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModelTests;
@@ -91,7 +92,7 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
-public class GoogleAiStudioServiceTests extends ESTestCase {
+public class GoogleAiStudioServiceTests extends InferenceServiceTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
@@ -1177,4 +1178,9 @@ private GoogleAiStudioService createGoogleAiStudioService() {
mockClusterServiceEmpty()
);
}
+
+ @Override
+ public InferenceService createInferenceService() {
+ return createGoogleAiStudioService();
+ }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java
index 26fd076e72462..4cb5ff6d7c68b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java
@@ -16,13 +16,14 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkingSettings;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.RerankingInferenceService;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContent;
@@ -30,6 +31,7 @@
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
@@ -64,7 +66,7 @@
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
-public class GoogleVertexAiServiceTests extends ESTestCase {
+public class GoogleVertexAiServiceTests extends InferenceServiceTestCase {
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
@@ -1046,6 +1048,23 @@ private GoogleVertexAiService createGoogleVertexAiService() {
return new GoogleVertexAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty());
}
+ @Override
+ public InferenceService createInferenceService() {
+ return createGoogleVertexAiService();
+ }
+
+ protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
+ assertThat(
+ rerankingInferenceService.rerankerWindowSize("semantic-ranker-default-003"),
+ CoreMatchers.is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE)
+ );
+ assertThat(rerankingInferenceService.rerankerWindowSize("semantic-ranker-default-004"), CoreMatchers.is(600));
+ assertThat(
+ rerankingInferenceService.rerankerWindowSize("any other"),
+ CoreMatchers.is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE)
+ );
+ }
+
private Map getRequestConfigMap(
Map serviceSettings,
Map taskSettings,
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
index c770672c5d5f2..a1fbffa69b7d5 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
@@ -23,16 +23,17 @@
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.RerankingInferenceService;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
@@ -50,6 +51,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
+import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettingsTests;
@@ -97,7 +99,7 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
-public class HuggingFaceServiceTests extends ESTestCase {
+public class HuggingFaceServiceTests extends InferenceServiceTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
@@ -1347,6 +1349,19 @@ private HuggingFaceService createHuggingFaceService() {
);
}
+ @Override
+ public InferenceService createInferenceService() {
+ return createHuggingFaceService();
+ }
+
+ @Override
+ protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
+ assertThat(
+ rerankingInferenceService.rerankerWindowSize("any model"),
+ is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE)
+ );
+ }
+
private Map getRequestConfigMap(
Map serviceSettings,
Map chunkingSettings,
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java
index ddc62b5a412b9..b10192b25face 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java
@@ -23,6 +23,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
@@ -30,7 +31,6 @@
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
@@ -46,6 +46,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionCreator;
@@ -92,7 +93,7 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
-public class IbmWatsonxServiceTests extends ESTestCase {
+public class IbmWatsonxServiceTests extends InferenceServiceTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
@@ -1021,6 +1022,11 @@ private IbmWatsonxService createIbmWatsonxService() {
return new IbmWatsonxService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty());
}
+ @Override
+ public InferenceService createInferenceService() {
+ return createIbmWatsonxService();
+ }
+
private static class IbmWatsonxServiceWithoutAuth extends IbmWatsonxService {
IbmWatsonxServiceWithoutAuth(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents, mockClusterServiceEmpty());
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
index d36c574e0aa99..d2f3406085cb1 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
@@ -22,14 +22,15 @@
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.RerankingInferenceService;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
@@ -43,6 +44,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModelTests;
@@ -86,7 +88,7 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
-public class JinaAIServiceTests extends ESTestCase {
+public class JinaAIServiceTests extends InferenceServiceTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
@@ -1844,4 +1846,13 @@ private JinaAIService createJinaAIService() {
return new JinaAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty());
}
+ @Override
+ public InferenceService createInferenceService() {
+ return createJinaAIService();
+ }
+
+ @Override
+ protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
+ assertThat(rerankingInferenceService.rerankerWindowSize("any model"), is(5500));
+ }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
index 442058171bf50..f4baee98192f4 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
@@ -24,6 +24,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
@@ -836,4 +837,9 @@ private Map getRequestConfigMap(Map serviceSetti
private static Map getEmbeddingsServiceSettingsMap() {
return buildServiceSettingsMap("id", "url", SimilarityMeasure.COSINE.toString(), null, null, null);
}
+
+ @Override
+ public InferenceService createInferenceService() {
+ return createService();
+ }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
index 602378f2b9783..936cedf1bf272 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
@@ -23,6 +23,7 @@
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
@@ -32,7 +33,6 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
@@ -50,6 +50,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
+import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel;
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModelTests;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingModelTests;
@@ -101,7 +102,7 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
-public class MistralServiceTests extends ESTestCase {
+public class MistralServiceTests extends InferenceServiceTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
@@ -1326,4 +1327,8 @@ private static Map getSecretSettingsMap(String apiKey) {
return new HashMap<>(Map.of(API_KEY_FIELD, apiKey));
}
+ @Override
+ public InferenceService createInferenceService() {
+ return createService();
+ }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
index 83455861198d3..365602275797b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
@@ -24,6 +24,7 @@
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
@@ -32,7 +33,6 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
@@ -50,6 +50,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
+import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests;
@@ -101,7 +102,7 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
-public class OpenAiServiceTests extends ESTestCase {
+public class OpenAiServiceTests extends InferenceServiceTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
@@ -1658,4 +1659,9 @@ public void testGetConfiguration() throws Exception {
private OpenAiService createOpenAiService() {
return new OpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty());
}
+
+ @Override
+ public InferenceService createInferenceService() {
+ return createOpenAiService();
+ }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java
index 18ac37e54c321..5d6bec1bcfbff 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java
@@ -19,19 +19,21 @@
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.RerankingInferenceService;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
+import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchema;
@@ -40,6 +42,7 @@
import org.junit.Before;
import java.io.IOException;
+import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -56,6 +59,7 @@
import static org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequestTests.randomUnifiedCompletionRequest;
import static org.elasticsearch.xpack.inference.Utils.mockClusterService;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
@@ -74,7 +78,7 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
-public class SageMakerServiceTests extends ESTestCase {
+public class SageMakerServiceTests extends InferenceServiceTestCase {
private static final String QUERY = "query";
private static final List INPUT = List.of("input");
@@ -524,4 +528,17 @@ public void testClose() throws IOException {
verify(client, only()).close();
}
+ @Override
+ public InferenceService createInferenceService() {
+ when(schemas.supportedTaskTypes()).thenReturn(EnumSet.of(TaskType.RERANK, TaskType.TEXT_EMBEDDING, TaskType.COMPLETION));
+ return sageMakerService;
+ }
+
+ @Override
+ protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
+ assertThat(
+ rerankingInferenceService.rerankerWindowSize("any model"),
+ is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE)
+ );
+ }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
index 72a3b530ab647..69378b899c98a 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
@@ -21,14 +21,15 @@
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.RerankingInferenceService;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
@@ -42,6 +43,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettingsTests;
@@ -84,7 +86,7 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
-public class VoyageAIServiceTests extends ESTestCase {
+public class VoyageAIServiceTests extends InferenceServiceTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
@@ -1789,4 +1791,13 @@ private VoyageAIService createVoyageAIService() {
return new VoyageAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty());
}
+ @Override
+ public InferenceService createInferenceService() {
+ return createVoyageAIService();
+ }
+
+ protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
+ assertThat(rerankingInferenceService.rerankerWindowSize("rerank-lite-1"), is(2800));
+ assertThat(rerankingInferenceService.rerankerWindowSize("any other model"), is(5500));
+ }
}
diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java
index 233c7278eed98..7c3074d9a9fbc 100644
--- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java
+++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java
@@ -326,6 +326,7 @@ public class Constants {
"cluster:admin/xpack/watcher/watch/put",
"cluster:internal/remote_cluster/nodes",
"cluster:internal/xpack/inference",
+ "cluster:internal/xpack/inference/rerankwindowsize/get",
"cluster:internal/xpack/inference/unified",
"cluster:internal/xpack/ml/coordinatedinference",
"cluster:internal/xpack/ml/datafeed/isolate",