diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java index f5815a3bfde23..2481d88078e6e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java @@ -31,7 +31,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; -import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; @@ -258,7 +257,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { List.of(query), TextExpansionConfigUpdate.EMPTY_UPDATE, false, - InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API + null ); inferRequest.setHighPriority(true); inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilder.java index 472cb5e1c7012..e2bb592638966 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilder.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilder.java @@ -28,7 +28,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; -import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; @@ -162,7 +161,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { List.of(modelText), TextExpansionConfigUpdate.EMPTY_UPDATE, false, - InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API + null ); inferRequest.setHighPriority(true); inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java index b859f59fc3e0e..b702eb7fc8ff9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java @@ -19,7 +19,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; -import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; @@ -116,7 +115,7 @@ public void buildVector(Client client, ActionListener listener) { List.of(modelText), TextEmbeddingConfigUpdate.EMPTY_INSTANCE, false, - InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API + null ); inferRequest.setHighPriority(true); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index de31f9d6cefc8..839c0cd3a8481 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -147,6 +147,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.TimeUnit; import java.util.function.Predicate; import java.util.function.Supplier; @@ -179,6 +180,13 @@ public class InferencePlugin extends Plugin Setting.Property.NodeScope, Setting.Property.Dynamic ); + public static final TimeValue DEFAULT_QUERY_INFERENCE_TIMEOUT = TimeValue.timeValueSeconds(TimeUnit.SECONDS.toSeconds(10)); + public static final Setting QUERY_INFERENCE_TIMEOUT = Setting.timeSetting( + "xpack.inference.semantic_text.inference_timeout", + DEFAULT_QUERY_INFERENCE_TIMEOUT, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); public static final LicensedFeature.Momentary INFERENCE_API_FEATURE = LicensedFeature.momentary( "inference", @@ -311,7 +319,8 @@ public Collection createComponents(PluginServices services) { serviceComponents.get(), inferenceServiceSettings, modelRegistry.get(), - authorizationHandler + authorizationHandler, + context ), context -> new SageMakerService( new SageMakerModelBuilder(sageMakerSchemas), @@ -321,7 +330,8 @@ public Collection createComponents(PluginServices services) { ), sageMakerSchemas, services.threadPool(), - sageMakerConfigurations::getOrCompute + sageMakerConfigurations::getOrCompute, + context ) ) ); @@ -383,24 +393,24 @@ public void loadExtensions(ExtensionLoader loader) { public List getInferenceServiceFactories() { return List.of( - context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()), - context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()), - context -> new OpenAiService(httpFactory.get(), serviceComponents.get()), - context -> new CohereService(httpFactory.get(), serviceComponents.get()), - context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()), - context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get()), - context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get()), - context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get()), - context -> new MistralService(httpFactory.get(), serviceComponents.get()), - context -> new AnthropicService(httpFactory.get(), serviceComponents.get()), - context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()), - context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()), - context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()), - context -> new JinaAIService(httpFactory.get(), serviceComponents.get()), - context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()), - context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()), + context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get(), context), + context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get(), context), + context -> new OpenAiService(httpFactory.get(), serviceComponents.get(), context), + context -> new CohereService(httpFactory.get(), serviceComponents.get(), context), + context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get(), context), + context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get(), context), + context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get(), context), + context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get(), context), + context -> new MistralService(httpFactory.get(), serviceComponents.get(), context), + context -> new AnthropicService(httpFactory.get(), serviceComponents.get(), context), + context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get(), context), + context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get(), context), + context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get(), context), + context -> new JinaAIService(httpFactory.get(), serviceComponents.get(), context), + context -> new VoyageAIService(httpFactory.get(), serviceComponents.get(), context), + context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context), ElasticsearchInternalService::new, - context -> new CustomService(httpFactory.get(), serviceComponents.get()) + context -> new CustomService(httpFactory.get(), serviceComponents.get(), context) ); } @@ -495,7 +505,7 @@ public List> getSettings() { settings.add(SKIP_VALIDATE_AND_START); settings.add(INDICES_INFERENCE_BATCH_SIZE); settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions()); - + settings.add(QUERY_INFERENCE_TIMEOUT); return settings; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 182c083ef1c26..f12e50674e5f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -31,7 +31,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; @@ -237,7 +236,7 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu List.of(query), Map.of(), InputType.INTERNAL_SEARCH, - InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, + null, false ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index ff8ae6fd5aac3..7b80a725e7f7b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; @@ -17,12 +18,14 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -42,11 +45,17 @@ public abstract class SenderService implements InferenceService { protected static final Set COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION); private final Sender sender; private final ServiceComponents serviceComponents; + private final ClusterService clusterService; - public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + public SenderService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { Objects.requireNonNull(factory); sender = factory.createSender(); this.serviceComponents = Objects.requireNonNull(serviceComponents); + this.clusterService = Objects.requireNonNull(context.clusterService()); } public Sender getSender() { @@ -73,6 +82,9 @@ public void infer( init(); var chunkInferenceInput = input.stream().map(i -> new ChunkInferenceInput(i, null)).toList(); var inferenceInput = createInput(this, model, chunkInferenceInput, inputType, query, returnDocuments, topN, stream); + if (timeout == null) { + timeout = clusterService.getClusterSettings().get(InferencePlugin.QUERY_INFERENCE_TIMEOUT); + } doInfer(model, inferenceInput, taskSettings, timeout, listener); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 7897317736c72..1f443d4d93291 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -19,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -85,8 +86,12 @@ public class AlibabaCloudSearchService extends SenderService { InputType.INTERNAL_SEARCH ); - public AlibabaCloudSearchService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AlibabaCloudSearchService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 591607953ea1a..3b61169612c0a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -20,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -93,9 +94,10 @@ public class AmazonBedrockService extends SenderService { public AmazonBedrockService( HttpRequestSender.Factory httpSenderFactory, AmazonBedrockRequestSender.Factory amazonBedrockFactory, - ServiceComponents serviceComponents + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(httpSenderFactory, serviceComponents); + super(httpSenderFactory, serviceComponents, context); this.amazonBedrockSender = amazonBedrockFactory.createSender(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index bec8908ab73f9..ea2a4a41ad0ad 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -58,8 +59,12 @@ public class AnthropicService extends SenderService { private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.COMPLETION); - public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AnthropicService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 04883f23b947f..6da5b91f65d6a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -19,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -83,8 +84,12 @@ public class AzureAiStudioService extends SenderService { InputType.INTERNAL_SEARCH ); - public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AzureAiStudioService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index e9ff97c1ba725..1bac68430d87c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -69,8 +70,12 @@ public class AzureOpenAiService extends SenderService { private static final String SERVICE_NAME = "Azure OpenAI"; private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION); - public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AzureOpenAiService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index bf6a0bd03122b..36e4560b310e2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -19,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -84,8 +85,12 @@ public class CohereService extends SenderService { // The reason it needs to be done here is that the batching logic needs to hold state but the *RequestManagers are instantiated // on every request - public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public CohereService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index deb6e17ec5311..65154ae3a3c1d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -19,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -74,8 +75,12 @@ public class CustomService extends SenderService { TaskType.COMPLETION ); - public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public CustomService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java index 56719199e094f..ebf5a8289d725 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -16,6 +16,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -58,8 +59,12 @@ public class DeepSeekService extends SenderService { ); private static final EnumSet SUPPORTED_TASK_TYPES_FOR_STREAMING = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); - public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public DeepSeekService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 8da1229a528ea..d71bf02115bfb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -20,6 +20,7 @@ import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.MinimalServiceSettings; @@ -117,9 +118,10 @@ public ElasticInferenceService( ServiceComponents serviceComponents, ElasticInferenceServiceSettings elasticInferenceServiceSettings, ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler + ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, + InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents); + super(factory, serviceComponents, context); this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents( elasticInferenceServiceSettings.getElasticInferenceServiceUrl() ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index a6823d65da107..aaf18141e3665 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -263,6 +263,10 @@ private void preferredVariantFromPlatformArchitecture(ActionListener 0 means scaling should be available for ml nodes diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 4f2674179be67..bc2a728efa9a5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -613,6 +613,9 @@ public void infer( TimeValue timeout, ActionListener listener ) { + if (timeout == null) { + timeout = getConfiguredInferenceTimeout(); + } if (model instanceof ElasticsearchInternalModel esModel) { var taskType = model.getConfigurations().getTaskType(); if (TaskType.TEXT_EMBEDDING.equals(taskType)) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 9841ea64370c3..76b4e3ff7f31f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -19,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -82,8 +83,12 @@ public class GoogleAiStudioService extends SenderService { InputType.INTERNAL_SEARCH ); - public GoogleAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public GoogleAiStudioService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 3b59e999125e5..6c5aec6ddfeb7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -97,8 +98,12 @@ public Set supportedStreamingTasks() { return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); } - public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public GoogleVertexAiService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java index b0d40b41914d5..e05adf3773fd4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -44,8 +45,12 @@ public abstract class HuggingFaceBaseService extends SenderService { */ static final int EMBEDDING_MAX_BATCH_SIZE = 20; - public HuggingFaceBaseService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public HuggingFaceBaseService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index d10fb77290c6b..411101cb5728b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -15,6 +15,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -71,8 +72,12 @@ public class HuggingFaceService extends HuggingFaceBaseService { OpenAiChatCompletionResponseEntity::fromResponse ); - public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public HuggingFaceService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index e61995aac91f3..5bd4f672c59fd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -17,6 +17,7 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -57,8 +58,12 @@ public class HuggingFaceElserService extends HuggingFaceBaseService { private static final String SERVICE_NAME = "Hugging Face ELSER"; private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING); - public HuggingFaceElserService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public HuggingFaceElserService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 7dfb0002bb062..db45c29039271 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -69,8 +70,12 @@ public class IbmWatsonxService extends SenderService { private static final String SERVICE_NAME = "IBM Watsonx"; private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING); - public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public IbmWatsonxService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index c2e88cb6cdc7c..6b0210fdc4958 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -76,8 +77,12 @@ public class JinaAIService extends SenderService { InputType.INTERNAL_SEARCH ); - public JinaAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public JinaAIService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index b11feb117d761..15b01e2ba9807 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -84,8 +85,12 @@ public class MistralService extends SenderService { OpenAiChatCompletionResponseEntity::fromResponse ); - public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public MistralService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index edff1dfc08cba..3c6984bf37e84 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -91,8 +92,12 @@ public class OpenAiService extends SenderService { OpenAiChatCompletionResponseEntity::fromResponse ); - public OpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public OpenAiService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java index aafd6c46857fc..fa574d18de80e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java @@ -12,6 +12,7 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.CheckedSupplier; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -20,6 +21,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -28,6 +30,7 @@ import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder; @@ -37,6 +40,7 @@ import java.util.EnumSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import static org.elasticsearch.core.Strings.format; @@ -55,13 +59,15 @@ public class SageMakerService implements InferenceService { private final SageMakerSchemas schemas; private final ThreadPool threadPool; private final LazyInitializable configuration; + private final ClusterService clusterService; public SageMakerService( SageMakerModelBuilder modelBuilder, SageMakerClient client, SageMakerSchemas schemas, ThreadPool threadPool, - CheckedSupplier, RuntimeException> configurationMap + CheckedSupplier, RuntimeException> configurationMap, + InferenceServiceExtension.InferenceServiceFactoryContext context ) { this.modelBuilder = modelBuilder; this.client = client; @@ -74,6 +80,7 @@ public SageMakerService( .setConfigurations(configurationMap.get()) .build() ); + this.clusterService = Objects.requireNonNull(context.clusterService()); } @Override @@ -146,6 +153,10 @@ public void infer( var inferenceRequest = new SageMakerInferenceRequest(query, returnDocuments, topN, input, stream, inputType); + if (timeout == null) { + timeout = clusterService.getClusterSettings().get(InferencePlugin.QUERY_INFERENCE_TIMEOUT); + } + try { var sageMakerModel = ((SageMakerModel) model).override(taskSettings); var regionAndSecrets = regionAndSecrets(sageMakerModel); @@ -156,7 +167,7 @@ public void infer( client.invokeStream( regionAndSecrets, request, - timeout != null ? timeout : DEFAULT_TIMEOUT, + timeout, ActionListener.wrap( response -> listener.onResponse(schema.streamResponse(sageMakerModel, response)), e -> listener.onFailure(schema.error(sageMakerModel, e)) @@ -168,7 +179,7 @@ public void infer( client.invoke( regionAndSecrets, request, - timeout != null ? timeout : DEFAULT_TIMEOUT, + timeout, ActionListener.wrap( response -> listener.onResponse(schema.response(sageMakerModel, response, threadPool.getThreadContext())), e -> listener.onFailure(schema.error(sageMakerModel, e)) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 0ffec057dc2b4..da4f6071e073e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -96,8 +97,12 @@ public class VoyageAIService extends SenderService { InputType.INTERNAL_SEARCH ); - public VoyageAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public VoyageAIService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 5d7a6a149f941..73f672e5464f7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -10,10 +10,13 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -47,10 +50,17 @@ public class SenderServiceTests extends ESTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private ThreadPool threadPool; + private InferenceServiceExtension.InferenceServiceFactoryContext context; @Before public void init() throws Exception { threadPool = createThreadPool(inferenceUtilityPool()); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -64,7 +74,7 @@ public void testStart_InitializesTheSender() throws IOException { var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); - try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) { + try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); service.start(mock(Model.class), listener); @@ -84,7 +94,7 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); - try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) { + try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); service.start(mock(Model.class), listener); listener.actionGet(TIMEOUT); @@ -102,8 +112,12 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep } private static final class TestSenderService extends SenderService { - TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + TestSenderService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index 8fbbd33d569e4..33dc018d9aa7a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; @@ -20,6 +21,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -77,11 +79,18 @@ public class AlibabaCloudSearchServiceTests extends ESTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private ThreadPool threadPool; private HttpClientManager clientManager; + private InferenceServiceExtension.InferenceServiceFactoryContext context; @Before public void init() throws Exception { threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -91,7 +100,9 @@ public void shutdown() throws IOException { } public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context) + ) { ActionListener modelVerificationListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); @@ -116,7 +127,9 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException } public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context) + ) { ActionListener modelVerificationListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); @@ -143,7 +156,9 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsP } public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context) + ) { ActionListener modelVerificationListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); @@ -169,7 +184,9 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsN } public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context) + ) { var model = service.parsePersistedConfig( "id", TaskType.TEXT_EMBEDDING, @@ -190,7 +207,9 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting } public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context) + ) { var model = service.parsePersistedConfig( "id", TaskType.TEXT_EMBEDDING, @@ -210,7 +229,9 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting } public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context) + ) { var persistedConfig = getPersistedConfigMap( AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), @@ -235,7 +256,9 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun } public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context) + ) { var persistedConfig = getPersistedConfigMap( AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), @@ -262,7 +285,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = OpenAiChatCompletionModelTests.createCompletionModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -279,7 +302,7 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO public void testUpdateModelWithEmbeddingDetails_UpdatesEmbeddingSizeAndSimilarity() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), context)) { var embeddingSize = randomNonNegativeInt(); var model = AlibabaCloudSearchEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -316,7 +339,7 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType_TextEmbedding() t taskSettingsMap, secretSettingsMap ); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -360,7 +383,7 @@ public void testInfer_ThrowsValidationExceptionForInvalidInputType_SparseEmbeddi taskSettingsMap, secretSettingsMap ); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -404,7 +427,7 @@ public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOExc taskSettingsMap, secretSettingsMap ); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -452,7 +475,7 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = createModelForTaskType(taskType, chunkingSettings); PlainActionFuture> listener = new PlainActionFuture<>(); @@ -482,7 +505,9 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin @SuppressWarnings("checkstyle:LineLength") public void testGetConfiguration() throws Exception { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context) + ) { String content = XContentHelper.stripWhitespace( """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index a014f27e7f0cc..10386c0b9353d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; @@ -24,6 +25,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -95,10 +97,17 @@ public class AmazonBedrockServiceTests extends ESTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private ThreadPool threadPool; + private InferenceServiceExtension.InferenceServiceFactoryContext context; @Before public void init() throws Exception { threadPool = createThreadPool(inferenceUtilityPool()); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -959,7 +968,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAmazonBedrockModel() throws IOExc ); var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -1007,7 +1016,7 @@ public void testInfer_SendsRequest_ForTitanEmbeddingsModel() throws IOException ); try ( - var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool)); + var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool), context); var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender() ) { var results = new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F }))); @@ -1042,7 +1051,7 @@ public void testInfer_SendsRequest_ForCohereEmbeddingsModel() throws IOException mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool), context)) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { var results = new TextEmbeddingFloatResults( List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) @@ -1088,7 +1097,7 @@ public void testInfer_SendsRequest_ForChatCompletionModel() throws IOException { mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool), context)) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { var mockResults = new ChatCompletionResults(List.of(new ChatCompletionResults.Result("test result"))); requestSender.enqueue(mockResults); @@ -1132,7 +1141,7 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool), context)) { var model = AmazonBedrockChatCompletionModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -1166,7 +1175,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool), context)) { var embeddingSize = randomNonNegativeInt(); var provider = randomFrom(AmazonBedrockProvider.values()); var model = AmazonBedrockEmbeddingsModelTests.createModel( @@ -1205,7 +1214,7 @@ public void testInfer_UnauthorizedResponse() throws IOException { ); try ( - var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool)); + var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool), context); var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender() ) { requestSender.enqueue( @@ -1240,7 +1249,7 @@ public void testInfer_UnauthorizedResponse() throws IOException { } public void testSupportsStreaming() throws IOException { - try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()))) { + try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()), context)) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1284,7 +1293,7 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool), context)) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { { var mockResults1 = new TextEmbeddingFloatResults( @@ -1345,7 +1354,12 @@ private AmazonBedrockService createAmazonBedrockService() { ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), mockClusterServiceEmpty() ); - return new AmazonBedrockService(mock(HttpRequestSender.Factory.class), amazonBedrockFactory, createWithEmptySettings(threadPool)); + return new AmazonBedrockService( + mock(HttpRequestSender.Factory.class), + amazonBedrockFactory, + createWithEmptySettings(threadPool), + context + ); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index 75ce59b16a763..838d4a12c5d88 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -11,12 +11,14 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -79,6 +81,7 @@ public class AnthropicServiceTests extends ESTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; + private InferenceServiceExtension.InferenceServiceFactoryContext context; private HttpClientManager clientManager; @@ -87,6 +90,12 @@ public void init() throws Exception { webServer.start(); threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -453,7 +462,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AnthropicService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AnthropicService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -486,7 +495,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException public void testInfer_SendsCompletionRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb", @@ -579,7 +588,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = AnthropicChatCompletionModelTests.createChatCompletionModel( getUrl(webServer), "secret", @@ -670,13 +679,13 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()))) { + try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()), context)) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } } private AnthropicService createServiceWithMockSender() { - return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 3d7ba7f7436fb..dffb0c11211d2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; @@ -23,6 +24,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -96,12 +98,19 @@ public class AzureAiStudioServiceTests extends ESTestCase { private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; + private InferenceServiceExtension.InferenceServiceFactoryContext context; @Before public void init() throws Exception { webServer.start(); threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -844,7 +853,7 @@ public void testParsePersistedConfig_WithoutSecretsCreatesChatCompletionModel() public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = AzureAiStudioChatCompletionModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -869,7 +878,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), context)) { var embeddingSize = randomNonNegativeInt(); var model = AzureAiStudioEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -895,7 +904,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testUpdateModelWithChatCompletionDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = AzureAiStudioEmbeddingsModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -923,7 +932,7 @@ public void testUpdateModelWithChatCompletionDetails_NonNullSimilarityInOriginal private void testUpdateModelWithChatCompletionDetails_Successful(Integer maxNewTokens) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = AzureAiStudioChatCompletionModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -956,7 +965,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOExc var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -994,7 +1003,7 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOExcept var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -1064,7 +1073,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1150,7 +1159,7 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep public void testInfer_WithChatCompletionModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), context)) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testChatCompletionResultJson)); var model = AzureAiStudioChatCompletionModelTests.createModel( @@ -1187,7 +1196,7 @@ public void testInfer_WithChatCompletionModel() throws IOException { public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1264,7 +1273,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = AzureAiStudioChatCompletionModelTests.createModel( "id", getUrl(webServer), @@ -1396,7 +1405,7 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()))) { + try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()), context)) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1405,7 +1414,7 @@ public void testSupportsStreaming() throws IOException { // ---------------------------------------------------------------- private AzureAiStudioService createService() { - return new AzureAiStudioService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new AzureAiStudioService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index de2e9ae9a21b8..35f4373928354 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; @@ -23,6 +24,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -94,12 +96,19 @@ public class AzureOpenAiServiceTests extends ESTestCase { private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; + private InferenceServiceExtension.InferenceServiceFactoryContext context; @Before public void init() throws Exception { webServer.start(); threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -752,7 +761,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AzureOpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -785,7 +794,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep public void testInfer_SendsRequest() throws IOException, URISyntaxException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -844,7 +853,7 @@ public void testInfer_SendsRequest() throws IOException, URISyntaxException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = AzureOpenAiCompletionModelTests.createModelWithRandomValues(); assertThrows( ElasticsearchStatusException.class, @@ -864,7 +873,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), context)) { var embeddingSize = randomNonNegativeInt(); var model = AzureOpenAiEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -891,7 +900,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_UnauthorisedResponse() throws IOException, URISyntaxException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -952,7 +961,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException, URISyn private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOException, URISyntaxException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1065,7 +1074,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = AzureOpenAiCompletionModelTests.createCompletionModel( "resource", "deployment", @@ -1209,14 +1218,14 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()))) { + try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()), context)) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } } private AzureOpenAiService createAzureOpenAiService() { - return new AzureOpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new AzureOpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index fabf87151644b..27ec4dbaaaab2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; @@ -24,6 +25,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -97,12 +99,19 @@ public class CohereServiceTests extends ESTestCase { private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; + private InferenceServiceExtension.InferenceServiceFactoryContext context; @Before public void init() throws Exception { webServer.start(); threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -784,7 +793,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new CohereService(factory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -817,7 +826,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException public void testInfer_SendsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -891,7 +900,7 @@ public void testInfer_SendsRequest() throws IOException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = CohereCompletionModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThrows( ElasticsearchStatusException.class, @@ -911,7 +920,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), context)) { var embeddingSize = randomNonNegativeInt(); var embeddingType = randomFrom(CohereEmbeddingType.values()); var model = CohereEmbeddingsModelTests.createModel( @@ -938,7 +947,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -980,7 +989,7 @@ public void testInfer_UnauthorisedResponse() throws IOException { public void testInfer_SetsInputTypeToIngest_FromInferParameter_WhenTaskSettingsAreEmpty() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1056,7 +1065,7 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1130,7 +1139,7 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1234,7 +1243,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), context)) { // Batching will call the service with 2 inputs String responseJson = """ @@ -1324,7 +1333,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), context)) { // Batching will call the service with 2 inputs String responseJson = """ @@ -1444,7 +1453,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1528,7 +1537,7 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new CohereService(mock(), createWithEmptySettings(mock()))) { + try (var service = new CohereService(mock(), createWithEmptySettings(mock()), context)) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1569,7 +1578,7 @@ private Map getRequestConfigMap(Map serviceSetti } private CohereService createCohereService() { - return new CohereService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new CohereService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index 6ddb4ff71eeb3..720daa0d05092 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -9,14 +9,17 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -65,6 +68,7 @@ import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; public class CustomServiceTests extends AbstractInferenceServiceTests { @@ -149,7 +153,11 @@ private static void assertCompletionModel(Model model) { public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - return new CustomService(senderFactory, createWithEmptySettings(threadPool)); + return new CustomService( + senderFactory, + createWithEmptySettings(threadPool), + new InferenceServiceExtension.InferenceServiceFactoryContext(mock(), threadPool, mock(ClusterService.class), Settings.EMPTY) + ); } private static Map createServiceSettingsMap(TaskType taskType) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java index af38ee38e1eff..6a8f58a7c12e3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -10,12 +10,14 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -66,12 +68,19 @@ public class DeepSeekServiceTests extends ESTestCase { private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; + private InferenceServiceExtension.InferenceServiceFactoryContext context; @Before public void init() throws Exception { webServer.start(); threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -360,7 +369,8 @@ public void testDoChunkedInferAlwaysFails() throws IOException { private DeepSeekService createService() { return new DeepSeekService( HttpRequestSenderTests.createSenderFactory(threadPool, clientManager), - createWithEmptySettings(threadPool) + createWithEmptySettings(threadPool), + context ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 71a073c02e02b..6bbe0b7e0d6c7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; @@ -23,6 +24,7 @@ import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.MinimalServiceSettings; @@ -113,6 +115,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { private ThreadPool threadPool; private HttpClientManager clientManager; + private InferenceServiceExtension.InferenceServiceFactoryContext context; @Override protected Collection> getPlugins() { @@ -125,6 +128,12 @@ public void init() throws Exception { modelRegistry = node().injector().getInstance(ModelRegistry.class); threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -1459,7 +1468,8 @@ private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServ createWithEmptySettings(threadPool), new ElasticInferenceServiceSettings(Settings.EMPTY), modelRegistry, - mockAuthHandler + mockAuthHandler, + context ); } @@ -1488,7 +1498,8 @@ private ElasticInferenceService createService( createWithEmptySettings(threadPool), ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL), modelRegistry, - mockAuthHandler + mockAuthHandler, + context ); } @@ -1501,7 +1512,8 @@ private ElasticInferenceService createServiceWithAuthHandler( createWithEmptySettings(threadPool), ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL), modelRegistry, - new ElasticInferenceServiceAuthorizationRequestHandler(elasticInferenceServiceURL, threadPool) + new ElasticInferenceServiceAuthorizationRequestHandler(elasticInferenceServiceURL, threadPool), + context ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 41175581df1cf..cf853b70bbb07 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; @@ -24,6 +25,7 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -96,6 +98,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; + private InferenceServiceExtension.InferenceServiceFactoryContext context; private HttpClientManager clientManager; @@ -104,6 +107,12 @@ public void init() throws Exception { webServer.start(); threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -658,7 +667,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotGoogleAiStudioModel() throws IOEx var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -696,7 +705,7 @@ public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForModelThatD var model = GoogleAiStudioEmbeddingsModelTests.createModel("model", getUrl(webServer), "secret"); - try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -730,7 +739,7 @@ public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForModelThatD public void testInfer_SendsCompletionRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { "candidates": [ @@ -818,7 +827,7 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { "embeddings": [ @@ -897,7 +906,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { "embeddings": [ @@ -998,7 +1007,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed public void testInfer_ResourceNotFound() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1033,7 +1042,7 @@ public void testInfer_ResourceNotFound() throws IOException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = GoogleAiStudioCompletionModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThrows( ElasticsearchStatusException.class, @@ -1052,7 +1061,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), context)) { var embeddingSize = randomNonNegativeInt(); var model = GoogleAiStudioEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -1124,7 +1133,7 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()))) { + try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()), context)) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1171,6 +1180,6 @@ private Map getRequestConfigMap( } private GoogleAiStudioService createGoogleAiStudioService() { - return new GoogleAiStudioService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new GoogleAiStudioService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index 99a09b983787d..7342e3b5c280b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; @@ -17,6 +18,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; @@ -71,12 +73,19 @@ public class GoogleVertexAiServiceTests extends ESTestCase { private HttpClientManager clientManager; private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private InferenceServiceExtension.InferenceServiceFactoryContext context; @Before public void init() throws Exception { webServer.start(); threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -1043,7 +1052,7 @@ public void testGetConfiguration() throws Exception { private GoogleVertexAiService createGoogleVertexAiService() { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - return new GoogleVertexAiService(senderFactory, createWithEmptySettings(threadPool)); + return new GoogleVertexAiService(senderFactory, createWithEmptySettings(threadPool), context); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java index 3be4b72c1237f..0bfefd66a901a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java @@ -10,7 +10,10 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -41,10 +44,17 @@ public class HuggingFaceBaseServiceTests extends ESTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private ThreadPool threadPool; + private InferenceServiceExtension.InferenceServiceFactoryContext context; @Before public void init() throws Exception { threadPool = createThreadPool(inferenceUtilityPool()); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -60,7 +70,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotHuggingFaceModel() throws IOExcep var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new TestService(factory, createWithEmptySettings(threadPool))) { + try (var service = new TestService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -91,8 +101,12 @@ public void testInfer_ThrowsErrorWhenModelIsNotHuggingFaceModel() throws IOExcep } private static final class TestService extends HuggingFaceService { - TestService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + TestService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + super(factory, serviceComponents, context); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java index 814d533129439..e80c6c0363f30 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java @@ -9,6 +9,7 @@ import org.apache.http.HttpHeaders; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; @@ -17,6 +18,7 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.WeightedToken; import org.elasticsearch.test.ESTestCase; @@ -63,12 +65,19 @@ public class HuggingFaceElserServiceTests extends ESTestCase { private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; + private InferenceServiceExtension.InferenceServiceFactoryContext context; @Before public void init() throws Exception { webServer.start(); threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -81,7 +90,7 @@ public void shutdown() throws IOException { public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceElserService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceElserService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ [ @@ -137,7 +146,8 @@ public void testGetConfiguration() throws Exception { try ( var service = new HuggingFaceElserService( HttpRequestSenderTests.createSenderFactory(threadPool, clientManager), - createWithEmptySettings(threadPool) + createWithEmptySettings(threadPool), + context ) ) { String content = XContentHelper.stripWhitespace(""" diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index e2850910ac64a..8246cfe4454b2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; @@ -24,6 +25,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -103,12 +105,19 @@ public class HuggingFaceServiceTests extends ESTestCase { private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; + private InferenceServiceExtension.InferenceServiceFactoryContext context; @Before public void init() throws Exception { webServer.start(); threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -258,7 +267,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); - try (var service = new HuggingFaceService(factory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -328,7 +337,7 @@ public void testUnifiedCompletionInfer() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -357,7 +366,7 @@ public void testUnifiedCompletionNonStreamingError() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); var latch = new CountDownLatch(1); service.unifiedCompletionInfer( @@ -486,7 +495,7 @@ public void testUnifiedCompletionMalformedError() throws Exception { private void testStreamError(String expectedResponse) throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -548,7 +557,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -621,7 +630,7 @@ public void testInfer_StreamRequestRetry() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new HuggingFaceService(mock(), createWithEmptySettings(mock()))) { + try (var service = new HuggingFaceService(mock(), createWithEmptySettings(mock()), context)) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1009,7 +1018,7 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInTaskSetti public void testInfer_SendsEmbeddingsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1060,7 +1069,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret"); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -1087,7 +1096,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { public void testInfer_SendsElserRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ [ @@ -1139,7 +1148,7 @@ public void testInfer_SendsElserRequest() throws IOException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = HuggingFaceElserModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThrows( ElasticsearchStatusException.class, @@ -1158,7 +1167,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), context)) { var embeddingSize = randomNonNegativeInt(); var model = HuggingFaceEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -1179,7 +1188,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1233,7 +1242,7 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th public void testChunkedInfer() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ [ @@ -1340,7 +1349,7 @@ public void testGetConfiguration() throws Exception { } private HuggingFaceService createHuggingFaceService() { - return new HuggingFaceService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new HuggingFaceService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 35dbcdd6aa99f..1567533774c61 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; @@ -24,6 +25,7 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -98,6 +100,7 @@ public class IbmWatsonxServiceTests extends ESTestCase { private ThreadPool threadPool; private HttpClientManager clientManager; + private InferenceServiceExtension.InferenceServiceFactoryContext context; private static final String apiKey = "apiKey"; private static final String modelId = "model"; @@ -110,6 +113,12 @@ public void init() throws Exception { webServer.start(); threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -597,7 +606,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotIbmWatsonxModel() throws IOExcept var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool))) { + try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -635,7 +644,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = IbmWatsonxEmbeddingsModelTests.createModel(modelId, projectId, URI.create(url), apiVersion, apiKey, getUrl(webServer)); - try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool))) { + try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( @@ -1018,12 +1027,12 @@ private Map getRequestConfigMap( } private IbmWatsonxService createIbmWatsonxService() { - return new IbmWatsonxService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new IbmWatsonxService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context); } private static class IbmWatsonxServiceWithoutAuth extends IbmWatsonxService { IbmWatsonxServiceWithoutAuth(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + super(factory, serviceComponents, mock(ClusterService.class)); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index eca76bc1a702a..eb143e1381feb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; @@ -23,6 +24,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -91,12 +93,19 @@ public class JinaAIServiceTests extends ESTestCase { private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; + private InferenceServiceExtension.InferenceServiceFactoryContext context; @Before public void init() throws Exception { webServer.start(); threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -778,7 +787,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotJinaAIModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new JinaAIService(factory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -819,7 +828,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), context)) { var embeddingSize = randomNonNegativeInt(); var embeddingType = randomFrom(JinaAIEmbeddingType.values()); var model = JinaAIEmbeddingsModelTests.createModel( @@ -846,7 +855,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_Embedding_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -889,7 +898,7 @@ public void testInfer_Embedding_UnauthorisedResponse() throws IOException { public void testInfer_Rerank_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -923,7 +932,7 @@ public void testInfer_Rerank_UnauthorisedResponse() throws IOException { public void testInfer_Embedding_Get_Response_Ingest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -994,7 +1003,7 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { public void testInfer_Embedding_Get_Response_Search() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1065,7 +1074,7 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { public void testInfer_Embedding_Get_Response_clustering() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ {"model":"jina-clip-v2","object":"list","usage":{"total_tokens":5,"prompt_tokens":5}, @@ -1120,7 +1129,7 @@ public void testInfer_Embedding_Get_Response_clustering() throws IOException { public void testInfer_Embedding_Get_Response_NullInputType() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1210,7 +1219,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), context)) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1295,7 +1304,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), context)) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1392,7 +1401,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), context)) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, null); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1475,7 +1484,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), context)) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, true); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1540,7 +1549,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1637,7 +1646,7 @@ public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOExcept private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), context)) { // Batching will call the service with 2 input String responseJson = """ @@ -1800,7 +1809,7 @@ public void testGetConfiguration() throws Exception { } public void testDoesNotSupportsStreaming() throws IOException { - try (var service = new JinaAIService(mock(), createWithEmptySettings(mock()))) { + try (var service = new JinaAIService(mock(), createWithEmptySettings(mock()), context)) { assertFalse(service.canStream(TaskType.COMPLETION)); assertFalse(service.canStream(TaskType.ANY)); } @@ -1841,7 +1850,7 @@ private Map getRequestConfigMap(Map serviceSetti } private JinaAIService createJinaAIService() { - return new JinaAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new JinaAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 4ba9b8aa24394..7eac607e934c4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; @@ -24,6 +25,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -106,12 +108,19 @@ public class MistralServiceTests extends ESTestCase { private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; + private InferenceServiceExtension.InferenceServiceFactoryContext context; @Before public void init() throws Exception { webServer.start(); threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -249,7 +258,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); - try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -308,7 +317,7 @@ public void testUnifiedCompletionInfer() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = MistralChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -353,7 +362,7 @@ public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = MistralChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); var latch = new CountDownLatch(1); service.unifiedCompletionInfer( @@ -421,7 +430,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = MistralChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -459,7 +468,7 @@ public void testInfer_StreamRequest_ErrorResponse() { } public void testSupportsStreaming() throws IOException { - try (var service = new MistralService(mock(), createWithEmptySettings(mock()))) { + try (var service = new MistralService(mock(), createWithEmptySettings(mock()), context)) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -942,7 +951,7 @@ public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenC public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = new Model(ModelConfigurationsTests.createRandomInstance()); assertThrows( @@ -962,7 +971,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), context)) { var embeddingSize = randomNonNegativeInt(); var model = MistralEmbeddingModelTests.createModel( randomAlphaOfLength(10), @@ -990,7 +999,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws I var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -1028,7 +1037,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey", null, null, null, null); - try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( @@ -1086,7 +1095,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1173,7 +1182,7 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1276,7 +1285,7 @@ public void testGetConfiguration() throws Exception { // ---------------------------------------------------------------- private MistralService createService() { - return new MistralService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new MistralService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index c19eb664e88ac..326a7e135af15 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; @@ -25,6 +26,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -106,12 +108,19 @@ public class OpenAiServiceTests extends ESTestCase { private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; + private InferenceServiceExtension.InferenceServiceFactoryContext context; @Before public void init() throws Exception { webServer.start(); threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -847,7 +856,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -885,7 +894,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user"); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -924,7 +933,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { var mockModel = getInvalidModel("model_id", "service_name", TaskType.SPARSE_EMBEDDING); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -965,7 +974,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -1003,7 +1012,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws public void testInfer_SendsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1099,7 +1108,7 @@ public void testUnifiedCompletionInfer() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -1132,7 +1141,7 @@ public void testUnifiedCompletionError() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); var latch = new CountDownLatch(1); service.unifiedCompletionInfer( @@ -1189,7 +1198,7 @@ public void testMidStreamUnifiedCompletionError() throws Exception { private void testStreamError(String expectedResponse) throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -1267,7 +1276,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), context)) { var model = OpenAiChatCompletionModelTests.createCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1344,7 +1353,7 @@ public void testInfer_StreamRequestRetry() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new OpenAiService(mock(), createWithEmptySettings(mock()))) { + try (var service = new OpenAiService(mock(), createWithEmptySettings(mock()), context)) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1400,7 +1409,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1485,7 +1494,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), context)) { // response with 2 embeddings String responseJson = """ @@ -1656,6 +1665,6 @@ public void testGetConfiguration() throws Exception { } private OpenAiService createOpenAiService() { - return new OpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new OpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java index d7d9473f18084..00c9b9d11dca7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java @@ -12,11 +12,13 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; @@ -84,7 +86,14 @@ public void init() { ThreadPool threadPool = mock(); when(threadPool.executor(anyString())).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); - sageMakerService = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of); + sageMakerService = new SageMakerService( + modelBuilder, + client, + schemas, + threadPool, + Map::of, + new InferenceServiceExtension.InferenceServiceFactoryContext(mock(), threadPool, mock(ClusterService.class), Settings.EMPTY) + ); } public void testSupportedTaskTypes() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 8602621e9eb78..ad80dde708d84 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; @@ -22,6 +23,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -89,12 +91,19 @@ public class VoyageAIServiceTests extends ESTestCase { private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; + private InferenceServiceExtension.InferenceServiceFactoryContext context; @Before public void init() throws Exception { webServer.start(); threadPool = createThreadPool(inferenceUtilityPool()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + mock(ClusterService.class), + Settings.EMPTY + ); } @After @@ -718,7 +727,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotVoyageAIModel() throws IOExceptio var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -763,7 +772,7 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOExcept "voyage-3-large" ); - try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool), context)) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( @@ -806,7 +815,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), context)) { var embeddingSize = randomNonNegativeInt(); var model = VoyageAIEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -831,7 +840,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_Embedding_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -873,7 +882,7 @@ public void testInfer_Embedding_UnauthorisedResponse() throws IOException { public void testInfer_Rerank_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -907,7 +916,7 @@ public void testInfer_Rerank_UnauthorisedResponse() throws IOException { public void testInfer_Embedding_Get_Response_Ingest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -989,7 +998,7 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { public void testInfer_Embedding_Get_Response_Search() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1071,7 +1080,7 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { public void testInfer_Embedding_Get_Response_NullInputType() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1163,7 +1172,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), context)) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, false, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1251,7 +1260,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), context)) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, false, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1345,7 +1354,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), context)) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, null, null); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1423,7 +1432,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), context)) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, true, true); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1490,7 +1499,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), context)) { String responseJson = """ { @@ -1599,7 +1608,7 @@ public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOExcept private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), context)) { // Batching will call the service with 2 input String responseJson = """ @@ -1745,7 +1754,7 @@ public void testGetConfiguration() throws Exception { } public void testDoesNotSupportsStreaming() throws IOException { - try (var service = new VoyageAIService(mock(), createWithEmptySettings(mock()))) { + try (var service = new VoyageAIService(mock(), createWithEmptySettings(mock()), context)) { assertFalse(service.canStream(TaskType.COMPLETION)); assertFalse(service.canStream(TaskType.ANY)); } @@ -1786,7 +1795,7 @@ private Map getRequestConfigMap(Map serviceSetti } private VoyageAIService createVoyageAIService() { - return new VoyageAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new VoyageAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), context); } }