diff --git a/server/src/main/java/org/elasticsearch/inference/RerankingInferenceService.java b/server/src/main/java/org/elasticsearch/inference/RerankingInferenceService.java new file mode 100644 index 0000000000000..d65a0970b1d92 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/RerankingInferenceService.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.inference; + +public interface RerankingInferenceService { + + /** + * The default window size for small reranking models (512 input tokens). + */ + int CONSERVATIVE_DEFAULT_WINDOW_SIZE = 300; + + /** + * The reranking model's max window or an approximation of + * measured in the number of words. + * @param modelId The model ID + * @return Window size in words + */ + int rerankerWindowSize(String modelId); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetRerankerWindowSizeAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetRerankerWindowSizeAction.java new file mode 100644 index 0000000000000..5035461f5f2a0 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetRerankerWindowSizeAction.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; + +import java.io.IOException; +import java.util.Objects; + +public class GetRerankerWindowSizeAction extends ActionType { + + public static final GetRerankerWindowSizeAction INSTANCE = new GetRerankerWindowSizeAction(); + public static final String NAME = "cluster:internal/xpack/inference/rerankwindowsize/get"; + + public GetRerankerWindowSizeAction() { + super(NAME); + } + + public static class Request extends ActionRequest { + + private final String inferenceEntityId; + + public Request(String inferenceEntityId) { + this.inferenceEntityId = inferenceEntityId; + } + + public Request(StreamInput in) throws IOException { + super(in); + this.inferenceEntityId = in.readString(); + } + + public String getInferenceEntityId() { + return inferenceEntityId; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(inferenceEntityId); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(inferenceEntityId, request.inferenceEntityId); + } + + @Override + public int hashCode() { + return Objects.hashCode(inferenceEntityId); + } + } + + public static class Response extends ActionResponse { + + private final int windowSize; + + public Response(int windowSize) { + this.windowSize = windowSize; + } + + public Response(StreamInput in) throws IOException { + this.windowSize = in.readVInt(); + } + + public int getWindowSize() { + return windowSize; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(windowSize); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + Response response = (Response) o; + return windowSize == response.windowSize; + } + + @Override + public int hashCode() { + return Objects.hashCode(windowSize); + } + } +} diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index 1244548597003..c1cf64b9f2ae8 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -25,6 +25,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskSettings; @@ -48,6 +49,8 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension { + public static final int RERANK_WINDOW_SIZE = 333; + @Override public List getInferenceServiceFactories() { return List.of(TestInferenceService::new); @@ -62,7 +65,7 @@ public TestRerankingModel(String inferenceEntityId, TestServiceSettings serviceS } } - public static class TestInferenceService extends AbstractTestInferenceService { + public static class TestInferenceService extends AbstractTestInferenceService implements RerankingInferenceService { public static final String NAME = "test_reranking_service"; private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.RERANK); @@ -200,6 +203,11 @@ protected ServiceSettings getServiceSettingsFromMap(Map serviceS return TestServiceSettings.fromMap(serviceSettingsMap); } + @Override + public int rerankerWindowSize(String modelId) { + return RERANK_WINDOW_SIZE; + } + public static class Configuration { public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java index 33b9adb431a0a..30b8a636b9ac6 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java @@ -61,8 +61,9 @@ public static Iterable parameters() { @Before public void setup() throws Exception { ModelRegistry modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class); - Utils.storeSparseModel(modelRegistry); + Utils.storeSparseModel("sparse-endpoint", modelRegistry); Utils.storeDenseModel( + "dense-endpoint", modelRegistry, randomIntBetween(1, 100), // dot product means that we need normalized vectors; it's not worth doing that in this test diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java index 7ddbf4fc55ffd..12f422ae8c2e3 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java @@ -90,8 +90,8 @@ public void setup() throws Exception { () -> randomFrom(DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType)) ); int dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 100); - Utils.storeSparseModel(modelRegistry); - Utils.storeDenseModel(modelRegistry, dimensions, similarity, elementType); + Utils.storeSparseModel("sparse-endpoint", modelRegistry); + Utils.storeDenseModel("dense-endpoint", modelRegistry, dimensions, similarity, elementType); } @Override @@ -122,27 +122,20 @@ public Settings indexSettings() { } public void testBulkOperations() throws Exception { - prepareCreate(INDEX_NAME).setMapping( - String.format( - Locale.ROOT, - """ - { - "properties": { - "sparse_field": { - "type": "semantic_text", - "inference_id": "%s" - }, - "dense_field": { - "type": "semantic_text", - "inference_id": "%s" - } - } + prepareCreate(INDEX_NAME).setMapping(String.format(Locale.ROOT, """ + { + "properties": { + "sparse_field": { + "type": "semantic_text", + "inference_id": "%s" + }, + "dense_field": { + "type": "semantic_text", + "inference_id": "%s" } - """, - TestSparseInferenceServiceExtension.TestInferenceService.NAME, - TestDenseInferenceServiceExtension.TestInferenceService.NAME - ) - ).get(); + } + } + """, "sparse-endpoint", "dense-endpoint")).get(); assertRandomBulkOperations(INDEX_NAME, isIndexRequest -> { Map map = new HashMap<>(); map.put("sparse_field", isIndexRequest && rarely() ? null : randomSemanticTextInput()); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/RerankWindowSizeIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/RerankWindowSizeIT.java new file mode 100644 index 0000000000000..020bbf0cfe752 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/RerankWindowSizeIT.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.action.GetRerankerWindowSizeAction; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.Utils; +import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; +import org.elasticsearch.xpack.inference.mock.TestRerankingServiceExtension; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.Before; + +import java.util.Collection; +import java.util.List; + +import static org.hamcrest.Matchers.containsString; + +@ESTestCase.WithoutEntitlements // due to dependency issue ES-12435 +public class RerankWindowSizeIT extends ESIntegTestCase { + + @Before + public void setup() throws Exception { + ModelRegistry modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class); + Utils.storeRerankModel("rerank-endpoint", modelRegistry); + Utils.storeSparseModel("sparse-endpoint", modelRegistry); + } + + @Override + protected Collection> nodePlugins() { + return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class); + } + + public void testRerankWindowSizeAction() { + var response = client().execute(GetRerankerWindowSizeAction.INSTANCE, new GetRerankerWindowSizeAction.Request("rerank-endpoint")) + .actionGet(); + assertEquals(TestRerankingServiceExtension.RERANK_WINDOW_SIZE, response.getWindowSize()); + } + + public void testActionNotAReranker() { + var e = expectThrows( + ElasticsearchStatusException.class, + () -> client().execute(GetRerankerWindowSizeAction.INSTANCE, new GetRerankerWindowSizeAction.Request("sparse-endpoint")) + .actionGet() + ); + assertThat(e.getMessage(), containsString("Inference endpoint [sparse-endpoint] does not have the rerank task type")); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexVersionIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexVersionIT.java index 8986b0a158e9f..ba63402158e6e 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexVersionIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexVersionIT.java @@ -31,8 +31,6 @@ import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.Utils; -import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; -import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.junit.Before; @@ -68,8 +66,8 @@ public void setup() throws Exception { () -> randomFrom(DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType)) ); int dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 100); - Utils.storeSparseModel(modelRegistry); - Utils.storeDenseModel(modelRegistry, dimensions, similarity, elementType); + Utils.storeSparseModel("sparse-endpoint", modelRegistry); + Utils.storeDenseModel("dense-endpoint", modelRegistry, dimensions, similarity, elementType); Set availableVersions = IndexVersionUtils.allReleasedVersions() .stream() @@ -113,11 +111,11 @@ public void testSemanticText() throws Exception { .startObject("properties") .startObject(SPARSE_SEMANTIC_FIELD) .field("type", "semantic_text") - .field("inference_id", TestSparseInferenceServiceExtension.TestInferenceService.NAME) + .field("inference_id", "sparse-endpoint") .endObject() .startObject(DENSE_SEMANTIC_FIELD) .field("type", "semantic_text") - .field("inference_id", TestDenseInferenceServiceExtension.TestInferenceService.NAME) + .field("inference_id", "dense-endpoint") .endObject() .endObject() .endObject(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index c3ae4f0d9d6d6..52c0b6b9fa9b7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -62,6 +62,7 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceDiagnosticsAction; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction; +import org.elasticsearch.xpack.core.inference.action.GetRerankerWindowSizeAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; @@ -72,6 +73,7 @@ import org.elasticsearch.xpack.inference.action.TransportGetInferenceDiagnosticsAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceServicesAction; +import org.elasticsearch.xpack.inference.action.TransportGetRerankerWindowSizeAction; import org.elasticsearch.xpack.inference.action.TransportInferenceAction; import org.elasticsearch.xpack.inference.action.TransportInferenceActionProxy; import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction; @@ -234,7 +236,8 @@ public List getActions() { new ActionHandler(XPackUsageFeatureAction.INFERENCE, TransportInferenceUsageAction.class), new ActionHandler(GetInferenceDiagnosticsAction.INSTANCE, TransportGetInferenceDiagnosticsAction.class), new ActionHandler(GetInferenceServicesAction.INSTANCE, TransportGetInferenceServicesAction.class), - new ActionHandler(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class) + new ActionHandler(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class), + new ActionHandler(GetRerankerWindowSizeAction.INSTANCE, TransportGetRerankerWindowSizeAction.class) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetRerankerWindowSizeAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetRerankerWindowSizeAction.java new file mode 100644 index 0000000000000..8e0a0d6696167 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetRerankerWindowSizeAction.java @@ -0,0 +1,108 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.RerankingInferenceService; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.GetRerankerWindowSizeAction; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; + +public class TransportGetRerankerWindowSizeAction extends HandledTransportAction< + GetRerankerWindowSizeAction.Request, + GetRerankerWindowSizeAction.Response> { + + private final ModelRegistry modelRegistry; + private final InferenceServiceRegistry serviceRegistry; + + @Inject + public TransportGetRerankerWindowSizeAction( + TransportService transportService, + ActionFilters actionFilters, + ThreadPool threadPool, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry + ) { + super( + GetRerankerWindowSizeAction.NAME, + transportService, + actionFilters, + GetRerankerWindowSizeAction.Request::new, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + this.modelRegistry = modelRegistry; + this.serviceRegistry = serviceRegistry; + } + + @Override + protected void doExecute( + Task task, + GetRerankerWindowSizeAction.Request request, + ActionListener listener + ) { + + SubscribableListener.newForked(l -> modelRegistry.getModel(request.getInferenceEntityId(), l)).< + GetRerankerWindowSizeAction.Response>andThen((l, unparsedModel) -> { + if (unparsedModel.taskType() != TaskType.RERANK) { + throw new ElasticsearchStatusException( + "Inference endpoint [{}] does not have the {} task type", + RestStatus.BAD_REQUEST, + request.getInferenceEntityId(), + TaskType.RERANK + ); + } + + var service = serviceRegistry.getService(unparsedModel.service()); + if (service.isEmpty()) { + throw new ElasticsearchStatusException( + "Unknown service [{}] for inference endpoint [{}]", + RestStatus.BAD_REQUEST, + unparsedModel.service(), + request.getInferenceEntityId() + ); + } + + if (service.get() instanceof RerankingInferenceService rerankingInferenceService) { + var model = service.get() + .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()); + + l.onResponse( + new GetRerankerWindowSizeAction.Response( + rerankWindowSize(rerankingInferenceService, model.getServiceSettings().modelId()) + ) + ); + } else { + throw new IllegalStateException( + "Inference endpoint [" + + request.getInferenceEntityId() + + "] has task type [" + + TaskType.RERANK + + "] but the service [" + + service.get().name() + + "] does not support reranking" + ); + } + }).addListener(listener); + } + + private int rerankWindowSize(RerankingInferenceService service, String modelId) { + return service.rerankerWindowSize(modelId); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index da608779fee0a..5383a4bfb2eec 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -26,6 +26,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -69,7 +70,7 @@ import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings.SERVICE_ID; import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings.WORKSPACE_NAME; -public class AlibabaCloudSearchService extends SenderService { +public class AlibabaCloudSearchService extends SenderService implements RerankingInferenceService { public static final String NAME = AlibabaCloudSearchUtils.SERVICE_NAME; private static final String SERVICE_NAME = "AlibabaCloud AI Search"; @@ -390,6 +391,14 @@ public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_16_0; } + @Override + public int rerankerWindowSize(String modelId) { + // Alibaba's mGTE models support long context windows of up to 8192 tokens. + // Using 1 token = 0.75 words, this translates to approximately 6144 words. + // https://huggingface.co/Alibaba-NLP/gte-multilingual-reranker-base + return 5500; + } + public static class Configuration { public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 4a5a8be8b6633..718757d9e2697 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -26,6 +26,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -72,7 +73,7 @@ import static org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings.DEFAULT_MAX_NEW_TOKENS; import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.EMBEDDING_MAX_BATCH_SIZE; -public class AzureAiStudioService extends SenderService { +public class AzureAiStudioService extends SenderService implements RerankingInferenceService { public static final String NAME = "azureaistudio"; @@ -400,6 +401,13 @@ private static void checkProviderAndEndpointTypeForTask( } } + @Override + public int rerankerWindowSize(String modelId) { + // Window size is model dependent and the values are not known for Azure AI Studio models. + // TODO make the rerank window size configurable + return RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE; + } + public static class Configuration { public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index fb6c630bd60c9..4963c8646e5d6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -26,6 +26,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -66,7 +67,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.EMBEDDING_MAX_BATCH_SIZE; -public class CohereService extends SenderService { +public class CohereService extends SenderService implements RerankingInferenceService { public static final String NAME = "cohere"; private static final String SERVICE_NAME = "Cohere"; @@ -361,6 +362,14 @@ public Set supportedStreamingTasks() { return COMPLETION_ONLY; } + @Override + public int rerankerWindowSize(String modelId) { + // Cohere rerank model truncates at 4096 tokens https://docs.cohere.com/reference/rerank + // Using 1 token = 0.75 words as a rough estimate, we get 3072 words + // allowing for some headroom, we set the window size below 3072 + return 2800; + } + public static class Configuration { public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 5f5078affa9d3..7cd069ac2e3e0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -26,6 +26,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -64,7 +65,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; -public class CustomService extends SenderService { +public class CustomService extends SenderService implements RerankingInferenceService { public static final String NAME = "custom"; private static final String SERVICE_NAME = "Custom"; @@ -366,6 +367,14 @@ public boolean hideFromConfigurationApi() { return true; } + @Override + public int rerankerWindowSize(String modelId) { + // The model's max input length is not known at this point, + // return a small default that will work with the smallest models + // TODO add a way to configure this setting + return RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE; + } + public static class Configuration { public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 8ad41e3567c04..89258d5716e8e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -31,6 +31,7 @@ import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; @@ -84,7 +85,7 @@ import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL_LINUX_X86; -public class ElasticsearchInternalService extends BaseElasticsearchInternalService { +public class ElasticsearchInternalService extends BaseElasticsearchInternalService implements RerankingInferenceService { public static final String NAME = "elasticsearch"; public static final String OLD_ELSER_SERVICE_NAME = "elser"; @@ -1060,6 +1061,14 @@ static TaskType inferenceConfigToTaskType(InferenceConfig config) { } } + @Override + public int rerankerWindowSize(String modelId) { + // The Elastic reranker has a window size of 512 tokens. + // Return 300 words as a default that comfortably fits in the window. + // TODO custom rerank models may have larger windows, make this configurable + return RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE; + } + /** * Iterates over the batch executing a limited number requests at a time to avoid * filling the ML node inference queue. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 2c2c667cd6eee..4e58e063eeebc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -25,6 +25,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; @@ -69,7 +70,7 @@ import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID; import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.COMPLETION_ERROR_PREFIX; -public class GoogleVertexAiService extends SenderService { +public class GoogleVertexAiService extends SenderService implements RerankingInferenceService { public static final String NAME = "googlevertexai"; @@ -383,6 +384,20 @@ private static GoogleVertexAiModel createModel( }; } + @Override + public int rerankerWindowSize(String modelId) { + // The -003 version rerankers have a content window of 512 tokens, + // the later -004 models support 1024 tokens. + // https://cloud.google.com/generative-ai-app-builder/docs/ranking + // TODO make the rerank window size configurable + + if (modelId != null && modelId.endsWith("-004")) { + return 600; + } else { + return RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE; + } + } + public static class Configuration { public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index bc64e832d182a..e0ad3f7460477 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -20,6 +20,7 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -57,7 +58,7 @@ * This class is responsible for managing the Hugging Face inference service. * It manages model creation, as well as chunked, non-chunked, and unified completion inference. */ -public class HuggingFaceService extends HuggingFaceBaseService { +public class HuggingFaceService extends HuggingFaceBaseService implements RerankingInferenceService { public static final String NAME = "hugging_face"; private static final String SERVICE_NAME = "Hugging Face"; @@ -228,6 +229,13 @@ public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_15_0; } + @Override + public int rerankerWindowSize(String modelId) { + // Assume a small window size as the true value is not known. + // TODO make the rerank window size configurable + return RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE; + } + public static class Configuration { public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index 00e1aede95a2b..bed21c9ccb8bf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -25,6 +25,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -63,7 +64,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceFields.EMBEDDING_MAX_BATCH_SIZE; -public class JinaAIService extends SenderService { +public class JinaAIService extends SenderService implements RerankingInferenceService { public static final String NAME = "jinaai"; private static final String SERVICE_NAME = "Jina AI"; @@ -347,6 +348,14 @@ public TransportVersion getMinimalSupportedVersion() { return TransportVersions.JINA_AI_INTEGRATION_ADDED; } + @Override + public int rerankerWindowSize(String modelId) { + // Jina AI rerank models have an 8000 token input length https://jina.ai/models/jina-reranker-v2-base-multilingual + // Using 1 token = 0.75 words as a rough estimate, we get 6000 words + // allowing for some headroom, we set the window size below 6000 words + return 5500; + } + public static class Configuration { public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java index 393e3503b5cfc..676a1edec126b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java @@ -25,6 +25,7 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; @@ -48,7 +49,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails; -public class SageMakerService implements InferenceService { +public class SageMakerService implements InferenceService, RerankingInferenceService { public static final String NAME = "amazon_sagemaker"; private static final String DISPLAY_NAME = "Amazon SageMaker"; private static final List ALIASES = List.of("sagemaker", "amazonsagemaker"); @@ -328,4 +329,11 @@ public TransportVersion getMinimalSupportedVersion() { public void close() throws IOException { client.close(); } + + @Override + public int rerankerWindowSize(String modelId) { + // Assume a small window size as the true value is not known. + // TODO make the rerank window size configurable + return RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 9698ee4c0d4bb..75d568c6477fd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -25,6 +25,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -61,7 +62,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; -public class VoyageAIService extends SenderService { +public class VoyageAIService extends SenderService implements RerankingInferenceService { public static final String NAME = "voyageai"; private static final String SERVICE_NAME = "Voyage AI"; @@ -91,6 +92,17 @@ public class VoyageAIService extends SenderService { 72 ); + private static final Map RERANKERS_INPUT_SIZE = Map.of( + "rerank-lite-1", + 2800 // The smallest model has a 4K context length https://docs.voyageai.com/docs/reranker + ); + + /** + * Apart from rerank-lite-1 all other models have a context length of at least 8k. + * This value is based on 1 token == 0.75 words and allowing for some overhead + */ + private static final int DEFAULT_RERANKER_INPUT_SIZE_WORDS = 5500; + public static final EnumSet VALID_INPUT_TYPE_VALUES = EnumSet.of( InputType.INGEST, InputType.SEARCH, @@ -369,6 +381,12 @@ public TransportVersion getMinimalSupportedVersion() { return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED; } + @Override + public int rerankerWindowSize(String modelId) { + Integer inputSize = RERANKERS_INPUT_SIZE.get(modelId); + return inputSize != null ? inputSize : DEFAULT_RERANKER_INPUT_SIZE_WORDS; + } + public static class Configuration { public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java index 5aa42520d74bd..c2253c7f5424b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin; import org.elasticsearch.xpack.core.ssl.SSLService; import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestRerankingServiceExtension; import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; import java.nio.file.Path; @@ -47,7 +48,8 @@ protected XPackLicenseState getLicenseState() { public List getInferenceServiceFactories() { return List.of( TestSparseInferenceServiceExtension.TestInferenceService::new, - TestDenseInferenceServiceExtension.TestInferenceService::new + TestDenseInferenceServiceExtension.TestInferenceService::new, + TestRerankingServiceExtension.TestInferenceService::new ); } }; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java index 53ada98b69cfe..413fc80169ea0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java @@ -27,6 +27,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestRerankingServiceExtension; import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.hamcrest.Matchers; @@ -81,27 +82,36 @@ public static ScalingExecutorBuilder inferenceUtilityPool() { ); } - public static void storeSparseModel(ModelRegistry modelRegistry) throws Exception { + public static void storeSparseModel(String inferenceId, ModelRegistry modelRegistry) throws Exception { Model model = new TestSparseInferenceServiceExtension.TestSparseModel( - TestSparseInferenceServiceExtension.TestInferenceService.NAME, + inferenceId, new TestSparseInferenceServiceExtension.TestServiceSettings("sparse_model", null, false) ); storeModel(modelRegistry, model); } public static void storeDenseModel( + String inferenceId, ModelRegistry modelRegistry, int dimensions, SimilarityMeasure similarityMeasure, DenseVectorFieldMapper.ElementType elementType ) throws Exception { Model model = new TestDenseInferenceServiceExtension.TestDenseModel( - TestDenseInferenceServiceExtension.TestInferenceService.NAME, + inferenceId, new TestDenseInferenceServiceExtension.TestServiceSettings("dense_model", dimensions, similarityMeasure, elementType) ); storeModel(modelRegistry, model); } + public static void storeRerankModel(String inferenceId, ModelRegistry modelRegistry) throws Exception { + Model model = new TestRerankingServiceExtension.TestRerankingModel( + inferenceId, + new TestRerankingServiceExtension.TestServiceSettings("rerank-model") + ); + storeModel(modelRegistry, model); + } + public static void storeModel(ModelRegistry modelRegistry, Model model) throws Exception { PlainActionFuture listener = new PlainActionFuture<>(); modelRegistry.storeModel(model, listener, AcknowledgedRequest.DEFAULT_ACK_TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetRerankerWindowSizeActionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetRerankerWindowSizeActionRequestTests.java new file mode 100644 index 0000000000000..665caa833962f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetRerankerWindowSizeActionRequestTests.java @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.action.GetRerankerWindowSizeAction; + +import java.io.IOException; + +public class GetRerankerWindowSizeActionRequestTests extends AbstractWireSerializingTestCase { + @Override + protected Writeable.Reader instanceReader() { + return GetRerankerWindowSizeAction.Request::new; + } + + @Override + protected GetRerankerWindowSizeAction.Request createTestInstance() { + return new GetRerankerWindowSizeAction.Request(randomAlphaOfLength(8)); + } + + @Override + protected GetRerankerWindowSizeAction.Request mutateInstance(GetRerankerWindowSizeAction.Request instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetRerankerWindowSizeActionResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetRerankerWindowSizeActionResponseTests.java new file mode 100644 index 0000000000000..310e8e307f058 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetRerankerWindowSizeActionResponseTests.java @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.action.GetRerankerWindowSizeAction; + +import java.io.IOException; + +public class GetRerankerWindowSizeActionResponseTests extends AbstractWireSerializingTestCase { + @Override + protected Writeable.Reader instanceReader() { + return GetRerankerWindowSizeAction.Response::new; + } + + @Override + protected GetRerankerWindowSizeAction.Response createTestInstance() { + return new GetRerankerWindowSizeAction.Response(randomNonNegativeInt()); + } + + @Override + protected GetRerankerWindowSizeAction.Response mutateInstance(GetRerankerWindowSizeAction.Response instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java index 293860903badf..c89c59c963bfd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.plugins.Plugin; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.Utils; -import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.junit.Before; @@ -26,7 +25,7 @@ public class SemanticTextNonDynamicFieldMapperTests extends NonDynamicFieldMappe @Before public void setup() throws Exception { ModelRegistry modelRegistry = node().injector().getInstance(ModelRegistry.class); - Utils.storeSparseModel(modelRegistry); + Utils.storeSparseModel("sparse-endpoint", modelRegistry); } @Override @@ -49,6 +48,6 @@ protected String getMapping() { return String.format(Locale.ROOT, """ "type": "%s", "inference_id": "%s" - """, SemanticTextFieldMapper.CONTENT_TYPE, TestSparseInferenceServiceExtension.TestInferenceService.NAME); + """, SemanticTextFieldMapper.CONTENT_TYPE, "sparse-endpoint"); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java index 4a4c59f091abf..effcce962aa93 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java @@ -17,7 +17,6 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; @@ -52,7 +51,7 @@ * To use this class, extend it and pass the constructor a configuration. *

*/ -public abstract class AbstractInferenceServiceTests extends ESTestCase { +public abstract class AbstractInferenceServiceTests extends InferenceServiceTestCase { protected final MockWebServer webServer = new MockWebServer(); protected ThreadPool threadPool; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/InferenceServiceTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/InferenceServiceTestCase.java new file mode 100644 index 0000000000000..b24535133107c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/InferenceServiceTestCase.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services; + +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.RerankingInferenceService; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; + +public abstract class InferenceServiceTestCase extends ESTestCase { + + public abstract InferenceService createInferenceService(); + + public void testRerankersImplementRerankInterface() throws IOException { + try (InferenceService inferenceService = createInferenceService()) { + boolean implementsReranking = inferenceService instanceof RerankingInferenceService; + boolean hasRerankTaskType = inferenceService.supportedTaskTypes().contains(TaskType.RERANK); + if (implementsReranking != hasRerankTaskType) { + fail( + "Reranking inference services should implement RerankingInferenceService and support the RERANK task type. " + + "Service [" + + inferenceService.name() + + "] supports task type: [" + + hasRerankTaskType + + "] and implements" + + " RerankingInferenceService: [" + + implementsReranking + + "]" + ); + } + } + } + + public void testRerankersHaveWindowSize() throws IOException { + try (InferenceService inferenceService = createInferenceService()) { + if (inferenceService instanceof RerankingInferenceService rerankingInferenceService) { + assertRerankerWindowSize(rerankingInferenceService); + } + } + } + + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + fail("Reranking services should override this test method to verify window size"); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java index 88e0ea3287336..cb9731d31910b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -561,4 +562,8 @@ private Map getRequestConfigMap(Map serviceSetti return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings)); } + @Override + public InferenceService createInferenceService() { + return createService(); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index f0258e9f66ed5..90adf6085734f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -19,14 +19,15 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; @@ -42,6 +43,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.action.AlibabaCloudSearchActionVisitor; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel; @@ -73,7 +75,7 @@ import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.mock; -public class AlibabaCloudSearchServiceTests extends ESTestCase { +public class AlibabaCloudSearchServiceTests extends InferenceServiceTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private ThreadPool threadPool; private HttpClientManager clientManager; @@ -710,4 +712,18 @@ private Map getRequestConfigMap( Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings) ); } + + @Override + public InferenceService createInferenceService() { + return new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); + } + + @Override + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + assertThat(rerankingInferenceService.rerankerWindowSize("any model"), is(5500)); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index c3b1cab4b4e0a..71d7cd5c5c1cd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -31,7 +32,6 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; @@ -43,6 +43,7 @@ import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; import org.elasticsearch.xpack.inference.services.amazonbedrock.client.AmazonBedrockMockRequestSender; import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; @@ -92,7 +93,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -public class AmazonBedrockServiceTests extends ESTestCase { +public class AmazonBedrockServiceTests extends InferenceServiceTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private ThreadPool threadPool; @@ -1405,6 +1406,11 @@ private AmazonBedrockService createAmazonBedrockService() { ); } + @Override + public InferenceService createInferenceService() { + return createAmazonBedrockService(); + } + private Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index 9111866d29c88..531239aeb5431 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -16,13 +16,13 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -35,6 +35,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModel; import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModelTests; @@ -74,7 +75,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -public class AnthropicServiceTests extends ESTestCase { +public class AnthropicServiceTests extends InferenceServiceTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); @@ -688,4 +689,9 @@ public void testSupportsStreaming() throws IOException { private AnthropicService createServiceWithMockSender() { return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } + + @Override + public InferenceService createInferenceService() { + return createServiceWithMockSender(); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 3383762a9f332..08c31539be888 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -22,14 +22,15 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -46,6 +47,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModelTests; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettingsTests; @@ -96,7 +98,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -public class AzureAiStudioServiceTests extends ESTestCase { +public class AzureAiStudioServiceTests extends InferenceServiceTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; @@ -1682,6 +1684,16 @@ private AzureAiStudioService createService() { ); } + @Override + public InferenceService createInferenceService() { + return createService(); + } + + @Override + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + assertThat(rerankingInferenceService.rerankerWindowSize("Any model"), is(300)); + } + private Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index f3d65c5589169..4eb3b6a53b9ba 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -29,7 +30,6 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -44,6 +44,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests; @@ -89,7 +90,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -public class AzureOpenAiServiceTests extends ESTestCase { +public class AzureOpenAiServiceTests extends InferenceServiceTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; @@ -1223,6 +1224,11 @@ private AzureOpenAiService createAzureOpenAiService() { ); } + @Override + public InferenceService createInferenceService() { + return createAzureOpenAiService(); + } + private Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 8f189baa33b20..e39dc02c238cb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -23,14 +23,15 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -46,6 +47,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; @@ -92,7 +94,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -public class CohereServiceTests extends ESTestCase { +public class CohereServiceTests extends InferenceServiceTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; @@ -1635,4 +1637,13 @@ private CohereService createCohereService() { return new CohereService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } + @Override + public InferenceService createInferenceService() { + return createCohereService(); + } + + @Override + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + assertThat(rerankingInferenceService.rerankerWindowSize("any model"), is(2800)); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index a707030a34189..55bb98705a2a3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -16,9 +16,11 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.WeightedToken; @@ -805,4 +807,17 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { assertThat(requestMap.get("input"), is(List.of("a"))); } } + + @Override + public InferenceService createInferenceService() { + return createService(threadPool, clientManager); + } + + @Override + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + assertThat( + rerankingInferenceService.rerankerWindowSize("any model"), + CoreMatchers.is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE) + ); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java index 908451b8e681f..d15fdeb962fdc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -16,12 +16,12 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -35,6 +35,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.junit.After; import org.junit.Before; @@ -61,7 +62,7 @@ import static org.hamcrest.Matchers.isA; import static org.mockito.Mockito.mock; -public class DeepSeekServiceTests extends ESTestCase { +public class DeepSeekServiceTests extends InferenceServiceTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; @@ -365,6 +366,11 @@ private DeepSeekService createService() { ); } + @Override + public InferenceService createInferenceService() { + return createService(); + } + private void parseRequestConfig(String json, ActionListener listener) throws IOException { try (var service = createService()) { service.parseRequestConfig("inference-id", TaskType.CHAT_COMPLETION, map(json), listener); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 4111cab05b7c2..88459133ddc71 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -31,12 +31,14 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.telemetry.InferenceStats; @@ -82,7 +84,9 @@ import org.elasticsearch.xpack.inference.ModelConfigurationsTests; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; +import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.hamcrest.CoreMatchers; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -131,7 +135,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -public class ElasticsearchInternalServiceTests extends ESTestCase { +public class ElasticsearchInternalServiceTests extends InferenceServiceTestCase { private String randomInferenceEntityId; private InferenceStats inferenceStats; @@ -2090,6 +2094,19 @@ private ElasticsearchInternalService createService(Client client) { return new ElasticsearchInternalService(context); } + @Override + public InferenceService createInferenceService() { + return createService(mock(Client.class)); + } + + @Override + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + assertThat( + rerankingInferenceService.rerankerWindowSize("any model"), + CoreMatchers.is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE) + ); + } + private ElasticsearchInternalService createService(Client client, BaseElasticsearchInternalService.PreferredModelVariant modelVariant) { var context = new InferenceServiceExtension.InferenceServiceFactoryContext( client, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 435ea9de5911b..d6ebe4dfde8d8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -30,7 +31,6 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -45,6 +45,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModelTests; @@ -91,7 +92,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -public class GoogleAiStudioServiceTests extends ESTestCase { +public class GoogleAiStudioServiceTests extends InferenceServiceTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); @@ -1177,4 +1178,9 @@ private GoogleAiStudioService createGoogleAiStudioService() { mockClusterServiceEmpty() ); } + + @Override + public InferenceService createInferenceService() { + return createGoogleAiStudioService(); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index 26fd076e72462..4cb5ff6d7c68b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -16,13 +16,14 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; @@ -30,6 +31,7 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; @@ -64,7 +66,7 @@ import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; -public class GoogleVertexAiServiceTests extends ESTestCase { +public class GoogleVertexAiServiceTests extends InferenceServiceTestCase { private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; @@ -1046,6 +1048,23 @@ private GoogleVertexAiService createGoogleVertexAiService() { return new GoogleVertexAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } + @Override + public InferenceService createInferenceService() { + return createGoogleVertexAiService(); + } + + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + assertThat( + rerankingInferenceService.rerankerWindowSize("semantic-ranker-default-003"), + CoreMatchers.is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE) + ); + assertThat(rerankingInferenceService.rerankerWindowSize("semantic-ranker-default-004"), CoreMatchers.is(600)); + assertThat( + rerankingInferenceService.rerankerWindowSize("any other"), + CoreMatchers.is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE) + ); + } + private Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index c770672c5d5f2..a1fbffa69b7d5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -23,16 +23,17 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -50,6 +51,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests; import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettingsTests; @@ -97,7 +99,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -public class HuggingFaceServiceTests extends ESTestCase { +public class HuggingFaceServiceTests extends InferenceServiceTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); @@ -1347,6 +1349,19 @@ private HuggingFaceService createHuggingFaceService() { ); } + @Override + public InferenceService createInferenceService() { + return createHuggingFaceService(); + } + + @Override + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + assertThat( + rerankingInferenceService.rerankerWindowSize("any model"), + is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE) + ); + } + private Map getRequestConfigMap( Map serviceSettings, Map chunkingSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index ddc62b5a412b9..b10192b25face 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -30,7 +31,6 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -46,6 +46,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionCreator; @@ -92,7 +93,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -public class IbmWatsonxServiceTests extends ESTestCase { +public class IbmWatsonxServiceTests extends InferenceServiceTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; @@ -1021,6 +1022,11 @@ private IbmWatsonxService createIbmWatsonxService() { return new IbmWatsonxService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } + @Override + public InferenceService createInferenceService() { + return createIbmWatsonxService(); + } + private static class IbmWatsonxServiceWithoutAuth extends IbmWatsonxService { IbmWatsonxServiceWithoutAuth(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents, mockClusterServiceEmpty()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index d36c574e0aa99..d2f3406085cb1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -22,14 +22,15 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -43,6 +44,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModelTests; @@ -86,7 +88,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -public class JinaAIServiceTests extends ESTestCase { +public class JinaAIServiceTests extends InferenceServiceTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; @@ -1844,4 +1846,13 @@ private JinaAIService createJinaAIService() { return new JinaAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } + @Override + public InferenceService createInferenceService() { + return createJinaAIService(); + } + + @Override + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + assertThat(rerankingInferenceService.rerankerWindowSize("any model"), is(5500)); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java index 442058171bf50..f4baee98192f4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java @@ -24,6 +24,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -836,4 +837,9 @@ private Map getRequestConfigMap(Map serviceSetti private static Map getEmbeddingsServiceSettingsMap() { return buildServiceSettingsMap("id", "url", SimilarityMeasure.COSINE.toString(), null, null, null); } + + @Override + public InferenceService createInferenceService() { + return createService(); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 602378f2b9783..936cedf1bf272 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -32,7 +33,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -50,6 +50,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel; import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModelTests; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingModelTests; @@ -101,7 +102,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -public class MistralServiceTests extends ESTestCase { +public class MistralServiceTests extends InferenceServiceTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; @@ -1326,4 +1327,8 @@ private static Map getSecretSettingsMap(String apiKey) { return new HashMap<>(Map.of(API_KEY_FIELD, apiKey)); } + @Override + public InferenceService createInferenceService() { + return createService(); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 83455861198d3..365602275797b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -24,6 +24,7 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -32,7 +33,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -50,6 +50,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests; @@ -101,7 +102,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -public class OpenAiServiceTests extends ESTestCase { +public class OpenAiServiceTests extends InferenceServiceTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; @@ -1658,4 +1659,9 @@ public void testGetConfiguration() throws Exception { private OpenAiService createOpenAiService() { return new OpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } + + @Override + public InferenceService createInferenceService() { + return createOpenAiService(); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java index 18ac37e54c321..5d6bec1bcfbff 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java @@ -19,19 +19,21 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; +import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder; import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchema; @@ -40,6 +42,7 @@ import org.junit.Before; import java.io.IOException; +import java.util.EnumSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -56,6 +59,7 @@ import static org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequestTests.randomUnifiedCompletionRequest; import static org.elasticsearch.xpack.inference.Utils.mockClusterService; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -74,7 +78,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -public class SageMakerServiceTests extends ESTestCase { +public class SageMakerServiceTests extends InferenceServiceTestCase { private static final String QUERY = "query"; private static final List INPUT = List.of("input"); @@ -524,4 +528,17 @@ public void testClose() throws IOException { verify(client, only()).close(); } + @Override + public InferenceService createInferenceService() { + when(schemas.supportedTaskTypes()).thenReturn(EnumSet.of(TaskType.RERANK, TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)); + return sageMakerService; + } + + @Override + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + assertThat( + rerankingInferenceService.rerankerWindowSize("any model"), + is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE) + ); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 72a3b530ab647..69378b899c98a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -21,14 +21,15 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -42,6 +43,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettingsTests; @@ -84,7 +86,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -public class VoyageAIServiceTests extends ESTestCase { +public class VoyageAIServiceTests extends InferenceServiceTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; @@ -1789,4 +1791,13 @@ private VoyageAIService createVoyageAIService() { return new VoyageAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } + @Override + public InferenceService createInferenceService() { + return createVoyageAIService(); + } + + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + assertThat(rerankingInferenceService.rerankerWindowSize("rerank-lite-1"), is(2800)); + assertThat(rerankingInferenceService.rerankerWindowSize("any other model"), is(5500)); + } } diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 233c7278eed98..7c3074d9a9fbc 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -326,6 +326,7 @@ public class Constants { "cluster:admin/xpack/watcher/watch/put", "cluster:internal/remote_cluster/nodes", "cluster:internal/xpack/inference", + "cluster:internal/xpack/inference/rerankwindowsize/get", "cluster:internal/xpack/inference/unified", "cluster:internal/xpack/ml/coordinatedinference", "cluster:internal/xpack/ml/datafeed/isolate",