Skip to content

Commit dd86d48

Browse files
supply inference context to all inference services and added a constructor for testing purposes
1 parent 6fe0bc1 commit dd86d48

File tree

22 files changed

+212
-63
lines changed

22 files changed

+212
-63
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,8 @@ public Collection<?> createComponents(PluginServices services) {
321321
serviceComponents.get(),
322322
inferenceServiceSettings,
323323
modelRegistry.get(),
324-
authorizationHandler
324+
authorizationHandler,
325+
context
325326
),
326327
context -> new SageMakerService(
327328
new SageMakerModelBuilder(sageMakerSchemas),
@@ -332,7 +333,7 @@ public Collection<?> createComponents(PluginServices services) {
332333
sageMakerSchemas,
333334
services.threadPool(),
334335
sageMakerConfigurations::getOrCompute,
335-
serviceComponents.get()
336+
context
336337
)
337338
)
338339
);
@@ -394,24 +395,24 @@ public void loadExtensions(ExtensionLoader loader) {
394395

395396
public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
396397
return List.of(
397-
context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()),
398-
context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()),
399-
context -> new OpenAiService(httpFactory.get(), serviceComponents.get()),
400-
context -> new CohereService(httpFactory.get(), serviceComponents.get()),
401-
context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()),
402-
context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get()),
403-
context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get()),
404-
context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get()),
405-
context -> new MistralService(httpFactory.get(), serviceComponents.get()),
406-
context -> new AnthropicService(httpFactory.get(), serviceComponents.get()),
407-
context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()),
408-
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
409-
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
410-
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
411-
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
412-
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
398+
context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get(), context),
399+
context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get(), context),
400+
context -> new OpenAiService(httpFactory.get(), serviceComponents.get(), context),
401+
context -> new CohereService(httpFactory.get(), serviceComponents.get(), context),
402+
context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get(), context),
403+
context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get(), context),
404+
context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get(), context),
405+
context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get(), context),
406+
context -> new MistralService(httpFactory.get(), serviceComponents.get(), context),
407+
context -> new AnthropicService(httpFactory.get(), serviceComponents.get(), context),
408+
context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get(), context),
409+
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get(), context),
410+
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get(), context),
411+
context -> new JinaAIService(httpFactory.get(), serviceComponents.get(), context),
412+
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get(), context),
413+
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context),
413414
ElasticsearchInternalService::new,
414-
context -> new CustomService(httpFactory.get(), serviceComponents.get())
415+
context -> new CustomService(httpFactory.get(), serviceComponents.get(), context)
415416
);
416417
}
417418

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.cluster.service.ClusterService;
1213
import org.elasticsearch.common.ValidationException;
1314
import org.elasticsearch.core.IOUtils;
1415
import org.elasticsearch.core.Nullable;
@@ -17,6 +18,7 @@
1718
import org.elasticsearch.inference.ChunkInferenceInput;
1819
import org.elasticsearch.inference.ChunkedInference;
1920
import org.elasticsearch.inference.InferenceService;
21+
import org.elasticsearch.inference.InferenceServiceExtension;
2022
import org.elasticsearch.inference.InferenceServiceResults;
2123
import org.elasticsearch.inference.InputType;
2224
import org.elasticsearch.inference.Model;
@@ -43,11 +45,21 @@ public abstract class SenderService implements InferenceService {
4345
protected static final Set<TaskType> COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION);
4446
private final Sender sender;
4547
private final ServiceComponents serviceComponents;
48+
private final ClusterService clusterService;
4649

47-
public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
50+
public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) {
4851
Objects.requireNonNull(factory);
4952
sender = factory.createSender();
5053
this.serviceComponents = Objects.requireNonNull(serviceComponents);
54+
this.clusterService = Objects.requireNonNull(context.clusterService());
55+
}
56+
57+
// for testing
58+
public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
59+
Objects.requireNonNull(factory);
60+
sender = factory.createSender();
61+
this.serviceComponents = Objects.requireNonNull(serviceComponents);
62+
this.clusterService = clusterService;
5163
}
5264

