Skip to content

Commit bb7e477

Browse files
Fix inference update API calls with task_type in body or deployment_id defined (#121231) (#121564)
* Fix inference update API calls with task_type in body or deployment_id defined * Update docs/changelog/121231.yaml * Fixing test * Reuse existing deployment ID retrieval logic --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent c10b6ce commit bb7e477

File tree

6 files changed

+138
-11
lines changed

6 files changed

+138
-11
lines changed

docs/changelog/121231.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 121231
2+
summary: Fix inference update API calls with `task_type` in body or `deployment_id`
3+
defined
4+
area: Machine Learning
5+
type: bug
6+
issues: []

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,24 @@ public void testAttachToDeployment() throws IOException {
5151
var results = infer(inferenceId, List.of("washing machine"));
5252
assertNotNull(results.get("sparse_embedding"));
5353

54+
var updatedNumAllocations = randomIntBetween(1, 10);
55+
var updatedEndpointConfig = updateEndpoint(inferenceId, updatedEndpointConfig(updatedNumAllocations), TaskType.SPARSE_EMBEDDING);
56+
assertThat(
57+
updatedEndpointConfig.get("service_settings"),
58+
is(
59+
Map.of(
60+
"num_allocations",
61+
updatedNumAllocations,
62+
"num_threads",
63+
1,
64+
"model_id",
65+
"attach_to_deployment",
66+
"deployment_id",
67+
"existing_deployment"
68+
)
69+
)
70+
);
71+
5472
deleteModel(inferenceId);
5573
// assert deployment not stopped
5674
var stats = (List<Map<String, Object>>) getTrainedModelStats(modelId).get("trained_model_stats");
@@ -110,6 +128,24 @@ public void testAttachWithModelId() throws IOException {
110128
var results = infer(inferenceId, List.of("washing machine"));
111129
assertNotNull(results.get("sparse_embedding"));
112130

131+
var updatedNumAllocations = randomIntBetween(1, 10);
132+
var updatedEndpointConfig = updateEndpoint(inferenceId, updatedEndpointConfig(updatedNumAllocations), TaskType.SPARSE_EMBEDDING);
133+
assertThat(
134+
updatedEndpointConfig.get("service_settings"),
135+
is(
136+
Map.of(
137+
"num_allocations",
138+
updatedNumAllocations,
139+
"num_threads",
140+
1,
141+
"model_id",
142+
"attach_with_model_id",
143+
"deployment_id",
144+
"existing_deployment_with_model_id"
145+
)
146+
)
147+
);
148+
113149
stopMlNodeDeployment(deploymentId);
114150
}
115151

@@ -216,6 +252,16 @@ private String endpointConfig(String modelId, String deploymentId) {
216252
""", modelId, deploymentId);
217253
}
218254

255+
private String updatedEndpointConfig(int numAllocations) {
256+
return Strings.format("""
257+
{
258+
"service_settings": {
259+
"num_allocations": %d
260+
}
261+
}
262+
""", numAllocations);
263+
}
264+
219265
private Response startMlNodeDeploymemnt(String modelId, String deploymentId) throws IOException {
220266
String endPoint = "/_ml/trained_models/"
221267
+ modelId

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,11 @@ static Map<String, Object> updateEndpoint(String inferenceID, String modelConfig
238238
return putRequest(endpoint, modelConfig);
239239
}
240240

241+
static Map<String, Object> updateEndpoint(String inferenceID, String modelConfig) throws IOException {
242+
String endpoint = Strings.format("_inference/%s/_update", inferenceID);
243+
return putRequest(endpoint, modelConfig);
244+
}
245+
241246
protected Map<String, Object> putPipeline(String pipelineId, String modelId) throws IOException {
242247
String endpoint = Strings.format("_ingest/pipeline/%s", pipelineId);
243248
String body = """

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,61 @@ public void testUnifiedCompletionInference() throws Exception {
369369
}
370370
}
371371

372+
public void testUpdateEndpointWithWrongTaskTypeInURL() throws IOException {
373+
putModel("sparse_embedding_model", mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
374+
var e = expectThrows(
375+
ResponseException.class,
376+
() -> updateEndpoint(
377+
"sparse_embedding_model",
378+
updateConfig(null, randomAlphaOfLength(10), randomIntBetween(1, 10)),
379+
TaskType.TEXT_EMBEDDING
380+
)
381+
);
382+
assertThat(e.getMessage(), containsString("Task type must match the task type of the existing endpoint"));
383+
}
384+
385+
public void testUpdateEndpointWithWrongTaskTypeInBody() throws IOException {
386+
putModel("sparse_embedding_model", mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
387+
var e = expectThrows(
388+
ResponseException.class,
389+
() -> updateEndpoint(
390+
"sparse_embedding_model",
391+
updateConfig(TaskType.TEXT_EMBEDDING, randomAlphaOfLength(10), randomIntBetween(1, 10))
392+
)
393+
);
394+
assertThat(e.getMessage(), containsString("Task type must match the task type of the existing endpoint"));
395+
}
396+
397+
public void testUpdateEndpointWithTaskTypeInURL() throws IOException {
398+
testUpdateEndpoint(false, true);
399+
}
400+
401+
public void testUpdateEndpointWithTaskTypeInBody() throws IOException {
402+
testUpdateEndpoint(true, false);
403+
}
404+
405+
public void testUpdateEndpointWithTaskTypeInBodyAndURL() throws IOException {
406+
testUpdateEndpoint(true, true);
407+
}
408+
409+
@SuppressWarnings("unchecked")
410+
private void testUpdateEndpoint(boolean taskTypeInBody, boolean taskTypeInURL) throws IOException {
411+
String endpointId = "sparse_embedding_model";
412+
putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
413+
414+
int temperature = randomIntBetween(1, 10);
415+
var expectedConfig = updateConfig(taskTypeInBody ? TaskType.SPARSE_EMBEDDING : null, randomAlphaOfLength(1), temperature);
416+
Map<String, Object> updatedEndpoint;
417+
if (taskTypeInURL) {
418+
updatedEndpoint = updateEndpoint(endpointId, expectedConfig, TaskType.SPARSE_EMBEDDING);
419+
} else {
420+
updatedEndpoint = updateEndpoint(endpointId, expectedConfig);
421+
}
422+
423+
Map<String, Objects> updatedTaskSettings = (Map<String, Objects>) updatedEndpoint.get("task_settings");
424+
assertEquals(temperature, updatedTaskSettings.get("temperature"));
425+
}
426+
372427
private static Iterator<String> expectedResultsIterator(List<String> input) {
373428
// The Locale needs to be ROOT to match what the test service is going to respond with
374429
return Stream.concat(input.stream().map(s -> s.toUpperCase(Locale.ROOT)).map(InferenceCrudIT::expectedResult), Stream.of("[DONE]"))

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.elasticsearch.cluster.block.ClusterBlockLevel;
2222
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
2323
import org.elasticsearch.cluster.service.ClusterService;
24+
import org.elasticsearch.common.Strings;
2425
import org.elasticsearch.common.settings.Settings;
2526
import org.elasticsearch.common.util.concurrent.EsExecutors;
2627
import org.elasticsearch.common.xcontent.XContentHelper;
@@ -51,6 +52,7 @@
5152
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
5253
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
5354
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
55+
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel;
5456
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
5557
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings;
5658

@@ -257,22 +259,22 @@ private void updateInClusterEndpoint(
257259
ActionListener<Boolean> listener
258260
) throws IOException {
259261
// The model we are trying to update must have a trained model associated with it if it is an in-cluster deployment
260-
throwIfTrainedModelDoesntExist(request);
262+
var deploymentId = getDeploymentIdForInClusterEndpoint(existingParsedModel);
263+
throwIfTrainedModelDoesntExist(request.getInferenceEntityId(), deploymentId);
261264

262265
Map<String, Object> serviceSettings = request.getContentAsSettings().serviceSettings();
263266
if (serviceSettings != null && serviceSettings.get(NUM_ALLOCATIONS) instanceof Integer numAllocations) {
264267

265-
UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(
266-
request.getInferenceEntityId()
267-
);
268+
UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId);
268269
updateRequest.setNumberOfAllocations(numAllocations);
269270

270271
var delegate = listener.<CreateTrainedModelAssignmentAction.Response>delegateFailure((l2, response) -> {
271272
modelRegistry.updateModelTransaction(newModel, existingParsedModel, l2);
272273
});
273274

274275
logger.info(
275-
"Updating trained model deployment for inference entity [{}] with [{}] num_allocations",
276+
"Updating trained model deployment [{}] for inference entity [{}] with [{}] num_allocations",
277+
deploymentId,
276278
request.getInferenceEntityId(),
277279
numAllocations
278280
);
@@ -295,12 +297,26 @@ private boolean isInClusterService(String name) {
295297
return List.of(ElasticsearchInternalService.NAME, ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME).contains(name);
296298
}
297299

298-
private void throwIfTrainedModelDoesntExist(UpdateInferenceModelAction.Request request) throws ElasticsearchStatusException {
299-
var assignments = TrainedModelAssignmentUtils.modelAssignments(request.getInferenceEntityId(), clusterService.state());
300+
private String getDeploymentIdForInClusterEndpoint(Model model) {
301+
if (model instanceof ElasticsearchInternalModel esModel) {
302+
return esModel.mlNodeDeploymentId();
303+
} else {
304+
throw new IllegalStateException(
305+
Strings.format(
306+
"Cannot update inference endpoint [%s]. Class [%s] is not an Elasticsearch internal model",
307+
model.getInferenceEntityId(),
308+
model.getClass().getSimpleName()
309+
)
310+
);
311+
}
312+
}
313+
314+
private void throwIfTrainedModelDoesntExist(String inferenceEntityId, String deploymentId) throws ElasticsearchStatusException {
315+
var assignments = TrainedModelAssignmentUtils.modelAssignments(deploymentId, clusterService.state());
300316
if ((assignments == null || assignments.isEmpty())) {
301317
throw ExceptionsHelper.entityNotFoundException(
302318
Messages.MODEL_ID_DOES_NOT_MATCH_EXISTING_MODEL_IDS_BUT_MUST_FOR_IN_CLUSTER_SERVICE,
303-
request.getInferenceEntityId()
319+
inferenceEntityId
304320

305321
);
306322
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUpdateInferenceModelAction.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,11 @@
77

88
package org.elasticsearch.xpack.inference.rest;
99

10-
import org.elasticsearch.ElasticsearchStatusException;
1110
import org.elasticsearch.action.ActionListener;
1211
import org.elasticsearch.client.internal.node.NodeClient;
1312
import org.elasticsearch.inference.TaskType;
1413
import org.elasticsearch.rest.BaseRestHandler;
1514
import org.elasticsearch.rest.RestRequest;
16-
import org.elasticsearch.rest.RestStatus;
1715
import org.elasticsearch.rest.RestUtils;
1816
import org.elasticsearch.rest.Scope;
1917
import org.elasticsearch.rest.ServerlessScope;
@@ -48,7 +46,8 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
4846
inferenceEntityId = restRequest.param(INFERENCE_ID);
4947
taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID));
5048
} else {
51-
throw new ElasticsearchStatusException("Inference ID must be provided in the path", RestStatus.BAD_REQUEST);
49+
inferenceEntityId = restRequest.param(TASK_TYPE_OR_INFERENCE_ID);
50+
taskType = TaskType.ANY;
5251
}
5352

5453
var content = restRequest.requiredContent();

0 commit comments

Comments
 (0)