diff --git a/docs/changelog/121231.yaml b/docs/changelog/121231.yaml new file mode 100644 index 0000000000000..bd9eb934c8d08 --- /dev/null +++ b/docs/changelog/121231.yaml @@ -0,0 +1,6 @@ +pr: 121231 +summary: Fix inference update API calls with `task_type` in body or `deployment_id` + defined +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java index 0a2200ff912ac..e5eda9a71b472 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java @@ -43,6 +43,24 @@ public void testAttachToDeployment() throws IOException { var results = infer(inferenceId, List.of("washing machine")); assertNotNull(results.get("sparse_embedding")); + var updatedNumAllocations = randomIntBetween(1, 10); + var updatedEndpointConfig = updateEndpoint(inferenceId, updatedEndpointConfig(updatedNumAllocations), TaskType.SPARSE_EMBEDDING); + assertThat( + updatedEndpointConfig.get("service_settings"), + is( + Map.of( + "num_allocations", + updatedNumAllocations, + "num_threads", + 1, + "model_id", + "attach_to_deployment", + "deployment_id", + "existing_deployment" + ) + ) + ); + deleteModel(inferenceId); // assert deployment not stopped var stats = (List>) getTrainedModelStats(modelId).get("trained_model_stats"); @@ -83,6 +101,24 @@ public void testAttachWithModelId() throws IOException { var results = infer(inferenceId, List.of("washing machine")); assertNotNull(results.get("sparse_embedding")); + var updatedNumAllocations = randomIntBetween(1, 10); + var updatedEndpointConfig = updateEndpoint(inferenceId, updatedEndpointConfig(updatedNumAllocations), TaskType.SPARSE_EMBEDDING); + assertThat( + updatedEndpointConfig.get("service_settings"), + is( + Map.of( + "num_allocations", + updatedNumAllocations, + "num_threads", + 1, + "model_id", + "attach_with_model_id", + "deployment_id", + "existing_deployment_with_model_id" + ) + ) + ); + stopMlNodeDeployment(deploymentId); } @@ -189,6 +225,16 @@ private String endpointConfig(String modelId, String deploymentId) { """, modelId, deploymentId); } + private String updatedEndpointConfig(int numAllocations) { + return Strings.format(""" + { + "service_settings": { + "num_allocations": %d + } + } + """, numAllocations); + } + private Response startMlNodeDeploymemnt(String modelId, String deploymentId) throws IOException { String endPoint = "/_ml/trained_models/" + modelId diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index bb3f3e9b46c4d..950ff196e5136 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -238,6 +238,11 @@ static Map updateEndpoint(String inferenceID, String modelConfig return putRequest(endpoint, modelConfig); } + static Map updateEndpoint(String inferenceID, String modelConfig) throws IOException { + String endpoint = Strings.format("_inference/%s/_update", inferenceID); + return putRequest(endpoint, modelConfig); + } + protected Map putPipeline(String pipelineId, String modelId) throws IOException { String endpoint = Strings.format("_ingest/pipeline/%s", pipelineId); String body = """ diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index b786cd1298495..793b3f7a9a349 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -369,6 +369,61 @@ public void testUnifiedCompletionInference() throws Exception { } } + public void testUpdateEndpointWithWrongTaskTypeInURL() throws IOException { + putModel("sparse_embedding_model", mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + var e = expectThrows( + ResponseException.class, + () -> updateEndpoint( + "sparse_embedding_model", + updateConfig(null, randomAlphaOfLength(10), randomIntBetween(1, 10)), + TaskType.TEXT_EMBEDDING + ) + ); + assertThat(e.getMessage(), containsString("Task type must match the task type of the existing endpoint")); + } + + public void testUpdateEndpointWithWrongTaskTypeInBody() throws IOException { + putModel("sparse_embedding_model", mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + var e = expectThrows( + ResponseException.class, + () -> updateEndpoint( + "sparse_embedding_model", + updateConfig(TaskType.TEXT_EMBEDDING, randomAlphaOfLength(10), randomIntBetween(1, 10)) + ) + ); + assertThat(e.getMessage(), containsString("Task type must match the task type of the existing endpoint")); + } + + public void testUpdateEndpointWithTaskTypeInURL() throws IOException { + testUpdateEndpoint(false, true); + } + + public void testUpdateEndpointWithTaskTypeInBody() throws IOException { + testUpdateEndpoint(true, false); + } + + public void testUpdateEndpointWithTaskTypeInBodyAndURL() throws IOException { + testUpdateEndpoint(true, true); + } + + @SuppressWarnings("unchecked") + private void testUpdateEndpoint(boolean taskTypeInBody, boolean taskTypeInURL) throws IOException { + String endpointId = "sparse_embedding_model"; + putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + + int temperature = randomIntBetween(1, 10); + var expectedConfig = updateConfig(taskTypeInBody ? TaskType.SPARSE_EMBEDDING : null, randomAlphaOfLength(1), temperature); + Map updatedEndpoint; + if (taskTypeInURL) { + updatedEndpoint = updateEndpoint(endpointId, expectedConfig, TaskType.SPARSE_EMBEDDING); + } else { + updatedEndpoint = updateEndpoint(endpointId, expectedConfig); + } + + Map updatedTaskSettings = (Map) updatedEndpoint.get("task_settings"); + assertEquals(temperature, updatedTaskSettings.get("temperature")); + } + private static Iterator expectedResultsIterator(List input) { // The Locale needs to be ROOT to match what the test service is going to respond with return Stream.concat(input.stream().map(s -> s.toUpperCase(Locale.ROOT)).map(InferenceCrudIT::expectedResult), Stream.of("[DONE]")) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java index b857ef3068835..ed005a86d66b5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java @@ -21,6 +21,7 @@ import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.inference.InferenceService; @@ -50,6 +51,7 @@ import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings; @@ -255,14 +257,13 @@ private void updateInClusterEndpoint( ActionListener listener ) throws IOException { // The model we are trying to update must have a trained model associated with it if it is an in-cluster deployment - throwIfTrainedModelDoesntExist(request); + var deploymentId = getDeploymentIdForInClusterEndpoint(existingParsedModel); + throwIfTrainedModelDoesntExist(request.getInferenceEntityId(), deploymentId); Map serviceSettings = request.getContentAsSettings().serviceSettings(); if (serviceSettings != null && serviceSettings.get(NUM_ALLOCATIONS) instanceof Integer numAllocations) { - UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request( - request.getInferenceEntityId() - ); + UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId); updateRequest.setNumberOfAllocations(numAllocations); var delegate = listener.delegateFailure((l2, response) -> { @@ -270,7 +271,8 @@ private void updateInClusterEndpoint( }); logger.info( - "Updating trained model deployment for inference entity [{}] with [{}] num_allocations", + "Updating trained model deployment [{}] for inference entity [{}] with [{}] num_allocations", + deploymentId, request.getInferenceEntityId(), numAllocations ); @@ -293,12 +295,26 @@ private boolean isInClusterService(String name) { return List.of(ElasticsearchInternalService.NAME, ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME).contains(name); } - private void throwIfTrainedModelDoesntExist(UpdateInferenceModelAction.Request request) throws ElasticsearchStatusException { - var assignments = TrainedModelAssignmentUtils.modelAssignments(request.getInferenceEntityId(), clusterService.state()); + private String getDeploymentIdForInClusterEndpoint(Model model) { + if (model instanceof ElasticsearchInternalModel esModel) { + return esModel.mlNodeDeploymentId(); + } else { + throw new IllegalStateException( + Strings.format( + "Cannot update inference endpoint [%s]. Class [%s] is not an Elasticsearch internal model", + model.getInferenceEntityId(), + model.getClass().getSimpleName() + ) + ); + } + } + + private void throwIfTrainedModelDoesntExist(String inferenceEntityId, String deploymentId) throws ElasticsearchStatusException { + var assignments = TrainedModelAssignmentUtils.modelAssignments(deploymentId, clusterService.state()); if ((assignments == null || assignments.isEmpty())) { throw ExceptionsHelper.entityNotFoundException( Messages.MODEL_ID_DOES_NOT_MATCH_EXISTING_MODEL_IDS_BUT_MUST_FOR_IN_CLUSTER_SERVICE, - request.getInferenceEntityId() + inferenceEntityId ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUpdateInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUpdateInferenceModelAction.java index 120731a4f8e66..7b3c54c60cdcc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUpdateInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUpdateInferenceModelAction.java @@ -7,13 +7,11 @@ package org.elasticsearch.xpack.inference.rest; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestUtils; import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; @@ -48,7 +46,8 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient inferenceEntityId = restRequest.param(INFERENCE_ID); taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID)); } else { - throw new ElasticsearchStatusException("Inference ID must be provided in the path", RestStatus.BAD_REQUEST); + inferenceEntityId = restRequest.param(TASK_TYPE_OR_INFERENCE_ID); + taskType = TaskType.ANY; } var content = restRequest.requiredContent();