5365
public Sender getSender() {
@@ -75,7 +87,7 @@ public void infer(
7587
var chunkInferenceInput = input.stream().map(i -> new ChunkInferenceInput(i, null)).toList();
7688
var inferenceInput = createInput(this, model, chunkInferenceInput, inputType, query, returnDocuments, topN, stream);
7789
if (timeout == null) {
78-
timeout = serviceComponents.clusterService().getClusterSettings().get(InferencePlugin.QUERY_INFERENCE_TIMEOUT);
90+
timeout = clusterService.getClusterSettings().get(InferencePlugin.QUERY_INFERENCE_TIMEOUT);
7991
}
8092
doInfer(model, inferenceInput, taskSettings, timeout, listener);
8193
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14+
import org.elasticsearch.cluster.service.ClusterService;
1415
import org.elasticsearch.common.ValidationException;
1516
import org.elasticsearch.common.util.LazyInitializable;
1617
import org.elasticsearch.core.Nullable;
@@ -19,6 +20,7 @@
1920
import org.elasticsearch.inference.ChunkedInference;
2021
import org.elasticsearch.inference.ChunkingSettings;
2122
import org.elasticsearch.inference.InferenceServiceConfiguration;
23+
import org.elasticsearch.inference.InferenceServiceExtension;
2224
import org.elasticsearch.inference.InferenceServiceResults;
2325
import org.elasticsearch.inference.InputType;
2426
import org.elasticsearch.inference.Model;
@@ -85,8 +87,13 @@ public class AlibabaCloudSearchService extends SenderService {
8587
InputType.INTERNAL_SEARCH
8688
);
8789

88-
public AlibabaCloudSearchService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
89-
super(factory, serviceComponents);
90+
// for testing
91+
public AlibabaCloudSearchService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
92+
super(factory, serviceComponents, clusterService);
93+
}
94+
95+
public AlibabaCloudSearchService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) {
96+
super(factory, serviceComponents, context);
9097
}
9198

9299
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14+
import org.elasticsearch.cluster.service.ClusterService;
1415
import org.elasticsearch.common.Strings;
1516
import org.elasticsearch.common.ValidationException;
1617
import org.elasticsearch.common.util.LazyInitializable;
@@ -20,6 +21,7 @@
2021
import org.elasticsearch.inference.ChunkedInference;
2122
import org.elasticsearch.inference.ChunkingSettings;
2223
import org.elasticsearch.inference.InferenceServiceConfiguration;
24+
import org.elasticsearch.inference.InferenceServiceExtension;
2325
import org.elasticsearch.inference.InferenceServiceResults;
2426
import org.elasticsearch.inference.InputType;
2527
import org.elasticsearch.inference.Model;
@@ -90,12 +92,24 @@ public class AmazonBedrockService extends SenderService {
9092
InputType.UNSPECIFIED
9193
);
9294

95+
// for testing
9396
public AmazonBedrockService(
9497
HttpRequestSender.Factory httpSenderFactory,
9598
AmazonBedrockRequestSender.Factory amazonBedrockFactory,
96-
ServiceComponents serviceComponents
99+
ServiceComponents serviceComponents,
100+
ClusterService clusterService
97101
) {
98-
super(httpSenderFactory, serviceComponents);
102+
super(httpSenderFactory, serviceComponents, clusterService);
103+
this.amazonBedrockSender = amazonBedrockFactory.createSender();
104+
}
105+
106+
public AmazonBedrockService(
107+
HttpRequestSender.Factory httpSenderFactory,
108+
AmazonBedrockRequestSender.Factory amazonBedrockFactory,
109+
ServiceComponents serviceComponents,
110+
InferenceServiceExtension.InferenceServiceFactoryContext context
111+
) {
112+
super(httpSenderFactory, serviceComponents, context);
99113
this.amazonBedrockSender = amazonBedrockFactory.createSender();
100114
}
101115

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14+
import org.elasticsearch.cluster.service.ClusterService;
1415
import org.elasticsearch.common.ValidationException;
1516
import org.elasticsearch.common.util.LazyInitializable;
1617
import org.elasticsearch.core.Nullable;
1718
import org.elasticsearch.core.TimeValue;
1819
import org.elasticsearch.inference.ChunkedInference;
1920
import org.elasticsearch.inference.InferenceServiceConfiguration;
21+
import org.elasticsearch.inference.InferenceServiceExtension;
2022
import org.elasticsearch.inference.InferenceServiceResults;
2123
import org.elasticsearch.inference.InputType;
2224
import org.elasticsearch.inference.Model;
@@ -58,8 +60,13 @@ public class AnthropicService extends SenderService {
5860

5961
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.COMPLETION);
6062

61-
public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
62-
super(factory, serviceComponents);
63+
// for testing
64+
public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
65+
super(factory, serviceComponents, clusterService);
66+
}
67+
68+
public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) {
69+
super(factory, serviceComponents, context);
6370
}
6471

6572
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14+
import org.elasticsearch.cluster.service.ClusterService;
1415
import org.elasticsearch.common.Strings;
1516
import org.elasticsearch.common.ValidationException;
1617
import org.elasticsearch.common.util.LazyInitializable;
@@ -19,6 +20,7 @@
1920
import org.elasticsearch.inference.ChunkedInference;
2021
import org.elasticsearch.inference.ChunkingSettings;
2122
import org.elasticsearch.inference.InferenceServiceConfiguration;
23+
import org.elasticsearch.inference.InferenceServiceExtension;
2224
import org.elasticsearch.inference.InferenceServiceResults;
2325
import org.elasticsearch.inference.InputType;
2426
import org.elasticsearch.inference.Model;
@@ -83,8 +85,13 @@ public class AzureAiStudioService extends SenderService {
8385
InputType.INTERNAL_SEARCH
8486
);
8587

