Skip to content

Commit 071e7ce

Browse files
authored
[ML] Move code specific to the Elasticsearch in cluster services to those sevices (elastic#113749)
Remove the platform arch argument from parseRequest and move code used by internal services out of the transport action into the service.
1 parent 0fbb3bc commit 071e7ce

File tree

44 files changed

+294
-330
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+294
-330
lines changed

muted-tests.yml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,12 +278,9 @@ tests:
278278
- class: org.elasticsearch.xpack.ml.integration.MlJobIT
279279
method: testCreateJobsWithIndexNameOption
280280
issue: https://github.com/elastic/elasticsearch/issues/113528
281-
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
282-
method: testPutE5WithTrainedModelAndInference
283-
issue: https://github.com/elastic/elasticsearch/issues/113565
284-
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
285-
method: testPutE5Small_withPlatformAgnosticVariant
286-
issue: https://github.com/elastic/elasticsearch/issues/113577
281+
- class: org.elasticsearch.validation.DotPrefixClientYamlTestSuiteIT
282+
method: test {p0=dot_prefix/10_basic/Deprecated index template with a dot prefix index pattern}
283+
issue: https://github.com/elastic/elasticsearch/issues/113529
287284
- class: org.elasticsearch.xpack.ml.integration.MlJobIT
288285
method: testCantCreateJobWithSameID
289286
issue: https://github.com/elastic/elasticsearch/issues/113581

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,9 @@ default void init(Client client) {}
3939
* @param modelId Model Id
4040
* @param taskType The model task type
4141
* @param config Configuration options including the secrets
42-
* @param platformArchitectures The Set of platform architectures (OS name and hardware architecture)
43-
* the cluster nodes and models are running on.
4442
* @param parsedModelListener A listener which will handle the resulting model or failure
4543
*/
46-
void parseRequestConfig(
47-
String modelId,
48-
TaskType taskType,
49-
Map<String, Object> config,
50-
Set<String> platformArchitectures,
51-
ActionListener<Model> parsedModelListener
52-
);
44+
void parseRequestConfig(String modelId, TaskType taskType, Map<String, Object> config, ActionListener<Model> parsedModelListener);
5345

5446
/**
5547
* Parse model configuration from {@code config map} from persisted storage and return the parsed {@link Model}. This requires that
@@ -155,17 +147,6 @@ default void putModel(Model modelVariant, ActionListener<Boolean> listener) {
155147
listener.onResponse(true);
156148
}
157149

158-
/**
159-
* Checks if the modelId has been downloaded to the local Elasticsearch cluster using the trained models API
160-
* The default action does nothing except acknowledge the request (false).
161-
* Any internal services should Override this method.
162-
* @param model
163-
* @param listener The listener
164-
*/
165-
default void isModelDownloaded(Model model, ActionListener<Boolean> listener) {
166-
listener.onResponse(false);
167-
};
168-
169150
/**
170151
* Optionally test the new model configuration in the inference service.
171152
* This function should be called when the model is first created, the
@@ -188,14 +169,6 @@ default Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
188169
return model;
189170
}
190171

191-
/**
192-
* Return true if this model is hosted in the local Elasticsearch cluster
193-
* @return True if in cluster
194-
*/
195-
default boolean isInClusterService() {
196-
return false;
197-
}
198-
199172
/**
200173
* Defines the version required across all clusters to use this service
201174
* @return {@link TransportVersion} specifying the version

server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.inference;
1111

1212
import org.elasticsearch.client.internal.Client;
13+
import org.elasticsearch.threadpool.ThreadPool;
1314

1415
import java.util.List;
1516

@@ -20,7 +21,7 @@ public interface InferenceServiceExtension {
2021

2122
List<Factory> getInferenceServiceFactories();
2223

23-
record InferenceServiceFactoryContext(Client client) {}
24+
record InferenceServiceFactoryContext(Client client, ThreadPool threadPool) {}
2425

2526
interface Factory {
2627
/**

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ protected String getTestRestCluster() {
5353
@Override
5454
protected Settings restClientSettings() {
5555
String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
56-
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
56+
return Settings.builder()
57+
.put(ThreadContext.PREFIX + ".Authorization", token)
58+
.put(CLIENT_SOCKET_TIMEOUT, "120s") // Long timeout for model download
59+
.build();
5760
}
5861

5962
static String mockSparseServiceModelConfig() {

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
import java.util.ArrayList;
3939
import java.util.List;
4040
import java.util.Map;
41-
import java.util.Set;
4241

4342
public class TestDenseInferenceServiceExtension implements InferenceServiceExtension {
4443
@Override
@@ -76,7 +75,6 @@ public void parseRequestConfig(
7675
String modelId,
7776
TaskType taskType,
7877
Map<String, Object> config,
79-
Set<String> platformArchitectures,
8078
ActionListener<Model> parsedModelListener
8179
) {
8280
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import java.util.ArrayList;
3535
import java.util.List;
3636
import java.util.Map;
37-
import java.util.Set;
3837

3938
public class TestRerankingServiceExtension implements InferenceServiceExtension {
4039
@Override
@@ -67,7 +66,6 @@ public void parseRequestConfig(
6766
String modelId,
6867
TaskType taskType,
6968
Map<String, Object> config,
70-
Set<String> platformArchitectures,
7169
ActionListener<Model> parsedModelListener
7270
) {
7371
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
import java.util.ArrayList;
3838
import java.util.List;
3939
import java.util.Map;
40-
import java.util.Set;
4140

4241
public class TestSparseInferenceServiceExtension implements InferenceServiceExtension {
4342
@Override
@@ -70,7 +69,6 @@ public void parseRequestConfig(
7069
String modelId,
7170
TaskType taskType,
7271
Map<String, Object> config,
73-
Set<String> platformArchitectures,
7472
ActionListener<Model> parsedModelListener
7573
) {
7674
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ public void parseRequestConfig(
6767
String modelId,
6868
TaskType taskType,
6969
Map<String, Object> config,
70-
Set<String> platformArchitectures,
7170
ActionListener<Model> parsedModelListener
7271
) {
7372
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.plugins.Plugin;
2424
import org.elasticsearch.reindex.ReindexPlugin;
2525
import org.elasticsearch.test.ESSingleNodeTestCase;
26+
import org.elasticsearch.threadpool.ThreadPool;
2627
import org.elasticsearch.xcontent.ToXContentObject;
2728
import org.elasticsearch.xcontent.XContentBuilder;
2829
import org.elasticsearch.xpack.inference.InferencePlugin;
@@ -117,7 +118,9 @@ public void testGetModel() throws Exception {
117118

118119
assertEquals(model.getConfigurations().getService(), modelHolder.get().service());
119120

120-
var elserService = new ElserInternalService(new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class)));
121+
var elserService = new ElserInternalService(
122+
new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class), mock(ThreadPool.class))
123+
);
121124
ElserInternalModel roundTripModel = elserService.parsePersistedConfigWithSecrets(
122125
modelHolder.get().inferenceEntityId(),
123126
modelHolder.get().taskType(),

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ public Collection<?> createComponents(PluginServices services) {
206206
);
207207
}
208208

209-
var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client());
209+
var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client(), services.threadPool());
210210
// This must be done after the HttpRequestSenderFactory is created so that the services can get the
211211
// reference correctly
212212
var registry = new InferenceServiceRegistry(inferenceServices, factoryContext);
@@ -299,15 +299,17 @@ public Collection<SystemIndexDescriptor> getSystemIndexDescriptors(Settings sett
299299

300300
@Override
301301
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settingsToUse) {
302-
return List.of(
303-
new ScalingExecutorBuilder(
304-
UTILITY_THREAD_POOL_NAME,
305-
0,
306-
10,
307-
TimeValue.timeValueMinutes(10),
308-
false,
309-
"xpack.inference.utility_thread_pool"
310-
)
302+
return List.of(inferenceUtilityExecutor(settings));
303+
}
304+
305+
public static ExecutorBuilder<?> inferenceUtilityExecutor(Settings settings) {
306+
return new ScalingExecutorBuilder(
307+
UTILITY_THREAD_POOL_NAME,
308+
0,
309+
10,
310+
TimeValue.timeValueMinutes(10),
311+
false,
312+
"xpack.inference.utility_thread_pool"
311313
);
312314
}
313315

0 commit comments

Comments
 (0)