Skip to content

Commit 560ffb9

Browse files
[Inference Timeout] Supply inference context to all third party services (#131251)
* Refactoring inference services to accept context * fix linting issues * adding mock cluster service to fix IT test * refactoring to remove duplication in constructors * remove unnecessary blank line * refactor to have uniform constructor call * refactor to have uniform constructor call for sagemaker * fix linting issues * fix failed unit tests --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent 6ed50e1 commit 560ffb9

File tree

44 files changed

+585
-224
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+585
-224
lines changed

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@ private ElasticInferenceService createElasticInferenceService() {
350350
createWithEmptySettings(threadPool),
351351
ElasticInferenceServiceSettingsTests.create(gatewayUrl),
352352
modelRegistry,
353-
new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool)
353+
new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool),
354+
mockClusterServiceEmpty()
354355
);
355356
}
356357
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,8 @@ public Collection<?> createComponents(PluginServices services) {
311311
serviceComponents.get(),
312312
inferenceServiceSettings,
313313
modelRegistry.get(),
314-
authorizationHandler
314+
authorizationHandler,
315+
context
315316
),
316317
context -> new SageMakerService(
317318
new SageMakerModelBuilder(sageMakerSchemas),
@@ -321,7 +322,8 @@ public Collection<?> createComponents(PluginServices services) {
321322
),
322323
sageMakerSchemas,
323324
services.threadPool(),
324-
sageMakerConfigurations::getOrCompute
325+
sageMakerConfigurations::getOrCompute,
326+
context
325327
)
326328
)
327329
);
@@ -383,24 +385,24 @@ public void loadExtensions(ExtensionLoader loader) {
383385

384386
public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
385387
return List.of(
386-
context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()),
387-
context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()),
388-
context -> new OpenAiService(httpFactory.get(), serviceComponents.get()),
389-
context -> new CohereService(httpFactory.get(), serviceComponents.get()),
390-
context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()),
391-
context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get()),
392-
context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get()),
393-
context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get()),
394-
context -> new MistralService(httpFactory.get(), serviceComponents.get()),
395-
context -> new AnthropicService(httpFactory.get(), serviceComponents.get()),
396-
context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()),
397-
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
398-
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
399-
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
400-
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
401-
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
388+
context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get(), context),
389+
context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get(), context),
390+
context -> new OpenAiService(httpFactory.get(), serviceComponents.get(), context),
391+
context -> new CohereService(httpFactory.get(), serviceComponents.get(), context),
392+
context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get(), context),
393+
context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get(), context),
394+
context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get(), context),
395+
context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get(), context),
396+
context -> new MistralService(httpFactory.get(), serviceComponents.get(), context),
397+
context -> new AnthropicService(httpFactory.get(), serviceComponents.get(), context),
398+
context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get(), context),
399+
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get(), context),
400+
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get(), context),
401+
context -> new JinaAIService(httpFactory.get(), serviceComponents.get(), context),
402+
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get(), context),
403+
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context),
402404
ElasticsearchInternalService::new,
403-
context -> new CustomService(httpFactory.get(), serviceComponents.get())
405+
context -> new CustomService(httpFactory.get(), serviceComponents.get(), context)
404406
);
405407
}
406408

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.cluster.service.ClusterService;
1213
import org.elasticsearch.common.ValidationException;
1314
import org.elasticsearch.core.IOUtils;
1415
import org.elasticsearch.core.Nullable;
@@ -42,11 +43,13 @@ public abstract class SenderService implements InferenceService {
4243
protected static final Set<TaskType> COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION);
4344
private final Sender sender;
4445
private final ServiceComponents serviceComponents;
46+
private final ClusterService clusterService;
4547

46-
public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
48+
public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
4749
Objects.requireNonNull(factory);
4850
sender = factory.createSender();
4951
this.serviceComponents = Objects.requireNonNull(serviceComponents);
52+
this.clusterService = Objects.requireNonNull(clusterService);
5053
}
5154