86-
public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
87-
super(factory, serviceComponents);
88+
// for testing
89+
public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
90+
super(factory, serviceComponents, clusterService);
91+
}
92+
93+
public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) {
94+
super(factory, serviceComponents, context);
8895
}
8996

9097
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14+
import org.elasticsearch.cluster.service.ClusterService;
1415
import org.elasticsearch.common.ValidationException;
1516
import org.elasticsearch.common.util.LazyInitializable;
1617
import org.elasticsearch.core.Nullable;
1718
import org.elasticsearch.core.TimeValue;
1819
import org.elasticsearch.inference.ChunkedInference;
1920
import org.elasticsearch.inference.ChunkingSettings;
2021
import org.elasticsearch.inference.InferenceServiceConfiguration;
22+
import org.elasticsearch.inference.InferenceServiceExtension;
2123
import org.elasticsearch.inference.InferenceServiceResults;
2224
import org.elasticsearch.inference.InputType;
2325
import org.elasticsearch.inference.Model;
@@ -69,8 +71,13 @@ public class AzureOpenAiService extends SenderService {
6971
private static final String SERVICE_NAME = "Azure OpenAI";
7072
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
7173

72-
public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
73-
super(factory, serviceComponents);
74+
// for testing
75+
public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
76+
super(factory, serviceComponents, clusterService);
77+
}
78+
79+
public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) {
80+
super(factory, serviceComponents, context);
7481
}
7582

7683
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14+
import org.elasticsearch.cluster.service.ClusterService;
1415
import org.elasticsearch.common.ValidationException;
1516
import org.elasticsearch.common.util.LazyInitializable;
1617
import org.elasticsearch.core.Nullable;
@@ -19,6 +20,7 @@
1920
import org.elasticsearch.inference.ChunkedInference;
2021
import org.elasticsearch.inference.ChunkingSettings;
2122
import org.elasticsearch.inference.InferenceServiceConfiguration;
23+
import org.elasticsearch.inference.InferenceServiceExtension;
2224
import org.elasticsearch.inference.InferenceServiceResults;
2325
import org.elasticsearch.inference.InputType;
2426
import org.elasticsearch.inference.Model;
@@ -84,8 +86,13 @@ public class CohereService extends SenderService {
8486
// The reason it needs to be done here is that the batching logic needs to hold state but the *RequestManagers are instantiated
8587
// on every request
8688

87-
public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
88-
super(factory, serviceComponents);
89+
// for testing
90+
public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
91+
super(factory, serviceComponents, clusterService);
92+
}
93+
94+
public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) {
95+
super(factory, serviceComponents, context);
8996
}
9097

9198
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14+
import org.elasticsearch.cluster.service.ClusterService;
1415
import org.elasticsearch.common.ValidationException;
1516
import org.elasticsearch.common.util.LazyInitializable;
1617
import org.elasticsearch.core.Nullable;
@@ -19,6 +20,7 @@
1920
import org.elasticsearch.inference.ChunkedInference;
2021
import org.elasticsearch.inference.ChunkingSettings;
2122
import org.elasticsearch.inference.InferenceServiceConfiguration;
23+
import org.elasticsearch.inference.InferenceServiceExtension;
2224
import org.elasticsearch.inference.InferenceServiceResults;
2325
import org.elasticsearch.inference.InputType;
2426
import org.elasticsearch.inference.Model;
@@ -74,8 +76,12 @@ public class CustomService extends SenderService {
7476
TaskType.COMPLETION
7577
);
7678

77-
public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
78-
super(factory, serviceComponents);
79+
public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
80+
super(factory, serviceComponents, clusterService);
81+
}
82+
83+
public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) {
84+
super(factory, serviceComponents, context);
7985
}
8086

8187
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
import org.elasticsearch.TransportVersion;
1111
import org.elasticsearch.TransportVersions;
1212
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.cluster.service.ClusterService;
1314
import org.elasticsearch.common.ValidationException;
1415
import org.elasticsearch.common.util.LazyInitializable;
1516
import org.elasticsearch.core.Strings;
1617
import org.elasticsearch.core.TimeValue;
1718
import org.elasticsearch.inference.ChunkedInference;
1819
import org.elasticsearch.inference.InferenceServiceConfiguration;
20+
import org.elasticsearch.inference.InferenceServiceExtension;
1921
import org.elasticsearch.inference.InferenceServiceResults;
2022
import org.elasticsearch.inference.InputType;
2123
import org.elasticsearch.inference.Model;
@@ -58,8 +60,13 @@ public class DeepSeekService extends SenderService {
5860
);
5961
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES_FOR_STREAMING = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
6062

61-
public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
62-
super(factory, serviceComponents);
63+
// for testing
64+
public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
65+
super(factory, serviceComponents, clusterService);
66+
}
67+
68+
public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) {
69+
super(factory, serviceComponents, context);
6370
}
6471

6572
@Override

0 commit comments

Comments
 (0)