Skip to content

Commit afd5b9f

Browse files
Fix inference update API calls with task_type in body or deployment_id defined (#121231) (#121320)
* Fix inference update API calls with task_type in body or deployment_id defined * Update docs/changelog/121231.yaml * Fixing test * Reuse existing deployment ID retrieval logic --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent c34afe0 commit afd5b9f

File tree

6 files changed

+139
-12
lines changed

6 files changed

+139
-12
lines changed

docs/changelog/121231.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 121231
2+
summary: Fix inference update API calls with `task_type` in body or `deployment_id`
3+
defined
4+
area: Machine Learning
5+
type: bug
6+
issues: []

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,24 @@ public void testAttachToDeployment() throws IOException {
4343
var results = infer(inferenceId, List.of("washing machine"));
4444
assertNotNull(results.get("sparse_embedding"));
4545

46+
var updatedNumAllocations = randomIntBetween(1, 10);
47+
var updatedEndpointConfig = updateEndpoint(inferenceId, updatedEndpointConfig(updatedNumAllocations), TaskType.SPARSE_EMBEDDING);
48+
assertThat(
49+
updatedEndpointConfig.get("service_settings"),
50+
is(
51+
Map.of(
52+
"num_allocations",
53+
updatedNumAllocations,
54+
"num_threads",
55+
1,
56+
"model_id",
57+
"attach_to_deployment",
58+
"deployment_id",
59+
"existing_deployment"
60+
)
61+
)
62+
);
63+
4664
deleteModel(inferenceId);
4765
// assert deployment not stopped
4866
var stats = (List<Map<String, Object>>) getTrainedModelStats(modelId).get("trained_model_stats");
@@ -83,6 +101,24 @@ public void testAttachWithModelId() throws IOException {
83101
var results = infer(inferenceId, List.of("washing machine"));
84102
assertNotNull(results.get("sparse_embedding"));
85103

104+
var updatedNumAllocations = randomIntBetween(1, 10);
105+
var updatedEndpointConfig = updateEndpoint(inferenceId, updatedEndpointConfig(updatedNumAllocations), TaskType.SPARSE_EMBEDDING);
106+
assertThat(
107+
updatedEndpointConfig.get("service_settings"),
108+
is(
109+
Map.of(
110+
"num_allocations",
111+
updatedNumAllocations,
112+
"num_threads",
113+
1,
114+
"model_id",
115+
"attach_with_model_id",
116+
"deployment_id",
117+
"existing_deployment_with_model_id"
118+
)
119+
)
120+
);
121+
86122
stopMlNodeDeployment(deploymentId);
87123
}
88124

@@ -189,6 +225,16 @@ private String endpointConfig(String modelId, String deploymentId) {
189225
""", modelId, deploymentId);
190226
}
191227

228+
private String updatedEndpointConfig(int numAllocations) {
229+
return Strings.format("""
230+
{
231+
"service_settings": {
232+
"num_allocations": %d
233+
}
234+
}
235+
""", numAllocations);
236+
}
237+
192238
private Response startMlNodeDeploymemnt(String modelId, String deploymentId) throws IOException {
193239
String endPoint = "/_ml/trained_models/"
194240
+ modelId

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,11 @@ protected Map<String, Object> updateEndpoint(String inferenceID, String modelCon
235235
return putRequest(endpoint, modelConfig);
236236
}
237237

238+
static Map<String, Object> updateEndpoint(String inferenceID, String modelConfig) throws IOException {
239+
String endpoint = Strings.format("_inference/%s/_update", inferenceID);
240+
return putRequest(endpoint, modelConfig);
241+
}
242+
238243
protected Map<String, Object> putPipeline(String pipelineId, String modelId) throws IOException {
239244
String endpoint = Strings.format("_ingest/pipeline/%s", pipelineId);
240245
String body = """
@@ -266,7 +271,7 @@ protected Map<String, Object> putModel(String modelId, String modelConfig) throw
266271
return putRequest(endpoint, modelConfig);
267272
}
268273

269-
Map<String, Object> putRequest(String endpoint, String body) throws IOException {
274+
static Map<String, Object> putRequest(String endpoint, String body) throws IOException {
270275
var request = new Request("PUT", endpoint);
271276
request.setJsonEntity(body);
272277
var response = client().performRequest(request);

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,61 @@ public void testSupportedStream() throws Exception {
335335
}
336336
}
337337

338+
public void testUpdateEndpointWithWrongTaskTypeInURL() throws IOException {
339+
putModel("sparse_embedding_model", mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
340+
var e = expectThrows(
341+
ResponseException.class,
342+
() -> updateEndpoint(
343+
"sparse_embedding_model",
344+
updateConfig(null, randomAlphaOfLength(10), randomIntBetween(1, 10)),
345+
TaskType.TEXT_EMBEDDING
346+
)
347+
);
348+
assertThat(e.getMessage(), containsString("Task type must match the task type of the existing endpoint"));
349+
}
350+
351+
public void testUpdateEndpointWithWrongTaskTypeInBody() throws IOException {
352+
putModel("sparse_embedding_model", mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
353+
var e = expectThrows(
354+
ResponseException.class,
355+
() -> updateEndpoint(
356+
"sparse_embedding_model",
357+
updateConfig(TaskType.TEXT_EMBEDDING, randomAlphaOfLength(10), randomIntBetween(1, 10))
358+
)
359+
);
360+
assertThat(e.getMessage(), containsString("Task type must match the task type of the existing endpoint"));
361+
}
362+
363+
public void testUpdateEndpointWithTaskTypeInURL() throws IOException {
364+
testUpdateEndpoint(false, true);
365+
}
366+
367+
public void testUpdateEndpointWithTaskTypeInBody() throws IOException {
368+
testUpdateEndpoint(true, false);
369+
}
370+
371+
public void testUpdateEndpointWithTaskTypeInBodyAndURL() throws IOException {
372+
testUpdateEndpoint(true, true);
373+
}
374+
375+
@SuppressWarnings("unchecked")
376+
private void testUpdateEndpoint(boolean taskTypeInBody, boolean taskTypeInURL) throws IOException {
377+
String endpointId = "sparse_embedding_model";
378+
putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
379+
380+
int temperature = randomIntBetween(1, 10);
381+
var expectedConfig = updateConfig(taskTypeInBody ? TaskType.SPARSE_EMBEDDING : null, randomAlphaOfLength(1), temperature);
382+
Map<String, Object> updatedEndpoint;
383+
if (taskTypeInURL) {
384+
updatedEndpoint = updateEndpoint(endpointId, expectedConfig, TaskType.SPARSE_EMBEDDING);
385+
} else {
386+
updatedEndpoint = updateEndpoint(endpointId, expectedConfig);
387+
}
388+
389+
Map<String, Objects> updatedTaskSettings = (Map<String, Objects>) updatedEndpoint.get("task_settings");
390+
assertEquals(temperature, updatedTaskSettings.get("temperature"));
391+
}
392+
338393
public void testGetZeroModels() throws IOException {
339394
var models = getModels("_all", TaskType.RERANK);
340395
assertThat(models, empty());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.elasticsearch.cluster.block.ClusterBlockLevel;
2222
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
2323
import org.elasticsearch.cluster.service.ClusterService;
24+
import org.elasticsearch.common.Strings;
2425
import org.elasticsearch.common.settings.Settings;
2526
import org.elasticsearch.common.util.concurrent.EsExecutors;
2627
import org.elasticsearch.common.xcontent.XContentHelper;
@@ -48,6 +49,7 @@
4849
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
4950
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
5051
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
52+
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel;
5153
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
5254
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings;
5355

@@ -246,22 +248,22 @@ private void updateInClusterEndpoint(
246248
ActionListener<Boolean> listener
247249
) throws IOException {
248250
// The model we are trying to update must have a trained model associated with it if it is an in-cluster deployment
249-
throwIfTrainedModelDoesntExist(request);
251+
var deploymentId = getDeploymentIdForInClusterEndpoint(existingParsedModel);
252+
throwIfTrainedModelDoesntExist(request.getInferenceEntityId(), deploymentId);
250253

251254
Map<String, Object> serviceSettings = request.getContentAsSettings().serviceSettings();
252255
if (serviceSettings != null && serviceSettings.get(NUM_ALLOCATIONS) instanceof Integer numAllocations) {
253256

254-
UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(
255-
request.getInferenceEntityId()
256-
);
257+
UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId);
257258
updateRequest.setNumberOfAllocations(numAllocations);
258259

259260
var delegate = listener.<CreateTrainedModelAssignmentAction.Response>delegateFailure((l2, response) -> {
260261
modelRegistry.updateModelTransaction(newModel, existingParsedModel, l2);
261262
});
262263

263264
logger.info(
264-
"Updating trained model deployment for inference entity [{}] with [{}] num_allocations",
265+
"Updating trained model deployment [{}] for inference entity [{}] with [{}] num_allocations",
266+
deploymentId,
265267
request.getInferenceEntityId(),
266268
numAllocations
267269
);
@@ -284,12 +286,26 @@ private boolean isInClusterService(String name) {
284286
return List.of(ElasticsearchInternalService.NAME, ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME).contains(name);
285287
}
286288

287-
private void throwIfTrainedModelDoesntExist(UpdateInferenceModelAction.Request request) throws ElasticsearchStatusException {
288-
var assignments = TrainedModelAssignmentUtils.modelAssignments(request.getInferenceEntityId(), clusterService.state());
289+
private String getDeploymentIdForInClusterEndpoint(Model model) {
290+
if (model instanceof ElasticsearchInternalModel esModel) {
291+
return esModel.mlNodeDeploymentId();
292+
} else {
293+
throw new IllegalStateException(
294+
Strings.format(
295+
"Cannot update inference endpoint [%s]. Class [%s] is not an Elasticsearch internal model",
296+
model.getInferenceEntityId(),
297+
model.getClass().getSimpleName()
298+
)
299+
);
300+
}
301+
}
302+
303+
private void throwIfTrainedModelDoesntExist(String inferenceEntityId, String deploymentId) throws ElasticsearchStatusException {
304+
var assignments = TrainedModelAssignmentUtils.modelAssignments(deploymentId, clusterService.state());
289305
if ((assignments == null || assignments.isEmpty())) {
290306
throw ExceptionsHelper.entityNotFoundException(
291307
Messages.MODEL_ID_DOES_NOT_MATCH_EXISTING_MODEL_IDS_BUT_MUST_FOR_IN_CLUSTER_SERVICE,
292-
request.getInferenceEntityId()
308+
inferenceEntityId
293309

294310
);
295311
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUpdateInferenceModelAction.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@
77

88
package org.elasticsearch.xpack.inference.rest;
99

10-
import org.elasticsearch.ElasticsearchStatusException;
1110
import org.elasticsearch.client.internal.node.NodeClient;
1211
import org.elasticsearch.inference.TaskType;
1312
import org.elasticsearch.rest.BaseRestHandler;
1413
import org.elasticsearch.rest.RestRequest;
15-
import org.elasticsearch.rest.RestStatus;
1614
import org.elasticsearch.rest.RestUtils;
1715
import org.elasticsearch.rest.Scope;
1816
import org.elasticsearch.rest.ServerlessScope;
@@ -47,7 +45,8 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
4745
inferenceEntityId = restRequest.param(INFERENCE_ID);
4846
taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID));
4947
} else {
50-
throw new ElasticsearchStatusException("Inference ID must be provided in the path", RestStatus.BAD_REQUEST);
48+
inferenceEntityId = restRequest.param(TASK_TYPE_OR_INFERENCE_ID);
49+
taskType = TaskType.ANY;
5150
}
5251

5352
var request = new UpdateInferenceModelAction.Request(

0 commit comments

Comments
 (0)