5255
public Sender getSender() {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14+
import org.elasticsearch.cluster.service.ClusterService;
1415
import org.elasticsearch.common.ValidationException;
1516
import org.elasticsearch.common.util.LazyInitializable;
1617
import org.elasticsearch.core.Nullable;
@@ -19,6 +20,7 @@
1920
import org.elasticsearch.inference.ChunkedInference;
2021
import org.elasticsearch.inference.ChunkingSettings;
2122
import org.elasticsearch.inference.InferenceServiceConfiguration;
23+
import org.elasticsearch.inference.InferenceServiceExtension;
2224
import org.elasticsearch.inference.InferenceServiceResults;
2325
import org.elasticsearch.inference.InputType;
2426
import org.elasticsearch.inference.Model;
@@ -85,8 +87,20 @@ public class AlibabaCloudSearchService extends SenderService {
8587
InputType.INTERNAL_SEARCH
8688
);
8789

88-
public AlibabaCloudSearchService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
89-
super(factory, serviceComponents);
90+
public AlibabaCloudSearchService(
91+
HttpRequestSender.Factory factory,
92+
ServiceComponents serviceComponents,
93+
InferenceServiceExtension.InferenceServiceFactoryContext context
94+
) {
95+
this(factory, serviceComponents, context.clusterService());
96+
}
97+
98+
public AlibabaCloudSearchService(
99+
HttpRequestSender.Factory factory,
100+
ServiceComponents serviceComponents,
101+
ClusterService clusterService
102+
) {
103+
super(factory, serviceComponents, clusterService);
90104
}
91105

92106
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14+
import org.elasticsearch.cluster.service.ClusterService;
1415
import org.elasticsearch.common.Strings;
1516
import org.elasticsearch.common.ValidationException;
1617
import org.elasticsearch.common.util.LazyInitializable;
@@ -20,6 +21,7 @@
2021
import org.elasticsearch.inference.ChunkedInference;
2122
import org.elasticsearch.inference.ChunkingSettings;
2223
import org.elasticsearch.inference.InferenceServiceConfiguration;
24+
import org.elasticsearch.inference.InferenceServiceExtension;
2325
import org.elasticsearch.inference.InferenceServiceResults;
2426
import org.elasticsearch.inference.InputType;
2527
import org.elasticsearch.inference.Model;
@@ -93,9 +95,19 @@ public class AmazonBedrockService extends SenderService {
9395
public AmazonBedrockService(
9496
HttpRequestSender.Factory httpSenderFactory,
9597
AmazonBedrockRequestSender.Factory amazonBedrockFactory,
96-
ServiceComponents serviceComponents
98+
ServiceComponents serviceComponents,
99+
InferenceServiceExtension.InferenceServiceFactoryContext context
97100
) {
98-
super(httpSenderFactory, serviceComponents);
101+
this(httpSenderFactory, amazonBedrockFactory, serviceComponents, context.clusterService());
102+
}
103+
104+
public AmazonBedrockService(
105+
HttpRequestSender.Factory httpSenderFactory,
106+
AmazonBedrockRequestSender.Factory amazonBedrockFactory,
107+
ServiceComponents serviceComponents,
108+
ClusterService clusterService
109+
) {
110+
super(httpSenderFactory, serviceComponents, clusterService);
99111
this.amazonBedrockSender = amazonBedrockFactory.createSender();
100112
}
101113

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14+
import org.elasticsearch.cluster.service.ClusterService;
1415
import org.elasticsearch.common.ValidationException;
1516
import org.elasticsearch.common.util.LazyInitializable;
1617
import org.elasticsearch.core.Nullable;
1718
import org.elasticsearch.core.TimeValue;
1819
import org.elasticsearch.inference.ChunkedInference;
1920
import org.elasticsearch.inference.InferenceServiceConfiguration;
21+
import org.elasticsearch.inference.InferenceServiceExtension;
2022
import org.elasticsearch.inference.InferenceServiceResults;
2123
import org.elasticsearch.inference.InputType;
2224
import org.elasticsearch.inference.Model;
@@ -58,8 +60,16 @@ public class AnthropicService extends SenderService {
5860

5961
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.COMPLETION);
6062

61-
public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
62-
super(factory, serviceComponents);
63+
public AnthropicService(
64+
HttpRequestSender.Factory factory,
65+
ServiceComponents serviceComponents,
66+
InferenceServiceExtension.InferenceServiceFactoryContext context
67+
) {
68+
this(factory, serviceComponents, context.clusterService());
69+
}
70+
71+
public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
72+
super(factory, serviceComponents, clusterService);
6373
}
6474

6575
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14+
import org.elasticsearch.cluster.service.ClusterService;
1415
import org.elasticsearch.common.Strings;
1516
import org.elasticsearch.common.ValidationException;
1617
import org.elasticsearch.common.util.LazyInitializable;
@@ -19,6 +20,7 @@
1920
import org.elasticsearch.inference.ChunkedInference;
2021
import org.elasticsearch.inference.ChunkingSettings;
2122
import org.elasticsearch.inference.InferenceServiceConfiguration;
23+
import org.elasticsearch.inference.InferenceServiceExtension;
2224
import org.elasticsearch.inference.InferenceServiceResults;
2325
import org.elasticsearch.inference.InputType;
2426
import org.elasticsearch.inference.Model;
@@ -84,8 +86,16 @@ public class AzureAiStudioService extends SenderService {
8486
InputType.INTERNAL_SEARCH
8587
);
8688

87-
public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
88-
super(factory, serviceComponents);
89+
public AzureAiStudioService(
90+
HttpRequestSender.Factory factory,
91+
ServiceComponents serviceComponents,
92+
InferenceServiceExtension.InferenceServiceFactoryContext context
93+
) {
94+
this(factory, serviceComponents, context.clusterService());
95+
}
96+
97+
public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
98+
super(factory, serviceComponents, clusterService);
8999
}
90100

