Skip to content

Commit 630af38

Browse files
authored
[ML] Create an ml node inference endpoint referencing an existing deployment (#114750)
1 parent 5e59ab5 commit 630af38

File tree

14 files changed

+482
-63
lines changed

14 files changed

+482
-63
lines changed

docs/changelog/114750.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 114750
2+
summary: Create an ml node inference endpoint referencing an existing model
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ static TransportVersion def(int id) {
244244
public static final TransportVersion OPT_IN_ESQL_CCS_EXECUTION_INFO = def(8_768_00_0);
245245
public static final TransportVersion QUERY_RULE_TEST_API = def(8_769_00_0);
246246
public static final TransportVersion ESQL_PER_AGGREGATE_FILTER = def(8_770_00_0);
247+
public static final TransportVersion ML_INFERENCE_ATTACH_TO_EXISTSING_DEPLOYMENT = def(8_771_00_0);
247248

248249
/*
249250
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,10 @@ void chunkedInfer(
129129
/**
130130
* Stop the model deployment.
131131
* The default action does nothing except acknowledge the request (true).
132-
* @param modelId The ID of the model to be stopped
132+
* @param unparsedModel The unparsed model configuration
133133
* @param listener The listener
134134
*/
135-
default void stop(String modelId, ActionListener<Boolean> listener) {
135+
default void stop(UnparsedModel unparsedModel, ActionListener<Boolean> listener) {
136136
listener.onResponse(true);
137137
}
138138

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import org.elasticsearch.client.Request;
11+
import org.elasticsearch.client.Response;
12+
import org.elasticsearch.client.ResponseException;
13+
import org.elasticsearch.core.Strings;
14+
import org.elasticsearch.inference.TaskType;
15+
16+
import java.io.IOException;
17+
import java.util.List;
18+
import java.util.Map;
19+
20+
import static org.hamcrest.Matchers.containsString;
21+
import static org.hamcrest.Matchers.is;
22+
23+
public class CreateFromDeploymentIT extends InferenceBaseRestTest {
24+
25+
@SuppressWarnings("unchecked")
26+
public void testAttachToDeployment() throws IOException {
27+
var modelId = "attach_to_deployment";
28+
var deploymentId = "existing_deployment";
29+
30+
CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
31+
var response = startMlNodeDeploymemnt(modelId, deploymentId);
32+
assertOkOrCreated(response);
33+
34+
var inferenceId = "inference_on_existing_deployment";
35+
var putModel = putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING);
36+
var serviceSettings = putModel.get("service_settings");
37+
assertThat(
38+
putModel.toString(),
39+
serviceSettings,
40+
is(Map.of("num_allocations", 1, "num_threads", 1, "model_id", "attach_to_deployment", "deployment_id", "existing_deployment"))
41+
);
42+
43+
var results = infer(inferenceId, List.of("washing machine"));
44+
assertNotNull(results.get("sparse_embedding"));
45+
46+
deleteModel(inferenceId);
47+
// assert deployment not stopped
48+
var stats = (List<Map<String, Object>>) getTrainedModelStats(modelId).get("trained_model_stats");
49+
var deploymentStats = stats.get(0).get("deployment_stats");
50+
assertNotNull(stats.toString(), deploymentStats);
51+
52+
stopMlNodeDeployment(deploymentId);
53+
}
54+
55+
public void testAttachWithModelId() throws IOException {
56+
var modelId = "attach_with_model_id";
57+
var deploymentId = "existing_deployment_with_model_id";
58+
59+
CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
60+
var response = startMlNodeDeploymemnt(modelId, deploymentId);
61+
assertOkOrCreated(response);
62+
63+
var inferenceId = "inference_on_existing_deployment";
64+
var putModel = putModel(inferenceId, endpointConfig(modelId, deploymentId), TaskType.SPARSE_EMBEDDING);
65+
var serviceSettings = putModel.get("service_settings");
66+
assertThat(
67+
putModel.toString(),
68+
serviceSettings,
69+
is(
70+
Map.of(
71+
"num_allocations",
72+
1,
73+
"num_threads",
74+
1,
75+
"model_id",
76+
"attach_with_model_id",
77+
"deployment_id",
78+
"existing_deployment_with_model_id"
79+
)
80+
)
81+
);
82+
83+
var results = infer(inferenceId, List.of("washing machine"));
84+
assertNotNull(results.get("sparse_embedding"));
85+
86+
stopMlNodeDeployment(deploymentId);
87+
}
88+
89+
public void testModelIdDoesNotMatch() throws IOException {
90+
var modelId = "attach_with_model_id";
91+
var deploymentId = "existing_deployment_with_model_id";
92+
var aDifferentModelId = "not_the_same_as_the_one_used_in_the_deployment";
93+
94+
CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
95+
var response = startMlNodeDeploymemnt(modelId, deploymentId);
96+
assertOkOrCreated(response);
97+
98+
var inferenceId = "inference_on_existing_deployment";
99+
var e = expectThrows(
100+
ResponseException.class,
101+
() -> putModel(inferenceId, endpointConfig(aDifferentModelId, deploymentId), TaskType.SPARSE_EMBEDDING)
102+
);
103+
assertThat(
104+
e.getMessage(),
105+
containsString(
106+
"Deployment [existing_deployment_with_model_id] uses model [attach_with_model_id] "
107+
+ "which does not match the model [not_the_same_as_the_one_used_in_the_deployment] in the request."
108+
)
109+
);
110+
}
111+
112+
private String endpointConfig(String deploymentId) {
113+
return Strings.format("""
114+
{
115+
"service": "elasticsearch",
116+
"service_settings": {
117+
"deployment_id": "%s"
118+
}
119+
}
120+
""", deploymentId);
121+
}
122+
123+
private String endpointConfig(String modelId, String deploymentId) {
124+
return Strings.format("""
125+
{
126+
"service": "elasticsearch",
127+
"service_settings": {
128+
"model_id": "%s",
129+
"deployment_id": "%s"
130+
}
131+
}
132+
""", modelId, deploymentId);
133+
}
134+
135+
private Response startMlNodeDeploymemnt(String modelId, String deploymentId) throws IOException {
136+
String endPoint = "/_ml/trained_models/"
137+
+ modelId
138+
+ "/deployment/_start?timeout=10s&wait_for=started"
139+
+ "&threads_per_allocation=1"
140+
+ "&number_of_allocations=1";
141+
142+
if (deploymentId != null) {
143+
endPoint = endPoint + "&deployment_id=" + deploymentId;
144+
}
145+
146+
Request request = new Request("POST", endPoint);
147+
return client().performRequest(request);
148+
}
149+
150+
protected void stopMlNodeDeployment(String deploymentId) throws IOException {
151+
String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
152+
Request request = new Request("POST", endpoint);
153+
request.addParameter("force", "true");
154+
client().performRequest(request);
155+
}
156+
157+
protected Map<String, Object> getTrainedModelStats(String modelId) throws IOException {
158+
Request request = new Request("GET", "/_ml/trained_models/" + modelId + "/_stats");
159+
return entityAsMap(client().performRequest(request));
160+
}
161+
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference;
99

1010
import org.elasticsearch.client.Request;
11+
import org.elasticsearch.client.RestClient;
1112
import org.elasticsearch.core.Strings;
1213
import org.elasticsearch.inference.TaskType;
1314

@@ -65,11 +66,12 @@ public class CustomElandModelIT extends InferenceBaseRestTest {
6566
public void testSparse() throws IOException {
6667
String modelId = "custom-text-expansion-model";
6768

68-
createTextExpansionModel(modelId);
69-
putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE);
69+
createTextExpansionModel(modelId, client());
70+
putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE, client());
7071
putVocabulary(
7172
List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"),
72-
modelId
73+
modelId,
74+
client()
7375
);
7476

7577
var inferenceConfig = """
@@ -90,7 +92,7 @@ public void testSparse() throws IOException {
9092
assertNotNull(results.get("sparse_embedding"));
9193
}
9294

93-
protected void createTextExpansionModel(String modelId) throws IOException {
95+
static void createTextExpansionModel(String modelId, RestClient client) throws IOException {
9496
// with_special_tokens: false for this test with limited vocab
9597
Request request = new Request("PUT", "/_ml/trained_models/" + modelId);
9698
request.setJsonEntity("""
@@ -107,10 +109,10 @@ protected void createTextExpansionModel(String modelId) throws IOException {
107109
}
108110
}
109111
}""");
110-
client().performRequest(request);
112+
client.performRequest(request);
111113
}
112114

113-
protected void putVocabulary(List<String> vocabulary, String modelId) throws IOException {
115+
static void putVocabulary(List<String> vocabulary, String modelId, RestClient client) throws IOException {
114116
List<String> vocabularyWithPad = new ArrayList<>();
115117
vocabularyWithPad.add("[PAD]");
116118
vocabularyWithPad.add("[UNK]");
@@ -121,14 +123,27 @@ protected void putVocabulary(List<String> vocabulary, String modelId) throws IOE
121123
request.setJsonEntity(Strings.format("""
122124
{ "vocabulary": [%s] }
123125
""", quotedWords));
124-
client().performRequest(request);
126+
client.performRequest(request);
125127
}
126128

127-
protected void putModelDefinition(String modelId, String base64EncodedModel, long unencodedModelSize) throws IOException {
129+
static void putModelDefinition(String modelId, String base64EncodedModel, long unencodedModelSize, RestClient client)
130+
throws IOException {
128131
Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/definition/0");
129132
String body = Strings.format("""
130133
{"total_definition_length":%s,"definition": "%s","total_parts": 1}""", unencodedModelSize, base64EncodedModel);
131134
request.setJsonEntity(body);
132-
client().performRequest(request);
135+
client.performRequest(request);
133136
}
137+
138+
// Create the model including definition and vocab
139+
static void createMlNodeTextExpansionModel(String modelId, RestClient client) throws IOException {
140+
createTextExpansionModel(modelId, client);
141+
putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE, client);
142+
putVocabulary(
143+
List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"),
144+
modelId,
145+
client
146+
);
147+
}
148+
134149
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ protected void putSemanticText(String endpointId, String searchEndpointId, Strin
207207
}
208208

209209
protected Map<String, Object> putModel(String modelId, String modelConfig, TaskType taskType) throws IOException {
210-
String endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
210+
String endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId);
211211
return putRequest(endpoint, modelConfig);
212212
}
213213

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
package org.elasticsearch.xpack.inference.action;
1111

12-
import org.apache.logging.log4j.LogManager;
13-
import org.apache.logging.log4j.Logger;
1412
import org.elasticsearch.ElasticsearchStatusException;
1513
import org.elasticsearch.action.ActionListener;
1614
import org.elasticsearch.action.ActionRunnable;
@@ -47,7 +45,6 @@ public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeA
4745

4846
private final ModelRegistry modelRegistry;
4947
private final InferenceServiceRegistry serviceRegistry;
50-
private static final Logger logger = LogManager.getLogger(TransportDeleteInferenceEndpointAction.class);
5148
private final Executor executor;
5249

5350
@Inject
@@ -118,7 +115,7 @@ private void doExecuteForked(
118115

119116
var service = serviceRegistry.getService(unparsedModel.service());
120117
if (service.isPresent()) {
121-
service.get().stop(request.getInferenceEndpointId(), listener);
118+
service.get().stop(unparsedModel, listener);
122119
} else {
123120
listener.onFailure(
124121
new ElasticsearchStatusException(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.inference.InputType;
2323
import org.elasticsearch.inference.Model;
2424
import org.elasticsearch.inference.TaskType;
25+
import org.elasticsearch.inference.UnparsedModel;
2526
import org.elasticsearch.xpack.core.ClientHelper;
2627
import org.elasticsearch.xpack.core.ml.MachineLearningField;
2728
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
@@ -98,6 +99,12 @@ public void start(Model model, ActionListener<Boolean> finalListener) {
9899
return;
99100
}
100101

102+
if (esModel.usesExistingDeployment()) {
103+
// don't start a deployment
104+
finalListener.onResponse(Boolean.TRUE);
105+
return;
106+
}
107+
101108
SubscribableListener.<Boolean>newForked(forkedListener -> { isBuiltinModelPut(model, forkedListener); })
102109
.<Boolean>andThen((l, modelConfigExists) -> {
103110
if (modelConfigExists == false) {
@@ -119,14 +126,28 @@ public void start(Model model, ActionListener<Boolean> finalListener) {
119126
}
120127

121128
@Override
122-
public void stop(String inferenceEntityId, ActionListener<Boolean> listener) {
123-
var request = new StopTrainedModelDeploymentAction.Request(inferenceEntityId);
124-
request.setForce(true);
125-
client.execute(
126-
StopTrainedModelDeploymentAction.INSTANCE,
127-
request,
128-
listener.delegateFailureAndWrap((delegatedResponseListener, response) -> delegatedResponseListener.onResponse(Boolean.TRUE))
129-
);
129+
public void stop(UnparsedModel unparsedModel, ActionListener<Boolean> listener) {
130+
131+
var model = parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
132+
if (model instanceof ElasticsearchInternalModel esModel) {
133+
134+
var serviceSettings = esModel.getServiceSettings();
135+
if (serviceSettings.getDeploymentId() != null) {
136+
// configured with an existing deployment so do not stop it
137+
listener.onResponse(Boolean.TRUE);
138+
return;
139+
}
140+
141+
var request = new StopTrainedModelDeploymentAction.Request(esModel.mlNodeDeploymentId());
142+
request.setForce(true);
143+
client.execute(
144+
StopTrainedModelDeploymentAction.INSTANCE,
145+
request,
146+
listener.delegateFailureAndWrap((delegatedResponseListener, response) -> delegatedResponseListener.onResponse(Boolean.TRUE))
147+
);
148+
} else {
149+
listener.onFailure(notElasticsearchModelException(model));
150+
}
130151
}
131152

132153
protected static IllegalStateException notElasticsearchModelException(Model model) {

0 commit comments

Comments
 (0)