91101
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14+
import org.elasticsearch.cluster.service.ClusterService;
1415
import org.elasticsearch.common.ValidationException;
1516
import org.elasticsearch.common.util.LazyInitializable;
1617
import org.elasticsearch.core.Nullable;
1718
import org.elasticsearch.core.TimeValue;
1819
import org.elasticsearch.inference.ChunkedInference;
1920
import org.elasticsearch.inference.ChunkingSettings;
2021
import org.elasticsearch.inference.InferenceServiceConfiguration;
22+
import org.elasticsearch.inference.InferenceServiceExtension;
2123
import org.elasticsearch.inference.InferenceServiceResults;
2224
import org.elasticsearch.inference.InputType;
2325
import org.elasticsearch.inference.Model;
@@ -69,8 +71,16 @@ public class AzureOpenAiService extends SenderService {
6971
private static final String SERVICE_NAME = "Azure OpenAI";
7072
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
7173

72-
public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
73-
super(factory, serviceComponents);
74+
public AzureOpenAiService(
75+
HttpRequestSender.Factory factory,
76+
ServiceComponents serviceComponents,
77+
InferenceServiceExtension.InferenceServiceFactoryContext context
78+
) {
79+
this(factory, serviceComponents, context.clusterService());
80+
}
81+
82+
public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
83+
super(factory, serviceComponents, clusterService);
7484
}
7585

7686
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14+
import org.elasticsearch.cluster.service.ClusterService;
1415
import org.elasticsearch.common.ValidationException;
1516
import org.elasticsearch.common.util.LazyInitializable;
1617
import org.elasticsearch.core.Nullable;
@@ -19,6 +20,7 @@
1920
import org.elasticsearch.inference.ChunkedInference;
2021
import org.elasticsearch.inference.ChunkingSettings;
2122
import org.elasticsearch.inference.InferenceServiceConfiguration;
23+
import org.elasticsearch.inference.InferenceServiceExtension;
2224
import org.elasticsearch.inference.InferenceServiceResults;
2325
import org.elasticsearch.inference.InputType;
2426
import org.elasticsearch.inference.Model;
@@ -84,8 +86,16 @@ public class CohereService extends SenderService {
8486
// The reason it needs to be done here is that the batching logic needs to hold state but the *RequestManagers are instantiated
8587
// on every request
8688

87-
public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
88-
super(factory, serviceComponents);
89+
public CohereService(
90+
HttpRequestSender.Factory factory,
91+
ServiceComponents serviceComponents,
92+
InferenceServiceExtension.InferenceServiceFactoryContext context
93+
) {
94+
this(factory, serviceComponents, context.clusterService());
95+
}
96+
97+
public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
98+
super(factory, serviceComponents, clusterService);
8999
}
90100

91101
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14+
import org.elasticsearch.cluster.service.ClusterService;
1415
import org.elasticsearch.common.ValidationException;
1516
import org.elasticsearch.common.util.LazyInitializable;
1617
import org.elasticsearch.core.Nullable;
@@ -19,6 +20,7 @@
1920
import org.elasticsearch.inference.ChunkedInference;
2021
import org.elasticsearch.inference.ChunkingSettings;
2122
import org.elasticsearch.inference.InferenceServiceConfiguration;
23+
import org.elasticsearch.inference.InferenceServiceExtension;
2224
import org.elasticsearch.inference.InferenceServiceResults;
2325
import org.elasticsearch.inference.InputType;
2426
import org.elasticsearch.inference.Model;
@@ -74,8 +76,16 @@ public class CustomService extends SenderService {
7476
TaskType.COMPLETION
7577
);
7678

77-
public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
78-
super(factory, serviceComponents);
79+
public CustomService(
80+
HttpRequestSender.Factory factory,
81+
ServiceComponents serviceComponents,
82+
InferenceServiceExtension.InferenceServiceFactoryContext context
83+
) {
84+
this(factory, serviceComponents, context.clusterService());
85+
}
86+
87+
public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
88+
super(factory, serviceComponents, clusterService);
7989
}
8090

8191
@Override

0 commit comments

Comments
 (0)