diff --git a/docs/changelog/114457.yaml b/docs/changelog/114457.yaml new file mode 100644 index 0000000000000..9558c41852f69 --- /dev/null +++ b/docs/changelog/114457.yaml @@ -0,0 +1,6 @@ +pr: 114457 +summary: "[Inference API] Introduce Update API to change some aspects of existing\ + \ inference endpoints" +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/inference/EmptySecretSettings.java b/server/src/main/java/org/elasticsearch/inference/EmptySecretSettings.java index 0e5b3a555b800..9c666bd4a35f5 100644 --- a/server/src/main/java/org/elasticsearch/inference/EmptySecretSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/EmptySecretSettings.java @@ -16,6 +16,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.Map; /** * This class defines an empty secret settings object. This is useful for services that do not have any secret settings. @@ -48,4 +49,9 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException {} + + @Override + public SecretSettings newSecretSettings(Map newSecrets) { + return INSTANCE; + } } diff --git a/server/src/main/java/org/elasticsearch/inference/EmptyTaskSettings.java b/server/src/main/java/org/elasticsearch/inference/EmptyTaskSettings.java index 0c863932c6afe..cba0282f7fed8 100644 --- a/server/src/main/java/org/elasticsearch/inference/EmptyTaskSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/EmptyTaskSettings.java @@ -16,6 +16,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.Map; /** * This class defines an empty task settings object. This is useful for services that do not have any task settings. @@ -53,4 +54,9 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException {} + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + return INSTANCE; + } } diff --git a/server/src/main/java/org/elasticsearch/inference/SecretSettings.java b/server/src/main/java/org/elasticsearch/inference/SecretSettings.java index e2c0c8b58c69b..90ca92bb0e2ef 100644 --- a/server/src/main/java/org/elasticsearch/inference/SecretSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/SecretSettings.java @@ -12,6 +12,9 @@ import org.elasticsearch.common.io.stream.VersionedNamedWriteable; import org.elasticsearch.xcontent.ToXContentObject; +import java.util.Map; + public interface SecretSettings extends ToXContentObject, VersionedNamedWriteable { + SecretSettings newSecretSettings(Map newSecrets); } diff --git a/server/src/main/java/org/elasticsearch/inference/TaskSettings.java b/server/src/main/java/org/elasticsearch/inference/TaskSettings.java index 9862abce2332c..7dd20688245ba 100644 --- a/server/src/main/java/org/elasticsearch/inference/TaskSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/TaskSettings.java @@ -12,6 +12,11 @@ import org.elasticsearch.common.io.stream.VersionedNamedWriteable; import org.elasticsearch.xcontent.ToXContentObject; +import java.util.Map; + public interface TaskSettings extends ToXContentObject, VersionedNamedWriteable { + boolean isEmpty(); + + TaskSettings updatedTaskSettings(Map newSettings); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UpdateInferenceModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UpdateInferenceModelAction.java new file mode 100644 index 0000000000000..cc59ae890467b --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UpdateInferenceModelAction.java @@ -0,0 +1,278 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.MlStrings; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.inference.ModelConfigurations.SERVICE_SETTINGS; +import static org.elasticsearch.inference.ModelConfigurations.TASK_SETTINGS; + +public class UpdateInferenceModelAction extends ActionType { + + public static final UpdateInferenceModelAction INSTANCE = new UpdateInferenceModelAction(); + public static final String NAME = "cluster:admin/xpack/inference/update"; + + public UpdateInferenceModelAction() { + super(NAME); + } + + public record Settings( + @Nullable Map serviceSettings, + @Nullable Map taskSettings, + @Nullable TaskType taskType + ) {} + + public static class Request extends AcknowledgedRequest { + + private final String inferenceEntityId; + private final BytesReference content; + private final XContentType contentType; + private final TaskType taskType; + private Settings settings; + + public Request(String inferenceEntityId, BytesReference content, XContentType contentType, TaskType taskType, TimeValue timeout) { + super(timeout, DEFAULT_ACK_TIMEOUT); + this.inferenceEntityId = inferenceEntityId; + this.content = content; + this.contentType = contentType; + this.taskType = taskType; + } + + public Request(StreamInput in) throws IOException { + super(in); + this.inferenceEntityId = in.readString(); + this.content = in.readBytesReference(); + this.taskType = TaskType.fromStream(in); + this.contentType = in.readEnum(XContentType.class); + } + + public String getInferenceEntityId() { + return inferenceEntityId; + } + + public TaskType getTaskType() { + return taskType; + } + + /** + * The body of the request. + * For in-cluster models, this is expected to contain some of the following: + * "number_of_allocations": `an integer` + * + * For third-party services, this is expected to contain: + * "service_settings": { + * "api_key": `a string` // service settings can only contain an api key + * } + * "task_settings": { a map of settings } + * + */ + public BytesReference getContent() { + return content; + } + + /** + * The body of the request as a map. + * The map is validated such that only allowed fields are present. + * If any fields in the body are not on the allow list, this function will throw an exception. + */ + public Settings getContentAsSettings() { + if (settings == null) { // settings is deterministic on content, so we only need to compute it once + Map unvalidatedMap = XContentHelper.convertToMap(content, false, contentType).v2(); + Map serviceSettings = new HashMap<>(); + Map taskSettings = new HashMap<>(); + TaskType taskType = null; + + if (unvalidatedMap.isEmpty()) { + throw new ElasticsearchStatusException("Request body is empty", RestStatus.BAD_REQUEST); + } + + if (unvalidatedMap.containsKey("task_type")) { + if (unvalidatedMap.get("task_type") instanceof String taskTypeString) { + taskType = TaskType.fromStringOrStatusException(taskTypeString); + } else { + throw new ElasticsearchStatusException( + "Failed to parse [task_type] in update request [{}]", + RestStatus.INTERNAL_SERVER_ERROR, + unvalidatedMap.toString() + ); + } + unvalidatedMap.remove("task_type"); + } + + if (unvalidatedMap.containsKey(SERVICE_SETTINGS)) { + if (unvalidatedMap.get(SERVICE_SETTINGS) instanceof Map tempMap) { + for (Map.Entry entry : (tempMap).entrySet()) { + if (entry.getKey() instanceof String key && entry.getValue() instanceof Object value) { + serviceSettings.put(key, value); + } else { + throw new ElasticsearchStatusException( + "Failed to parse update request [{}]", + RestStatus.INTERNAL_SERVER_ERROR, + unvalidatedMap.toString() + ); + } + } + unvalidatedMap.remove(SERVICE_SETTINGS); + } else { + throw new ElasticsearchStatusException( + "Unable to parse service settings in the request [{}]", + RestStatus.BAD_REQUEST, + unvalidatedMap.toString() + ); + } + } + + if (unvalidatedMap.containsKey(TASK_SETTINGS)) { + if (unvalidatedMap.get(TASK_SETTINGS) instanceof Map tempMap) { + for (Map.Entry entry : (tempMap).entrySet()) { + if (entry.getKey() instanceof String key && entry.getValue() instanceof Object value) { + taskSettings.put(key, value); + } else { + throw new ElasticsearchStatusException( + "Failed to parse update request [{}]", + RestStatus.INTERNAL_SERVER_ERROR, + unvalidatedMap.toString() + ); + } + } + unvalidatedMap.remove(TASK_SETTINGS); + } else { + throw new ElasticsearchStatusException( + "Unable to parse task settings in the request [{}]", + RestStatus.BAD_REQUEST, + unvalidatedMap.toString() + ); + } + } + + if (unvalidatedMap.isEmpty() == false) { + throw new ElasticsearchStatusException( + "Request contained fields which cannot be updated, remove these fields and try again [{}]", + RestStatus.BAD_REQUEST, + unvalidatedMap.toString() + ); + } + + this.settings = new Settings( + serviceSettings.isEmpty() == false ? Collections.unmodifiableMap(serviceSettings) : null, + taskSettings.isEmpty() == false ? Collections.unmodifiableMap(taskSettings) : null, + taskType + ); + } + return this.settings; + } + + public XContentType getContentType() { + return contentType; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(inferenceEntityId); + taskType.writeTo(out); + out.writeBytesReference(content); + XContentHelper.writeTo(out, contentType); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = new ActionRequestValidationException(); + if (MlStrings.isValidId(this.inferenceEntityId) == false) { + validationException.addValidationError(Messages.getMessage(Messages.INVALID_ID, "inference_id", this.inferenceEntityId)); + } + + if (validationException.validationErrors().isEmpty() == false) { + return validationException; + } else { + return null; + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(inferenceEntityId, request.inferenceEntityId) + && Objects.equals(content, request.content) + && contentType == request.contentType + && taskType == request.taskType; + } + + @Override + public int hashCode() { + return Objects.hash(inferenceEntityId, content, contentType, taskType); + } + } + + public static class Response extends ActionResponse implements ToXContentObject { + + private final ModelConfigurations model; + + public Response(ModelConfigurations model) { + this.model = model; + } + + public Response(StreamInput in) throws IOException { + super(in); + model = new ModelConfigurations(in); + } + + public ModelConfigurations getModel() { + return model; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + model.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return model.toFilteredXContent(builder, params); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Response response = (Response) o; + return Objects.equals(model, response.model); + } + + @Override + public int hashCode() { + return Objects.hash(model); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 6ebed55451ae7..9f9def6a0678d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -281,6 +281,9 @@ public final class Messages { public static final String FIELD_CANNOT_BE_NULL = "Field [{0}] cannot be null"; public static final String MODEL_ID_MATCHES_EXISTING_MODEL_IDS_BUT_MUST_NOT = "Model IDs must be unique. Requested model ID [{}] matches existing model IDs but must not."; + public static final String MODEL_ID_DOES_NOT_MATCH_EXISTING_MODEL_IDS_BUT_MUST_FOR_IN_CLUSTER_SERVICE = + "Requested model ID [{}] does not have a matching trained model and thus cannot be updated."; + public static final String INFERENCE_ENTITY_NON_EXISTANT_NO_UPDATE = "The inference endpoint [{}] does not exist and cannot be updated"; private Messages() {} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java index fb75d95aeed1b..73e3c31297fbf 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java @@ -93,6 +93,10 @@ public static ElasticsearchStatusException badRequestException(String msg, Objec return new ElasticsearchStatusException(msg, RestStatus.BAD_REQUEST, args); } + public static ElasticsearchStatusException entityNotFoundException(String msg, Object... args) { + return new ElasticsearchStatusException(msg, RestStatus.NOT_FOUND, args); + } + public static ElasticsearchStatusException taskOperationFailureToStatusException(TaskOperationFailure failure) { return new ElasticsearchStatusException(failure.getCause().getMessage(), failure.getStatus(), failure.getCause()); } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index f82b6f155c0a0..3ca6b45c2948e 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -81,6 +81,21 @@ static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody) { """, taskType); } + static String updateConfig(@Nullable TaskType taskTypeInBody, String apiKey, int temperature) { + var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\","; + return Strings.format(""" + { + %s + "service_settings": { + "api_key": "%s" + }, + "task_settings": { + "temperature": %d + } + } + """, taskType, apiKey, temperature); + } + static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody) { var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\","; return Strings.format(""" @@ -196,6 +211,11 @@ protected Map putModel(String modelId, String modelConfig, TaskT return putRequest(endpoint, modelConfig); } + protected Map updateEndpoint(String inferenceID, String modelConfig, TaskType taskType) throws IOException { + String endpoint = Strings.format("_inference/%s/%s/_update", taskType, inferenceID); + return putRequest(endpoint, modelConfig); + } + protected Map putPipeline(String pipelineId, String modelId) throws IOException { String endpoint = Strings.format("_ingest/pipeline/%s", pipelineId); String body = """ diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 5a84fd8985504..98c8d43707219 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -16,6 +16,8 @@ import java.io.IOException; import java.util.List; +import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.function.Function; import java.util.stream.IntStream; @@ -29,7 +31,7 @@ public class InferenceCrudIT extends InferenceBaseRestTest { @SuppressWarnings("unchecked") - public void testGet() throws IOException { + public void testCRUD() throws IOException { for (int i = 0; i < 5; i++) { putModel("se_model_" + i, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); } @@ -53,11 +55,29 @@ public void testGet() throws IOException { for (var denseModel : getDenseModels) { assertEquals("text_embedding", denseModel.get("task_type")); } - - var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING); - assertThat(singleModel, hasSize(1)); - assertEquals("se_model_1", singleModel.get(0).get("inference_id")); - + String oldApiKey; + { + var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING); + assertThat(singleModel, hasSize(1)); + assertEquals("se_model_1", singleModel.get(0).get("inference_id")); + oldApiKey = (String) singleModel.get(0).get("api_key"); + } + var newApiKey = randomAlphaOfLength(10); + int temperature = randomIntBetween(1, 10); + Map updatedEndpoint = updateEndpoint( + "se_model_1", + updateConfig(TaskType.SPARSE_EMBEDDING, newApiKey, temperature), + TaskType.SPARSE_EMBEDDING + ); + Map updatedTaskSettings = (Map) updatedEndpoint.get("task_settings"); + assertEquals(temperature, updatedTaskSettings.get("temperature")); + { + var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING); + assertThat(singleModel, hasSize(1)); + assertEquals("se_model_1", singleModel.get(0).get("inference_id")); + assertNotEquals(oldApiKey, newApiKey); + assertEquals(updatedEndpoint, singleModel.get(0)); + } for (int i = 0; i < 5; i++) { deleteModel("se_model_" + i, TaskType.SPARSE_EMBEDDING); } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index 02dfff1b5c2e6..6496bcdd89f21 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -163,6 +163,11 @@ public String getWriteableName() { public TransportVersion getMinimalSupportedVersion() { return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + return fromMap(new HashMap<>(newSettings)); + } } public record TestSecretSettings(String apiKey) implements SecretSettings { @@ -211,5 +216,10 @@ public String getWriteableName() { public TransportVersion getMinimalSupportedVersion() { return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests } + + @Override + public SecretSettings newSecretSettings(Map newSecrets) { + return TestSecretSettings.fromMap(new HashMap<>(newSecrets)); + } } } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 1a2f0fb6a1137..a76c4303268e4 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -600,6 +600,10 @@ public void writeTo(StreamOutput out) throws IOException { public boolean isEmpty() { return true; } + + public TaskSettings updatedTaskSettings(Map newSettings) { + return this; + } } record TestSecretSettings(String key) implements SecretSettings { @@ -625,6 +629,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } + + @Override + public SecretSettings newSecretSettings(Map newSecrets) { + return new TestSecretSettings(newSecrets.get("secret").toString()); + } } TestModelOfAnyKind(String inferenceEntityId, TaskType taskType, String service) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 927fd94809886..d251120980e0b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -47,12 +47,14 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportDeleteInferenceEndpointAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceDiagnosticsAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportInferenceAction; import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction; import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; +import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction; import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender; @@ -76,6 +78,7 @@ import org.elasticsearch.xpack.inference.rest.RestInferenceAction; import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestStreamInferenceAction; +import org.elasticsearch.xpack.inference.rest.RestUpdateInferenceModelAction; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockService; @@ -149,6 +152,7 @@ public InferencePlugin(Settings settings) { new ActionHandler<>(InferenceAction.INSTANCE, TransportInferenceAction.class), new ActionHandler<>(GetInferenceModelAction.INSTANCE, TransportGetInferenceModelAction.class), new ActionHandler<>(PutInferenceModelAction.INSTANCE, TransportPutInferenceModelAction.class), + new ActionHandler<>(UpdateInferenceModelAction.INSTANCE, TransportUpdateInferenceModelAction.class), new ActionHandler<>(DeleteInferenceEndpointAction.INSTANCE, TransportDeleteInferenceEndpointAction.class), new ActionHandler<>(XPackUsageFeatureAction.INFERENCE, TransportInferenceUsageAction.class), new ActionHandler<>(GetInferenceDiagnosticsAction.INSTANCE, TransportGetInferenceDiagnosticsAction.class) @@ -172,6 +176,7 @@ public List getRestHandlers( new RestStreamInferenceAction(), new RestGetInferenceModelAction(), new RestPutInferenceModelAction(), + new RestUpdateInferenceModelAction(), new RestDeleteInferenceEndpointAction(), new RestGetInferenceDiagnosticsAction() ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 49d65b6e0dc59..64eeed82ee1b9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -41,6 +41,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import java.io.IOException; @@ -100,7 +101,7 @@ protected void masterOperation( ActionListener listener ) throws Exception { var requestAsMap = requestToMap(request); - var resolvedTaskType = resolveTaskType(request.getTaskType(), (String) requestAsMap.remove(TaskType.NAME)); + var resolvedTaskType = ServiceUtils.resolveTaskType(request.getTaskType(), (String) requestAsMap.remove(TaskType.NAME)); String serviceName = (String) requestAsMap.remove(ModelConfigurations.SERVICE); if (serviceName == null) { @@ -227,37 +228,4 @@ protected ClusterBlockException checkBlock(PutInferenceModelAction.Request reque return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); } - /** - * task_type can be specified as either a URL parameter or in the - * request body. Resolve which to use or throw if the settings are - * inconsistent - * @param urlTaskType Taken from the URL parameter. ANY means not specified. - * @param bodyTaskType Taken from the request body. Maybe null - * @return The resolved task type - */ - static TaskType resolveTaskType(TaskType urlTaskType, String bodyTaskType) { - if (bodyTaskType == null) { - if (urlTaskType == TaskType.ANY) { - throw new ElasticsearchStatusException("model is missing required setting [task_type]", RestStatus.BAD_REQUEST); - } else { - return urlTaskType; - } - } - - TaskType parsedBodyTask = TaskType.fromStringOrStatusException(bodyTaskType); - if (parsedBodyTask == TaskType.ANY) { - throw new ElasticsearchStatusException("task_type [any] is not valid type for inference", RestStatus.BAD_REQUEST); - } - - if (parsedBodyTask.isAnyOrSame(urlTaskType) == false) { - throw new ElasticsearchStatusException( - "Cannot resolve conflicting task_type parameter in the request URL [{}] and the request body [{}]", - RestStatus.BAD_REQUEST, - urlTaskType.toString(), - bodyTaskType - ); - } - - return parsedBodyTask; - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java new file mode 100644 index 0000000000000..03a88e5228fa8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java @@ -0,0 +1,328 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.resolveTaskType; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS; + +public class TransportUpdateInferenceModelAction extends TransportMasterNodeAction< + UpdateInferenceModelAction.Request, + UpdateInferenceModelAction.Response> { + + private static final Logger logger = LogManager.getLogger(TransportUpdateInferenceModelAction.class); + + private final ModelRegistry modelRegistry; + private final InferenceServiceRegistry serviceRegistry; + private final Client client; + + @Inject + public TransportUpdateInferenceModelAction( + TransportService transportService, + ClusterService clusterService, + ThreadPool threadPool, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + Client client, + Settings settings + ) { + super( + UpdateInferenceModelAction.NAME, + transportService, + clusterService, + threadPool, + actionFilters, + UpdateInferenceModelAction.Request::new, + indexNameExpressionResolver, + UpdateInferenceModelAction.Response::new, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + this.modelRegistry = modelRegistry; + this.serviceRegistry = serviceRegistry; + this.client = client; + } + + @Override + protected void masterOperation( + Task task, + UpdateInferenceModelAction.Request request, + ClusterState state, + ActionListener masterListener + ) { + var bodyTaskType = request.getContentAsSettings().taskType(); + var resolvedTaskType = resolveTaskType(request.getTaskType(), bodyTaskType != null ? bodyTaskType.toString() : null); + + AtomicReference service = new AtomicReference<>(); + + var inferenceEntityId = request.getInferenceEntityId(); + + SubscribableListener.newForked(listener -> { checkEndpointExists(inferenceEntityId, listener); }) + .andThen((listener, unparsedModel) -> { + + Optional optionalService = serviceRegistry.getService(unparsedModel.service()); + if (optionalService.isEmpty()) { + listener.onFailure( + new ElasticsearchStatusException( + "Service [{}] not found", + RestStatus.INTERNAL_SERVER_ERROR, + unparsedModel.service() + ) + ); + } else { + service.set(optionalService.get()); + listener.onResponse(unparsedModel); + } + }) + .andThen((listener, existingUnparsedModel) -> { + + Model existingParsedModel = service.get() + .parsePersistedConfigWithSecrets( + request.getInferenceEntityId(), + existingUnparsedModel.taskType(), + new HashMap<>(existingUnparsedModel.settings()), + new HashMap<>(existingUnparsedModel.secrets()) + ); + + Model newModel = combineExistingModelWithNewSettings( + existingParsedModel, + request.getContentAsSettings(), + service.get().name(), + resolvedTaskType + ); + + if (isInClusterService(service.get().name())) { + updateInClusterEndpoint(request, newModel, existingParsedModel, listener); + } else { + modelRegistry.updateModelTransaction(newModel, existingParsedModel, listener); + } + }) + .andThen((listener, didUpdate) -> { + if (didUpdate) { + modelRegistry.getModel(inferenceEntityId, ActionListener.wrap((unparsedModel) -> { + if (unparsedModel == null) { + listener.onFailure( + new ElasticsearchStatusException( + "Failed to update model, updated model not found", + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + } else { + listener.onResponse( + service.get() + .parsePersistedConfig( + request.getInferenceEntityId(), + resolvedTaskType, + new HashMap<>(unparsedModel.settings()) + ) + .getConfigurations() + ); + } + }, listener::onFailure)); + } else { + listener.onFailure(new ElasticsearchStatusException("Failed to update model", RestStatus.INTERNAL_SERVER_ERROR)); + } + + }).andThen((listener, modelConfig) -> { + listener.onResponse(new UpdateInferenceModelAction.Response(modelConfig)); + }) + .addListener(masterListener); + } + + /** + * Combines the existing model with the new settings to create a new model using the + * SecretSettings and TaskSettings implementations for each service, as well as specifically handling NUM_ALLOCATIONS. + * + * @param existingParsedModel the Model representing a third-party service endpoint + * @param settingsToUpdate new settings + * @param serviceName + * @return a new object representing the updated model + */ + private Model combineExistingModelWithNewSettings( + Model existingParsedModel, + UpdateInferenceModelAction.Settings settingsToUpdate, + String serviceName, + TaskType resolvedTaskType + ) { + ModelConfigurations existingConfigs = existingParsedModel.getConfigurations(); + TaskSettings existingTaskSettings = existingConfigs.getTaskSettings(); + SecretSettings existingSecretSettings = existingParsedModel.getSecretSettings(); + + SecretSettings newSecretSettings = existingSecretSettings; + TaskSettings newTaskSettings = existingTaskSettings; + ServiceSettings newServiceSettings = existingConfigs.getServiceSettings(); + + if (settingsToUpdate.serviceSettings() != null && existingSecretSettings != null) { + newSecretSettings = existingSecretSettings.newSecretSettings(settingsToUpdate.serviceSettings()); + } + if (settingsToUpdate.serviceSettings() != null && settingsToUpdate.serviceSettings().containsKey(NUM_ALLOCATIONS)) { + // In cluster services can only have their num_allocations updated, so this is a special case + if (newServiceSettings instanceof ElasticsearchInternalServiceSettings elasticServiceSettings) { + newServiceSettings = new ElasticsearchInternalServiceSettings( + elasticServiceSettings, + (Integer) settingsToUpdate.serviceSettings().get(NUM_ALLOCATIONS) + ); + } + } + if (settingsToUpdate.taskSettings() != null && existingTaskSettings != null) { + newTaskSettings = existingTaskSettings.updatedTaskSettings(settingsToUpdate.taskSettings()); + } + + if (existingParsedModel.getTaskType().equals(resolvedTaskType) == false) { + throw new ElasticsearchStatusException("Task type must match the task type of the existing endpoint", RestStatus.BAD_REQUEST); + } + + ModelConfigurations newModelConfigs = new ModelConfigurations( + existingParsedModel.getInferenceEntityId(), + existingParsedModel.getTaskType(), + serviceName, + newServiceSettings, + newTaskSettings + ); + + return new Model(newModelConfigs, new ModelSecrets(newSecretSettings)); + } + + private void updateInClusterEndpoint( + UpdateInferenceModelAction.Request request, + Model newModel, + Model existingParsedModel, + ActionListener listener + ) throws IOException { + // The model we are trying to update must have a trained model associated with it if it is an in-cluster deployment + throwIfTrainedModelDoesntExist(request); + + Map serviceSettings = request.getContentAsSettings().serviceSettings(); + if (serviceSettings != null && serviceSettings.get(NUM_ALLOCATIONS) instanceof Integer numAllocations) { + + UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request( + request.getInferenceEntityId() + ); + updateRequest.setNumberOfAllocations(numAllocations); + + var delegate = listener.delegateFailure((l2, response) -> { + modelRegistry.updateModelTransaction(newModel, existingParsedModel, l2); + }); + + logger.info( + "Updating trained model deployment for inference entity [{}] with [{}] num_allocations", + request.getInferenceEntityId(), + numAllocations + ); + client.execute(UpdateTrainedModelDeploymentAction.INSTANCE, updateRequest, delegate); + + } else { + listener.onFailure( + new ElasticsearchStatusException( + "Failed to parse [{}] of update request [{}]", + RestStatus.BAD_REQUEST, + NUM_ALLOCATIONS, + request.getContent().utf8ToString() + ) + ); + } + + } + + private boolean isInClusterService(String name) { + return List.of(ElasticsearchInternalService.NAME, ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME).contains(name); + } + + private void throwIfTrainedModelDoesntExist(UpdateInferenceModelAction.Request request) throws ElasticsearchStatusException { + var assignments = TrainedModelAssignmentUtils.modelAssignments(request.getInferenceEntityId(), clusterService.state()); + if ((assignments == null || assignments.isEmpty())) { + throw ExceptionsHelper.entityNotFoundException( + Messages.MODEL_ID_DOES_NOT_MATCH_EXISTING_MODEL_IDS_BUT_MUST_FOR_IN_CLUSTER_SERVICE, + request.getInferenceEntityId() + + ); + } + } + + private void checkEndpointExists(String inferenceEntityId, ActionListener listener) { + modelRegistry.getModelWithSecrets(inferenceEntityId, ActionListener.wrap((model) -> { + if (model == null) { + listener.onFailure( + ExceptionsHelper.entityNotFoundException(Messages.INFERENCE_ENTITY_NON_EXISTANT_NO_UPDATE, inferenceEntityId) + ); + } else { + listener.onResponse(model); + } + }, e -> { + if (e instanceof ResourceNotFoundException) { + listener.onFailure( + // provide a more specific error message if the inference entity does not exist + ExceptionsHelper.entityNotFoundException(Messages.INFERENCE_ENTITY_NON_EXISTANT_NO_UPDATE, inferenceEntityId) + ); + } else { + listener.onFailure(e); + } + })); + } + + private static XContentParser getParser(UpdateInferenceModelAction.Request request) throws IOException { + return XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType()); + } + + @Override + protected ClusterBlockException checkBlock(UpdateInferenceModelAction.Request request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index d756c0ef26f14..62571c13aebf4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -3,6 +3,8 @@ * or more contributor license agreements. Licensed under the Elastic License * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. + * + * this file contains code contributed by a generative AI */ package org.elasticsearch.xpack.inference.registry; @@ -21,6 +23,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.OriginSettingClient; @@ -49,10 +52,13 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; import java.util.stream.Collectors; @@ -83,6 +89,8 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) private final OriginSettingClient client; private Map defaultConfigs; + private final Set preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>()); + public ModelRegistry(Client client) { this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); this.defaultConfigs = new HashMap<>(); @@ -306,7 +314,139 @@ private ModelConfigMap createModelConfigMap(SearchHits hits, String inferenceEnt ); } + public void updateModelTransaction(Model newModel, Model existingModel, ActionListener finalListener) { + + String inferenceEntityId = newModel.getConfigurations().getInferenceEntityId(); + logger.info("Attempting to store update to inference endpoint [{}]", inferenceEntityId); + + if (preventDeletionLock.contains(inferenceEntityId)) { + logger.warn(format("Attempted to update endpoint [{}] that is already being updated", inferenceEntityId)); + finalListener.onFailure( + new ElasticsearchStatusException( + "Endpoint [{}] is currently being updated. Try again once the update completes", + RestStatus.CONFLICT, + inferenceEntityId + ) + ); + return; + } else { + preventDeletionLock.add(inferenceEntityId); + } + + SubscribableListener.newForked((subListener) -> { + // in this block, we try to update the stored model configurations + IndexRequest configRequest = createIndexRequest( + Model.documentId(inferenceEntityId), + InferenceIndex.INDEX_NAME, + newModel.getConfigurations(), + true + ); + + ActionListener storeConfigListener = subListener.delegateResponse((l, e) -> { + // this block will only be called if the bulk unexpectedly throws an exception + preventDeletionLock.remove(inferenceEntityId); + l.onFailure(e); + }); + + client.prepareBulk().add(configRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).execute(storeConfigListener); + + }).andThen((subListener, configResponse) -> { + // in this block, we respond to the success or failure of updating the model configurations, then try to store the new secrets + if (configResponse.hasFailures()) { + // if storing the model configurations failed, it won't throw an exception, we need to check the BulkResponse and handle the + // exceptions ourselves. + logger.error( + format("Failed to update inference endpoint [%s] due to [%s]", inferenceEntityId, configResponse.buildFailureMessage()) + ); + // Since none of our updates succeeded at this point, we can simply return. + finalListener.onFailure( + new ElasticsearchStatusException( + format("Failed to update inference endpoint [%s] due to [%s]", inferenceEntityId), + RestStatus.INTERNAL_SERVER_ERROR, + configResponse.buildFailureMessage() + ) + ); + } else { + // Since the model configurations were successfully updated, we can now try to store the new secrets + IndexRequest secretsRequest = createIndexRequest( + Model.documentId(newModel.getConfigurations().getInferenceEntityId()), + InferenceSecretsIndex.INDEX_NAME, + newModel.getSecrets(), + true + ); + + ActionListener storeSecretsListener = subListener.delegateResponse((l, e) -> { + // this block will only be called if the bulk unexpectedly throws an exception + preventDeletionLock.remove(inferenceEntityId); + l.onFailure(e); + }); + + client.prepareBulk() + .add(secretsRequest) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .execute(storeSecretsListener); + } + }).andThen((subListener, secretsResponse) -> { + // in this block, we respond to the success or failure of updating the model secrets + if (secretsResponse.hasFailures()) { + // since storing the secrets failed, we will try to restore / roll-back-to the previous model configurations + IndexRequest configRequest = createIndexRequest( + Model.documentId(inferenceEntityId), + InferenceIndex.INDEX_NAME, + existingModel.getConfigurations(), + true + ); + logger.error( + "Failed to update inference endpoint secrets [{}], attempting rolling back to previous state", + inferenceEntityId + ); + + ActionListener rollbackConfigListener = subListener.delegateResponse((l, e) -> { + // this block will only be called if the bulk unexpectedly throws an exception + preventDeletionLock.remove(inferenceEntityId); + l.onFailure(e); + }); + client.prepareBulk() + .add(configRequest) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .execute(rollbackConfigListener); + } else { + // since updating the secrets was successful, we can remove the lock and respond to the final listener + preventDeletionLock.remove(inferenceEntityId); + finalListener.onResponse(true); + } + }).andThen((subListener, configResponse) -> { + // this block will be called if the secrets response failed, and the rollback didn't throw an exception. + // The rollback still could have failed though, so we need to check for that. + preventDeletionLock.remove(inferenceEntityId); + if (configResponse.hasFailures()) { + logger.error( + format("Failed to update inference endpoint [%s] due to [%s]", inferenceEntityId, configResponse.buildFailureMessage()) + ); + finalListener.onFailure( + new ElasticsearchStatusException( + format( + "Failed to rollback while handling failure to update inference endpoint [%s]. " + + "Endpoint may be in an inconsistent state due to [%s]", + inferenceEntityId + ), + RestStatus.INTERNAL_SERVER_ERROR, + configResponse.buildFailureMessage() + ) + ); + } else { + logger.warn("Failed to update inference endpoint [{}], successfully rolled back to previous state", inferenceEntityId); + finalListener.onResponse(false); + } + }); + + } + + /** + * Note: storeModel does not overwrite existing models and thus does not need to check the lock + */ public void storeModel(Model model, ActionListener listener) { + ActionListener bulkResponseActionListener = getStoreModelListener(model, listener); IndexRequest configRequest = createIndexRequest( @@ -405,6 +545,16 @@ private static BulkItemResponse.Failure getFirstBulkFailure(BulkResponse bulkRes } public void deleteModel(String inferenceEntityId, ActionListener listener) { + if (preventDeletionLock.contains(inferenceEntityId)) { + listener.onFailure( + new ElasticsearchStatusException( + "Model is currently being updated, you may delete the model once the update completes", + RestStatus.CONFLICT + ) + ); + return; + } + DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false); request.indices(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN); request.setQuery(documentIdQuery(inferenceEntityId)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java index 9f64b58e48b55..2dec72e6692a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java @@ -14,6 +14,12 @@ public final class Paths { static final String INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}"; static final String TASK_TYPE_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/{" + INFERENCE_ID + "}"; static final String INFERENCE_DIAGNOSTICS_PATH = "_inference/.diagnostics"; + static final String TASK_TYPE_INFERENCE_ID_UPDATE_PATH = "_inference/{" + + TASK_TYPE_OR_INFERENCE_ID + + "}/{" + + INFERENCE_ID + + "}/_update"; + static final String INFERENCE_ID_UPDATE_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/_update"; static final String STREAM_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/_stream"; static final String STREAM_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{" diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUpdateInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUpdateInferenceModelAction.java new file mode 100644 index 0000000000000..9405a6752538c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUpdateInferenceModelAction.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rest; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.rest.RestUtils; +import org.elasticsearch.rest.Scope; +import org.elasticsearch.rest.ServerlessScope; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction; + +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.PUT; +import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID; +import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_UPDATE_PATH; +import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_UPDATE_PATH; +import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID; + +@ServerlessScope(Scope.PUBLIC) +public class RestUpdateInferenceModelAction extends BaseRestHandler { + @Override + public String getName() { + return "update_inference_model_action"; + } + + @Override + public List routes() { + return List.of(new Route(PUT, INFERENCE_ID_UPDATE_PATH), new Route(PUT, TASK_TYPE_INFERENCE_ID_UPDATE_PATH)); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { + String inferenceEntityId; + TaskType taskType; + if (restRequest.hasParam(INFERENCE_ID)) { + inferenceEntityId = restRequest.param(INFERENCE_ID); + taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID)); + } else { + throw new ElasticsearchStatusException("Inference ID must be provided in the path", RestStatus.BAD_REQUEST); + } + + var request = new UpdateInferenceModelAction.Request( + inferenceEntityId, + restRequest.requiredContent(), + restRequest.getXContentType(), + taskType, + RestUtils.getMasterNodeTimeout(restRequest) + ); + return channel -> client.execute(UpdateInferenceModelAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 32c1d17373e53..c0e3c78b12f13 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -625,6 +625,40 @@ public static String mustBeAPositiveLongErrorMessage(String settingName, String return format("[%s] Invalid value [%s]. [%s] must be a positive long", scope, value, settingName); } + /** + * task_type can be specified as either a URL parameter or in the + * request body. Resolve which to use or throw if the settings are + * inconsistent + * @param urlTaskType Taken from the URL parameter. ANY means not specified. + * @param bodyTaskType Taken from the request body. Maybe null + * @return The resolved task type + */ + public static TaskType resolveTaskType(TaskType urlTaskType, String bodyTaskType) { + if (bodyTaskType == null) { + if (urlTaskType == TaskType.ANY) { + throw new ElasticsearchStatusException("model is missing required setting [task_type]", RestStatus.BAD_REQUEST); + } else { + return urlTaskType; + } + } + + TaskType parsedBodyTask = TaskType.fromStringOrStatusException(bodyTaskType); + if (parsedBodyTask == TaskType.ANY) { + throw new ElasticsearchStatusException("task_type [any] is not valid type for inference", RestStatus.BAD_REQUEST); + } + + if (parsedBodyTask.isAnyOrSame(urlTaskType) == false) { + throw new ElasticsearchStatusException( + "Cannot resolve conflicting task_type parameter in the request URL [{}] and the request body [{}]", + RestStatus.BAD_REQUEST, + urlTaskType.toString(), + bodyTaskType + ); + } + + return parsedBodyTask; + } + /** * Functional interface for creating an enum from a string. * @param diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettings.java index 63f82a8eceb98..05b5873a81d8d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettings.java @@ -139,4 +139,12 @@ public int hashCode() { public Map getParameters() { return parameters; } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + AlibabaCloudSearchCompletionTaskSettings updatedSettings = AlibabaCloudSearchCompletionTaskSettings.fromMap( + new HashMap<>(newSettings) + ); + return of(this, updatedSettings); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettings.java index c908c219e4053..9a431717d9fb9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettings.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.EnumSet; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -174,4 +175,10 @@ public int hashCode() { public static String invalidInputTypeMessage(InputType inputType) { return Strings.format("received invalid input type value [%s]", inputType.toString()); } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + AlibabaCloudSearchEmbeddingsTaskSettings newSettingsOnly = fromMap(new HashMap<>(newSettings)); + return of(this, newSettingsOnly, newSettingsOnly.inputType != null ? newSettingsOnly.inputType : this.getInputType()); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankTaskSettings.java index 97e7ecd41223d..40c3dee00d6c7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankTaskSettings.java @@ -102,4 +102,10 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(); } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + AlibabaCloudSearchRerankTaskSettings updatedSettings = new AlibabaCloudSearchRerankTaskSettings(); + return of(this, updatedSettings); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettings.java index 873cdf31fbe9d..0f4ebce920167 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettings.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.EnumSet; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -188,4 +189,10 @@ public int hashCode() { public static String invalidInputTypeMessage(InputType inputType) { return Strings.format("received invalid input type value [%s]", inputType.toString()); } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + AlibabaCloudSearchSparseTaskSettings updatedSettings = fromMap(new HashMap<>(newSettings)); + return of(this, updatedSettings, updatedSettings.getInputType() != null ? updatedSettings.getInputType() : this.inputType); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettings.java index 9e6328ce1c358..30a7dc9ad5a2e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettings.java @@ -18,6 +18,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -107,4 +108,9 @@ public boolean equals(Object object) { public int hashCode() { return Objects.hash(accessKey, secretKey); } + + @Override + public SecretSettings newSecretSettings(Map newSecrets) { + return fromMap(new HashMap<>(newSecrets)); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettings.java index 13787ed8cb6a4..c3db1465863e4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettings.java @@ -17,6 +17,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -192,4 +193,12 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(temperature, topP, topK, maxNewTokens); } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + AmazonBedrockChatCompletionRequestTaskSettings requestSettings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap( + new HashMap<>(newSettings) + ); + return of(this, requestSettings); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionTaskSettings.java index bb2c027127371..e8a6ca638c916 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionTaskSettings.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -59,6 +60,11 @@ private static AnthropicChatCompletionTaskSettings fromPersistedMap(Map newSettings) { + return fromRequestMap(new HashMap<>(newSettings)); + } + private record CommonFields(int maxTokens, Double temperature, Double topP, Integer topK) {} private static CommonFields fromMap(Map map, ValidationException validationException) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionTaskSettings.java index b8e33bac410fe..544c52f59a3c4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionTaskSettings.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettings; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -178,6 +179,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } + @Override + public String toString() { + return "AzureAiStudioChatCompletionTaskSettings{" + + "temperature=" + + temperature + + ", topP=" + + topP + + ", doSample=" + + doSample + + ", maxNewTokens=" + + maxNewTokens + + '}'; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -194,4 +209,11 @@ public int hashCode() { return Objects.hash(temperature, topP, doSample, maxNewTokens); } + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + AzureAiStudioChatCompletionRequestTaskSettings requestSettings = AzureAiStudioChatCompletionRequestTaskSettings.fromMap( + new HashMap<>(newSettings) + ); + return of(this, requestSettings); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsTaskSettings.java index bdb6ae74e5ab3..340ee95cd7b0c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsTaskSettings.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettings; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -111,4 +112,12 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hashCode(user); } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + AzureAiStudioEmbeddingsRequestTaskSettings requestSettings = AzureAiStudioEmbeddingsRequestTaskSettings.fromMap( + new HashMap<>(newSettings) + ); + return AzureAiStudioEmbeddingsTaskSettings.of(this, requestSettings); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java index 06217e8079b06..a2bd4f6175989 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java @@ -19,6 +19,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -125,4 +126,9 @@ public boolean equals(Object object) { public int hashCode() { return Objects.hash(entraId, apiKey); } + + @Override + public SecretSettings newSecretSettings(Map newSecrets) { + return AzureOpenAiSecretSettings.fromMap(new HashMap<>(newSecrets)); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettings.java index de0a0897a93c5..3008a543b8fea 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettings.java @@ -18,6 +18,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -107,4 +108,12 @@ public boolean equals(Object object) { public int hashCode() { return Objects.hash(user); } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + AzureOpenAiCompletionRequestTaskSettings updatedSettings = AzureOpenAiCompletionRequestTaskSettings.fromMap( + new HashMap<>(newSettings) + ); + return of(this, updatedSettings); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettings.java index 28ccade0a06b0..4157d7748d789 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettings.java @@ -18,6 +18,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -116,4 +117,12 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(user); } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + AzureOpenAiEmbeddingsRequestTaskSettings requestSettings = AzureOpenAiEmbeddingsRequestTaskSettings.fromMap( + new HashMap<>(newSettings) + ); + return of(this, requestSettings); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java index 34d37d0003adf..b789d1578290a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.EnumSet; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -204,4 +205,10 @@ public int hashCode() { public static String invalidInputTypeMessage(InputType inputType) { return Strings.format("received invalid input type value [%s]", inputType.toString()); } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + CohereEmbeddingsTaskSettings updatedSettings = CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(newSettings)); + return of(this, updatedSettings, updatedSettings.inputType != null ? updatedSettings.inputType : this.inputType); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankTaskSettings.java index f5893c825efcf..479000f840502 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankTaskSettings.java @@ -20,6 +20,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -186,4 +187,9 @@ public Integer getMaxChunksPerDoc() { return maxChunksPerDoc; } + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + CohereRerankTaskSettings updatedSettings = CohereRerankTaskSettings.fromMap(new HashMap<>(newSettings)); + return CohereRerankTaskSettings.of(this, updatedSettings); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettings.java index 70d787152121f..a0be1661b860d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettings.java @@ -17,6 +17,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -87,7 +88,11 @@ public CustomElandRerankTaskSettings(StreamInput in) throws IOException { } public CustomElandRerankTaskSettings(@Nullable Boolean doReturnDocuments) { - this.returnDocuments = doReturnDocuments; + if (doReturnDocuments == null) { + this.returnDocuments = true; + } else { + this.returnDocuments = doReturnDocuments; + } } @Override @@ -136,4 +141,10 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(returnDocuments); } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + CustomElandRerankTaskSettings updatedSettings = CustomElandRerankTaskSettings.fromMap(new HashMap<>(newSettings)); + return of(this, updatedSettings); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java index f8b5837ef387e..37e0f28dfb3fe 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java @@ -122,6 +122,18 @@ protected ElasticsearchInternalServiceSettings(ElasticsearchInternalServiceSetti this.adaptiveAllocationsSettings = other.adaptiveAllocationsSettings; } + /** + * Copy constructor with the ability to set the number of allocations. Used for Update API. + * @param other the existing settings + * @param numAllocations the new number of allocations + */ + public ElasticsearchInternalServiceSettings(ElasticsearchInternalServiceSettings other, int numAllocations) { + this.numAllocations = numAllocations; + this.numThreads = other.numThreads; + this.modelId = other.modelId; + this.adaptiveAllocationsSettings = other.adaptiveAllocationsSettings; + } + public ElasticsearchInternalServiceSettings(StreamInput in) throws IOException { if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { this.numAllocations = in.readOptionalVInt(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettings.java index 33696231668a5..3bcaa57827fdb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettings.java @@ -15,6 +15,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.Map; import java.util.Objects; public class ElserMlNodeTaskSettings implements TaskSettings { @@ -65,4 +66,9 @@ public int hashCode() { // Return the hash of NAME to make the serialization tests pass return Objects.hash(NAME); } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + return DEFAULT; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java index 57c8d61f9f9a5..20dbadb9b3eae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java @@ -19,6 +19,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -30,7 +31,7 @@ public class GoogleVertexAiSecretSettings implements SecretSettings { public static final String SERVICE_ACCOUNT_JSON = "service_account_json"; - private final SecureString serviceAccountJson; + final SecureString serviceAccountJson; public static GoogleVertexAiSecretSettings fromMap(@Nullable Map map) { if (map == null) { @@ -101,4 +102,9 @@ public boolean equals(Object object) { public int hashCode() { return Objects.hash(serviceAccountJson); } + + @Override + public SecretSettings newSecretSettings(Map newSecrets) { + return GoogleVertexAiSecretSettings.fromMap(new HashMap<>(newSecrets)); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsTaskSettings.java index 5e0185a7abb36..b7242100178a3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsTaskSettings.java @@ -17,6 +17,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -107,4 +108,12 @@ public boolean equals(Object object) { public int hashCode() { return Objects.hash(autoTruncate); } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + GoogleVertexAiEmbeddingsRequestTaskSettings requestSettings = GoogleVertexAiEmbeddingsRequestTaskSettings.fromMap( + new HashMap<>(newSettings) + ); + return of(this, requestSettings); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleVertexAiRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleVertexAiRerankTaskSettings.java index 8256eed7a5cba..64bec7e6cfeef 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleVertexAiRerankTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleVertexAiRerankTaskSettings.java @@ -18,6 +18,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -107,4 +108,12 @@ public boolean equals(Object object) { public int hashCode() { return Objects.hash(topN); } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + GoogleVertexAiRerankRequestTaskSettings requestSettings = GoogleVertexAiRerankRequestTaskSettings.fromMap( + new HashMap<>(newSettings) + ); + return of(this, requestSettings); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettings.java index 3c2586fb5a264..44064f61f5180 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettings.java @@ -18,6 +18,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -107,4 +108,12 @@ public boolean equals(Object object) { public int hashCode() { return Objects.hash(user); } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + OpenAiChatCompletionRequestTaskSettings updatedSettings = OpenAiChatCompletionRequestTaskSettings.fromMap( + new HashMap<>(newSettings) + ); + return of(this, updatedSettings); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettings.java index c7cc60043ef47..64f852822703c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettings.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -127,4 +128,10 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(user); } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + OpenAiEmbeddingsRequestTaskSettings requestSettings = OpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(newSettings)); + return of(this, requestSettings); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettings.java index 6affa998c089d..c68d4bc801724 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettings.java @@ -19,6 +19,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -78,4 +79,9 @@ public TransportVersion getMinimalSupportedVersion() { public void writeTo(StreamOutput out) throws IOException { out.writeSecureString(apiKey); } + + @Override + public SecretSettings newSecretSettings(Map newSecrets) { + return fromMap(new HashMap<>(newSecrets)); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/EmptySecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/EmptySecretSettingsTests.java index b50ea9e5ee224..d27a326d5fa1e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/EmptySecretSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/EmptySecretSettingsTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Tuple; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.test.AbstractWireSerializingTestCase; @@ -32,4 +33,13 @@ protected EmptySecretSettings mutateInstance(EmptySecretSettings instance) { // All instances are the same and have no fields, nothing to mutate return null; } + + public void testNewSecretSettings() { + + EmptySecretSettings newSecretSettings = (EmptySecretSettings) EmptySecretSettings.INSTANCE.newSecretSettings( + randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLengthBetween(1, 10), randomAlphaOfLengthBetween(1, 10))) + ); + + assertSame(EmptySecretSettings.INSTANCE, newSecretSettings); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/EmptyTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/EmptyTaskSettingsTests.java index 060dc23b935cc..7bc0cc57e31ab 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/EmptyTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/EmptyTaskSettingsTests.java @@ -11,12 +11,20 @@ import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import java.util.Map; + public class EmptyTaskSettingsTests extends AbstractWireSerializingTestCase { public static EmptyTaskSettings createRandom() { return EmptyTaskSettings.INSTANCE; // no options to randomise } + public void testUpdatedTaskSettings() { + EmptyTaskSettings initialSettings = createRandom(); + EmptyTaskSettings updatedSettings = (EmptyTaskSettings) initialSettings.updatedTaskSettings(Map.of()); + assertEquals(EmptyTaskSettings.INSTANCE, updatedSettings); + } + @Override protected Writeable.Reader instanceReader() { return EmptyTaskSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelSecretsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelSecretsTests.java index d6d139190c12c..ea2f41bf5c6cf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelSecretsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelSecretsTests.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.List; +import java.util.Map; public class ModelSecretsTests extends AbstractWireSerializingTestCase { @@ -83,5 +84,10 @@ public String getWriteableName() { public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_11_X; } + + @Override + public SecretSettings newSecretSettings(Map newSecrets) { + return new FakeSecretSettings(newSecrets.get(API_KEY).toString()); + } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelActionTests.java index 27e56c1bd973d..991c5a581eb35 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelActionTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.ServiceUtils; import static org.hamcrest.Matchers.containsString; @@ -17,27 +18,18 @@ public class TransportPutInferenceModelActionTests extends ESTestCase { public void testResolveTaskType() { - assertEquals(TaskType.SPARSE_EMBEDDING, TransportPutInferenceModelAction.resolveTaskType(TaskType.SPARSE_EMBEDDING, null)); - assertEquals( - TaskType.SPARSE_EMBEDDING, - TransportPutInferenceModelAction.resolveTaskType(TaskType.ANY, TaskType.SPARSE_EMBEDDING.toString()) - ); + assertEquals(TaskType.SPARSE_EMBEDDING, ServiceUtils.resolveTaskType(TaskType.SPARSE_EMBEDDING, null)); + assertEquals(TaskType.SPARSE_EMBEDDING, ServiceUtils.resolveTaskType(TaskType.ANY, TaskType.SPARSE_EMBEDDING.toString())); - var e = expectThrows( - ElasticsearchStatusException.class, - () -> TransportPutInferenceModelAction.resolveTaskType(TaskType.ANY, null) - ); + var e = expectThrows(ElasticsearchStatusException.class, () -> ServiceUtils.resolveTaskType(TaskType.ANY, null)); assertThat(e.getMessage(), containsString("model is missing required setting [task_type]")); - e = expectThrows( - ElasticsearchStatusException.class, - () -> TransportPutInferenceModelAction.resolveTaskType(TaskType.ANY, TaskType.ANY.toString()) - ); + e = expectThrows(ElasticsearchStatusException.class, () -> ServiceUtils.resolveTaskType(TaskType.ANY, TaskType.ANY.toString())); assertThat(e.getMessage(), containsString("task_type [any] is not valid type for inference")); e = expectThrows( ElasticsearchStatusException.class, - () -> TransportPutInferenceModelAction.resolveTaskType(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING.toString()) + () -> ServiceUtils.resolveTaskType(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING.toString()) ); assertThat( e.getMessage(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java index d8c25fb5a6d88..779a98e023455 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java @@ -25,6 +25,7 @@ import org.elasticsearch.xpack.inference.services.ServiceUtils; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength; @@ -217,6 +218,11 @@ public String getWriteableName() { public TransportVersion getMinimalSupportedVersion() { return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + return TestTaskSettings.fromMap(new HashMap<>(newSettings)); + } } public record TestSecretSettings(String apiKey) implements SecretSettings { @@ -265,5 +271,10 @@ public String getWriteableName() { public TransportVersion getMinimalSupportedVersion() { return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests } + + @Override + public SecretSettings newSecretSettings(Map newSecrets) { + return new TestSecretSettings(newSecrets.get("api_key").toString()); + } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettingsTests.java index c48d57cf3e03b..7acba78b3066b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettingsTests.java @@ -7,16 +7,17 @@ package org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Nullable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.hamcrest.MatcherAssert; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Map; +import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettings.PARAMETERS; import static org.hamcrest.Matchers.is; public class AlibabaCloudSearchCompletionTaskSettingsTests extends AbstractWireSerializingTestCase< @@ -34,10 +35,20 @@ public void testFromMap() { ); } - public void testIsEmpty() { - var randomSettings = createRandom(); - var stringRep = Strings.toString(randomSettings); - assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); + public void testUpdatedTaskSettings() { + var initialSettings = createRandom(); + var newSettings = createRandom(); + Map newSettingsMap = new HashMap<>(); + if (newSettings.getParameters() != null) { + newSettingsMap.put(PARAMETERS, newSettings.getParameters()); + } + AlibabaCloudSearchCompletionTaskSettings updatedSettings = (AlibabaCloudSearchCompletionTaskSettings) initialSettings + .updatedTaskSettings(Collections.unmodifiableMap(newSettingsMap)); + if (newSettings.getParameters() == null) { + assertEquals(initialSettings.getParameters(), updatedSettings.getParameters()); + } else { + assertEquals(newSettings.getParameters(), updatedSettings.getParameters()); + } } @Override @@ -60,7 +71,7 @@ public static Map getTaskSettingsMap(@Nullable Map(); if (params != null) { - map.put(AlibabaCloudSearchCompletionTaskSettings.PARAMETERS, params); + map.put(PARAMETERS, params); } return map; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettingsTests.java index 9e75a2f475051..4b558949fdc4a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettingsTests.java @@ -15,10 +15,12 @@ import org.hamcrest.MatcherAssert; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithIngestAndSearch; +import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettings.INPUT_TYPE; import static org.hamcrest.Matchers.is; public class AlibabaCloudSearchEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase< @@ -31,13 +33,27 @@ public static AlibabaCloudSearchEmbeddingsTaskSettings createRandom() { public void testFromMap() { MatcherAssert.assertThat( - AlibabaCloudSearchEmbeddingsTaskSettings.fromMap( - new HashMap<>(Map.of(AlibabaCloudSearchEmbeddingsTaskSettings.INPUT_TYPE, "ingest")) - ), + AlibabaCloudSearchEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(INPUT_TYPE, "ingest"))), is(new AlibabaCloudSearchEmbeddingsTaskSettings(InputType.INGEST)) ); } + public void testUpdatedTaskSettings() { + var initialSettings = createRandom(); + var newSettings = createRandom(); + Map newSettingsMap = new HashMap<>(); + if (newSettings.getInputType() != null) { + newSettingsMap.put(INPUT_TYPE, newSettings.getInputType().toString()); + } + AlibabaCloudSearchEmbeddingsTaskSettings updatedSettings = (AlibabaCloudSearchEmbeddingsTaskSettings) initialSettings + .updatedTaskSettings(Collections.unmodifiableMap(newSettingsMap)); + if (newSettings.getInputType() == null) { + assertEquals(initialSettings.getInputType(), updatedSettings.getInputType()); + } else { + assertEquals(newSettings.getInputType(), updatedSettings.getInputType()); + } + } + public void testFromMap_WhenInputTypeIsNull() { InputType inputType = null; MatcherAssert.assertThat( @@ -72,7 +88,7 @@ public static Map getTaskSettingsMap(@Nullable InputType inputTy var map = new HashMap(); if (inputType != null) { - map.put(AlibabaCloudSearchEmbeddingsTaskSettings.INPUT_TYPE, inputType.toString()); + map.put(INPUT_TYPE, inputType.toString()); } return map; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettingsTests.java index 2c134c6765078..fa78b24d1a4bb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettingsTests.java @@ -19,6 +19,8 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithIngestAndSearch; +import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseTaskSettings.INPUT_TYPE; +import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseTaskSettings.RETURN_TOKEN; import static org.hamcrest.Matchers.is; public class AlibabaCloudSearchSparseTaskSettingsTests extends AbstractWireSerializingTestCase { @@ -31,11 +33,33 @@ public static AlibabaCloudSearchSparseTaskSettings createRandom() { public void testFromMap() { MatcherAssert.assertThat( - AlibabaCloudSearchSparseTaskSettings.fromMap(new HashMap<>(Map.of(AlibabaCloudSearchSparseTaskSettings.INPUT_TYPE, "ingest"))), + AlibabaCloudSearchSparseTaskSettings.fromMap(new HashMap<>(Map.of(INPUT_TYPE, "ingest"))), is(new AlibabaCloudSearchSparseTaskSettings(InputType.INGEST, null)) ); } + public void testUpdatedTaskSettings() { + { + var initialSettings = createRandom(); + var newSettings = createRandom(); + AlibabaCloudSearchSparseTaskSettings updatedSettings = (AlibabaCloudSearchSparseTaskSettings) initialSettings + .updatedTaskSettings(Map.of(RETURN_TOKEN, newSettings.isReturnToken())); + } + { + var initialSettings = createRandom(); + var newSettings = createRandom(); + AlibabaCloudSearchSparseTaskSettings updatedSettings = (AlibabaCloudSearchSparseTaskSettings) initialSettings + .updatedTaskSettings( + Map.of( + INPUT_TYPE, + newSettings.getInputType() == null ? InputType.SEARCH.toString() : newSettings.getInputType().toString(), + RETURN_TOKEN, + newSettings.isReturnToken() + ) + ); + } + } + public void testIsEmpty() { var randomSettings = createRandom(); var stringRep = Strings.toString(randomSettings); @@ -69,11 +93,11 @@ public static Map getTaskSettingsMap(@Nullable InputType inputTy var map = new HashMap(); if (inputType != null) { - map.put(AlibabaCloudSearchSparseTaskSettings.INPUT_TYPE, inputType.toString()); + map.put(INPUT_TYPE, inputType.toString()); } if (returnToken != null) { - map.put(AlibabaCloudSearchSparseTaskSettings.RETURN_TOKEN, returnToken); + map.put(RETURN_TOKEN, returnToken); } return map; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettingsTests.java index 904851842a6c8..88aebd2d9d42b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettingsTests.java @@ -29,6 +29,17 @@ public class AmazonBedrockSecretSettingsTests extends AbstractBWCWireSerializationTestCase { + public void testNewSecretSettings() { + AmazonBedrockSecretSettings initialSettings = createRandom(); + AmazonBedrockSecretSettings newSettings = createRandom(); + + AmazonBedrockSecretSettings finalSettings = (AmazonBedrockSecretSettings) initialSettings.newSecretSettings( + Map.of(ACCESS_KEY_FIELD, newSettings.accessKey.toString(), SECRET_KEY_FIELD, newSettings.secretKey.toString()) + ); + + assertEquals(newSettings, finalSettings); + } + public void testIt_CreatesSettings_ReturnsNullFromMap_null() { var secrets = AmazonBedrockSecretSettings.fromMap(null); assertNull(secrets); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettingsTests.java index 69dd3b1e6257b..adbf2c66ca0e2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettingsTests.java @@ -38,6 +38,68 @@ public void testIsEmpty() { assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); } + public void updatedTaskSettings_WithEmptyMap_ReturnsSameSettings() { + var initialSettings = createRandom(); + AmazonBedrockChatCompletionTaskSettings updatedSettings = (AmazonBedrockChatCompletionTaskSettings) initialSettings + .updatedTaskSettings(Map.of()); + assertEquals(initialSettings, updatedSettings); + } + + public void updatedTaskSettings_WithNewTemperature_ReturnsUpdatedSettings() { + var initialSettings = createRandom(); + Map newSettings = Map.of(TEMPERATURE_FIELD, 0.7); + AmazonBedrockChatCompletionTaskSettings updatedSettings = (AmazonBedrockChatCompletionTaskSettings) initialSettings + .updatedTaskSettings(newSettings); + assertEquals(0.7, (double) updatedSettings.temperature(), 0.001); + assertEquals(initialSettings.topP(), updatedSettings.topP()); + assertEquals(initialSettings.topK(), updatedSettings.topK()); + assertEquals(initialSettings.maxNewTokens(), updatedSettings.maxNewTokens()); + } + + public void updatedTaskSettings_WithNewTopP_ReturnsUpdatedSettings() { + var initialSettings = createRandom(); + Map newSettings = Map.of(TOP_P_FIELD, 0.8); + AmazonBedrockChatCompletionTaskSettings updatedSettings = (AmazonBedrockChatCompletionTaskSettings) initialSettings + .updatedTaskSettings(newSettings); + assertEquals(0.8, (double) updatedSettings.topP(), 0.001); + assertEquals(initialSettings.temperature(), updatedSettings.temperature()); + assertEquals(initialSettings.topK(), updatedSettings.topK()); + assertEquals(initialSettings.maxNewTokens(), updatedSettings.maxNewTokens()); + } + + public void updatedTaskSettings_WithNewTopK_ReturnsUpdatedSettings() { + var initialSettings = createRandom(); + Map newSettings = Map.of(TOP_K_FIELD, 0.9); + AmazonBedrockChatCompletionTaskSettings updatedSettings = (AmazonBedrockChatCompletionTaskSettings) initialSettings + .updatedTaskSettings(newSettings); + assertEquals(0.9, (double) updatedSettings.topK(), 0.001); + assertEquals(initialSettings.temperature(), updatedSettings.temperature()); + assertEquals(initialSettings.topP(), updatedSettings.topP()); + assertEquals(initialSettings.maxNewTokens(), updatedSettings.maxNewTokens()); + } + + public void updatedTaskSettings_WithNewMaxNewTokens_ReturnsUpdatedSettings() { + var initialSettings = createRandom(); + Map newSettings = Map.of(MAX_NEW_TOKENS_FIELD, 256); + AmazonBedrockChatCompletionTaskSettings updatedSettings = (AmazonBedrockChatCompletionTaskSettings) initialSettings + .updatedTaskSettings(newSettings); + assertEquals(256, (double) updatedSettings.maxNewTokens(), 0.001); + assertEquals(initialSettings.temperature(), updatedSettings.temperature()); + assertEquals(initialSettings.topP(), updatedSettings.topP()); + assertEquals(initialSettings.topK(), updatedSettings.topK()); + } + + public void updatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() { + var initialSettings = createRandom(); + Map newSettings = Map.of(TEMPERATURE_FIELD, 0.7, TOP_P_FIELD, 0.8, TOP_K_FIELD, 0.9, MAX_NEW_TOKENS_FIELD, 256); + AmazonBedrockChatCompletionTaskSettings updatedSettings = (AmazonBedrockChatCompletionTaskSettings) initialSettings + .updatedTaskSettings(newSettings); + assertEquals(0.7, (double) updatedSettings.temperature(), 0.001); + assertEquals(0.8, (double) updatedSettings.topP(), 0.001); + assertEquals(0.9, (double) updatedSettings.topK(), 0.001); + assertEquals(256, (int) updatedSettings.maxNewTokens(), 0.001); + } + public void testFromMap_AllValues() { var taskMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512); assertEquals( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionTaskSettingsTests.java index e00de80e8709e..5f6823770345f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionTaskSettingsTests.java @@ -24,6 +24,25 @@ public class AnthropicChatCompletionTaskSettingsTests extends AbstractBWCWireSerializationTestCase { + public void testUpdatedTaskSettings() { + var initialSettings = createRandom(); + var newSettings = createRandom(); + AnthropicChatCompletionTaskSettings updatedSettings = (AnthropicChatCompletionTaskSettings) initialSettings.updatedTaskSettings( + Map.of( + AnthropicServiceFields.MAX_TOKENS, + newSettings.maxTokens(), + AnthropicServiceFields.TEMPERATURE_FIELD, + newSettings.temperature(), + AnthropicServiceFields.TOP_P_FIELD, + newSettings.topP(), + AnthropicServiceFields.TOP_K_FIELD, + newSettings.topK() + ) + ); + + assertEquals(newSettings, updatedSettings); + } + public static Map getChatCompletionTaskSettingsMap( @Nullable Integer maxTokens, @Nullable Double temperature, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionTaskSettingsTests.java index 8d7dcf1ef5170..21c1a233348fe 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionTaskSettingsTests.java @@ -19,6 +19,7 @@ import org.hamcrest.MatcherAssert; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -38,6 +39,30 @@ public void testIsEmpty() { assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); } + public void testUpdatedTaskSettings() { + var initialSettings = createRandom(); + var newSettings = createRandom(); + var settingsMap = new HashMap(); + if (newSettings.doSample() != null) settingsMap.put(DO_SAMPLE_FIELD, newSettings.doSample()); + if (newSettings.temperature() != null) settingsMap.put(TEMPERATURE_FIELD, newSettings.temperature()); + if (newSettings.topP() != null) settingsMap.put(TOP_P_FIELD, newSettings.topP()); + if (newSettings.maxNewTokens() != null) settingsMap.put(MAX_NEW_TOKENS_FIELD, newSettings.maxNewTokens()); + + AzureAiStudioChatCompletionTaskSettings updatedSettings = (AzureAiStudioChatCompletionTaskSettings) initialSettings + .updatedTaskSettings(Collections.unmodifiableMap(settingsMap)); + + assertEquals( + newSettings.temperature() == null ? initialSettings.temperature() : newSettings.temperature(), + updatedSettings.temperature() + ); + assertEquals(newSettings.topP() == null ? initialSettings.topP() : newSettings.topP(), updatedSettings.topP()); + assertEquals(newSettings.doSample() == null ? initialSettings.doSample() : newSettings.doSample(), updatedSettings.doSample()); + assertEquals( + newSettings.maxNewTokens() == null ? initialSettings.maxNewTokens() : newSettings.maxNewTokens(), + updatedSettings.maxNewTokens() + ); + } + public void testFromMap_AllValues() { var taskMap = getTaskSettingsMap(1.0, 2.0, true, 512); assertEquals( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsTaskSettingsTests.java index 4b6b38bd15c0d..cdfde5fcb09c9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsTaskSettingsTests.java @@ -20,6 +20,7 @@ import org.hamcrest.MatcherAssert; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -32,6 +33,23 @@ public void testIsEmpty() { assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); } + public void testUpdatedTaskSettings() { + var initialSettings = createRandom(); + var newSettings = createRandom(); + Map newSettingsMap = new HashMap<>(); + if (newSettings.user() != null) { + newSettingsMap.put(AzureAiStudioConstants.USER_FIELD, newSettings.user()); + } + AzureAiStudioEmbeddingsTaskSettings updatedSettings = (AzureAiStudioEmbeddingsTaskSettings) initialSettings.updatedTaskSettings( + Collections.unmodifiableMap(newSettingsMap) + ); + if (newSettings.user() == null) { + assertEquals(initialSettings.user(), updatedSettings.user()); + } else { + assertEquals(newSettings.user(), updatedSettings.user()); + } + } + public void testFromMap_WithUser() { assertEquals( new AzureAiStudioEmbeddingsTaskSettings("user"), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettingsTests.java index e08365e7ca3bf..dbbf90054a55b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettingsTests.java @@ -30,7 +30,31 @@ public class AzureOpenAiSecretSettingsTests extends AbstractBWCWireSerializationTestCase { public static AzureOpenAiSecretSettings createRandom() { - return new AzureOpenAiSecretSettings(randomSecureStringOfLength(15), randomSecureStringOfLength(15)); + boolean isApiKeyNotEntraId = randomBoolean(); + return new AzureOpenAiSecretSettings( + isApiKeyNotEntraId ? randomSecureStringOfLength(15) : null, + isApiKeyNotEntraId == false ? randomSecureStringOfLength(15) : null + ); + } + + public void testNewSecretSettingsApiKey() { + AzureOpenAiSecretSettings initialSettings = createRandom(); + AzureOpenAiSecretSettings newSettings = new AzureOpenAiSecretSettings(randomSecureStringOfLength(15), null); + AzureOpenAiSecretSettings finalSettings = (AzureOpenAiSecretSettings) initialSettings.newSecretSettings( + Map.of(API_KEY, newSettings.apiKey().toString()) + ); + + assertEquals(newSettings, finalSettings); + } + + public void testNewSecretSettingsEntraId() { + AzureOpenAiSecretSettings initialSettings = createRandom(); + AzureOpenAiSecretSettings newSettings = new AzureOpenAiSecretSettings(null, randomSecureStringOfLength(15)); + AzureOpenAiSecretSettings finalSettings = (AzureOpenAiSecretSettings) initialSettings.newSecretSettings( + Map.of(ENTRA_ID, newSettings.entraId().toString()) + ); + + assertEquals(newSettings, finalSettings); } public void testFromMap_ApiKey_Only() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettingsTests.java index 8e8d9c4f92800..9d77abfe6d512 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettingsTests.java @@ -38,6 +38,16 @@ public void testIsEmpty() { assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); } + public void testUpdatedTaskSettings() { + var initialSettings = createRandom(); + var newSettings = createRandom(); + AzureOpenAiCompletionTaskSettings updatedSettings = (AzureOpenAiCompletionTaskSettings) initialSettings.updatedTaskSettings( + newSettings.user() == null ? Map.of() : Map.of(AzureOpenAiServiceFields.USER, newSettings.user()) + ); + + assertEquals(newSettings.user() == null ? initialSettings.user() : newSettings.user(), updatedSettings.user()); + } + public void testFromMap_WithUser() { var user = "user"; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettingsTests.java index 72a063af37b90..4df9f2f6bcce0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettingsTests.java @@ -12,13 +12,13 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Nullable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields; import org.hamcrest.MatcherAssert; import java.io.IOException; import java.util.HashMap; import java.util.Map; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.USER; import static org.hamcrest.Matchers.is; public class AzureOpenAiEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase { @@ -41,17 +41,31 @@ public static AzureOpenAiEmbeddingsTaskSettings createRandom() { return new AzureOpenAiEmbeddingsTaskSettings(user); } + public void testUpdatedTaskSettings() { + var initialSettings = createRandom(); + var newSettings = createRandom(); + AzureOpenAiEmbeddingsTaskSettings updatedSettings = (AzureOpenAiEmbeddingsTaskSettings) initialSettings.updatedTaskSettings( + newSettings.user() == null ? Map.of() : Map.of(USER, newSettings.user()) + ); + + if (newSettings.user() == null) { + assertEquals(initialSettings.user(), updatedSettings.user()); + } else { + assertEquals(newSettings.user(), updatedSettings.user()); + } + } + public void testFromMap_WithUser() { assertEquals( new AzureOpenAiEmbeddingsTaskSettings("user"), - AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user"))) + AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(USER, "user"))) ); } public void testFromMap_UserIsEmptyString() { var thrownException = expectThrows( ValidationException.class, - () -> AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, ""))) + () -> AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(USER, ""))) ); MatcherAssert.assertThat( @@ -66,7 +80,7 @@ public void testFromMap_MissingUser_DoesNotThrowException() { } public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() { - var taskSettings = AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user"))); + var taskSettings = AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(USER, "user"))); var overriddenTaskSettings = AzureOpenAiEmbeddingsTaskSettings.of( taskSettings, @@ -76,11 +90,9 @@ public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() { } public void testOverrideWith_UsesOverriddenSettings() { - var taskSettings = AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user"))); + var taskSettings = AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(USER, "user"))); - var requestTaskSettings = AzureOpenAiEmbeddingsRequestTaskSettings.fromMap( - new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user2")) - ); + var requestTaskSettings = AzureOpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of(USER, "user2"))); var overriddenTaskSettings = AzureOpenAiEmbeddingsTaskSettings.of(taskSettings, requestTaskSettings); MatcherAssert.assertThat(overriddenTaskSettings, is(new AzureOpenAiEmbeddingsTaskSettings("user2"))); @@ -105,7 +117,7 @@ public static Map getAzureOpenAiRequestTaskSettingsMap(@Nullable var map = new HashMap(); if (user != null) { - map.put(AzureOpenAiServiceFields.USER, user); + map.put(USER, user); } return map; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java index 90c9b032465c6..3df8fcaf5d6b8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.EnumSet; import java.util.HashMap; import java.util.Locale; @@ -43,6 +44,31 @@ public void testIsEmpty() { assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); } + public void testUpdatedTaskSettings() { + var initialSettings = createRandom(); + var newSettings = createRandom(); + Map newSettingsMap = new HashMap<>(); + if (newSettings.getInputType() != null) { + newSettingsMap.put(CohereEmbeddingsTaskSettings.INPUT_TYPE, newSettings.getInputType().toString()); + } + if (newSettings.getTruncation() != null) { + newSettingsMap.put(CohereServiceFields.TRUNCATE, newSettings.getTruncation().toString()); + } + CohereEmbeddingsTaskSettings updatedSettings = (CohereEmbeddingsTaskSettings) initialSettings.updatedTaskSettings( + Collections.unmodifiableMap(newSettingsMap) + ); + if (newSettings.getInputType() == null) { + assertEquals(initialSettings.getInputType(), updatedSettings.getInputType()); + } else { + assertEquals(newSettings.getInputType(), updatedSettings.getInputType()); + } + if (newSettings.getTruncation() == null) { + assertEquals(initialSettings.getTruncation(), updatedSettings.getTruncation()); + } else { + assertEquals(newSettings.getTruncation(), updatedSettings.getTruncation()); + } + } + public void testFromMap_CreatesEmptySettings_WhenAllFieldsAreNull() { MatcherAssert.assertThat( CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of())), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankTaskSettingsTests.java new file mode 100644 index 0000000000000..6924ee05ecbb8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankTaskSettingsTests.java @@ -0,0 +1,154 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.rerank; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; + +public class CohereRerankTaskSettingsTests extends AbstractWireSerializingTestCase { + + public static CohereRerankTaskSettings createRandom() { + var returnDocuments = randomBoolean() ? randomBoolean() : null; + var topNDocsOnly = randomBoolean() ? randomIntBetween(1, 10) : null; + var maxChunksPerDoc = randomBoolean() ? randomIntBetween(1, 20) : null; + + return new CohereRerankTaskSettings(topNDocsOnly, returnDocuments, maxChunksPerDoc); + } + + public void testFromMap_WithValidValues_ReturnsSettings() { + Map taskMap = Map.of( + CohereRerankTaskSettings.RETURN_DOCUMENTS, + true, + CohereRerankTaskSettings.TOP_N_DOCS_ONLY, + 5, + CohereRerankTaskSettings.MAX_CHUNKS_PER_DOC, + 10 + ); + var settings = CohereRerankTaskSettings.fromMap(new HashMap<>(taskMap)); + assertTrue(settings.getReturnDocuments()); + assertEquals(5, settings.getTopNDocumentsOnly().intValue()); + assertEquals(10, settings.getMaxChunksPerDoc().intValue()); + } + + public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() { + var settings = CohereRerankTaskSettings.fromMap(Map.of()); + assertNull(settings.getReturnDocuments()); + assertNull(settings.getTopNDocumentsOnly()); + assertNull(settings.getMaxChunksPerDoc()); + } + + public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() { + Map taskMap = Map.of( + CohereRerankTaskSettings.RETURN_DOCUMENTS, + "invalid", + CohereRerankTaskSettings.TOP_N_DOCS_ONLY, + 5, + CohereRerankTaskSettings.MAX_CHUNKS_PER_DOC, + 10 + ); + var thrownException = expectThrows(ValidationException.class, () -> CohereRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [return_documents] is not of the expected type")); + } + + public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() { + Map taskMap = Map.of( + CohereRerankTaskSettings.RETURN_DOCUMENTS, + true, + CohereRerankTaskSettings.TOP_N_DOCS_ONLY, + "invalid", + CohereRerankTaskSettings.MAX_CHUNKS_PER_DOC, + 10 + ); + var thrownException = expectThrows(ValidationException.class, () -> CohereRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [top_n] is not of the expected type")); + } + + public void testFromMap_WithInvalidMaxChunksPerDoc_ThrowsValidationException() { + Map taskMap = Map.of( + CohereRerankTaskSettings.RETURN_DOCUMENTS, + true, + CohereRerankTaskSettings.TOP_N_DOCS_ONLY, + 5, + CohereRerankTaskSettings.MAX_CHUNKS_PER_DOC, + "invalid" + ); + var thrownException = expectThrows(ValidationException.class, () -> CohereRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [max_chunks_per_doc] is not of the expected type")); + } + + public void UpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() { + var initialSettings = new CohereRerankTaskSettings(5, true, 10); + CohereRerankTaskSettings updatedSettings = (CohereRerankTaskSettings) initialSettings.updatedTaskSettings(Map.of()); + assertEquals(initialSettings, updatedSettings); + } + + public void testUpdatedTaskSettings_WithNewReturnDocuments_ReturnsUpdatedSettings() { + var initialSettings = new CohereRerankTaskSettings(5, true, 10); + Map newSettings = Map.of(CohereRerankTaskSettings.RETURN_DOCUMENTS, false); + CohereRerankTaskSettings updatedSettings = (CohereRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertFalse(updatedSettings.getReturnDocuments()); + assertEquals(initialSettings.getTopNDocumentsOnly(), updatedSettings.getTopNDocumentsOnly()); + assertEquals(initialSettings.getMaxChunksPerDoc(), updatedSettings.getMaxChunksPerDoc()); + } + + public void testUpdatedTaskSettings_WithNewTopNDocsOnly_ReturnsUpdatedSettings() { + var initialSettings = new CohereRerankTaskSettings(5, true, 10); + Map newSettings = Map.of(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, 7); + CohereRerankTaskSettings updatedSettings = (CohereRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertEquals(7, updatedSettings.getTopNDocumentsOnly().intValue()); + assertEquals(initialSettings.getReturnDocuments(), updatedSettings.getReturnDocuments()); + assertEquals(initialSettings.getMaxChunksPerDoc(), updatedSettings.getMaxChunksPerDoc()); + } + + public void testUpdatedTaskSettings_WithNewMaxChunksPerDoc_ReturnsUpdatedSettings() { + var initialSettings = new CohereRerankTaskSettings(5, true, 10); + Map newSettings = Map.of(CohereRerankTaskSettings.MAX_CHUNKS_PER_DOC, 15); + CohereRerankTaskSettings updatedSettings = (CohereRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertEquals(15, updatedSettings.getMaxChunksPerDoc().intValue()); + assertEquals(initialSettings.getReturnDocuments(), updatedSettings.getReturnDocuments()); + assertEquals(initialSettings.getTopNDocumentsOnly(), updatedSettings.getTopNDocumentsOnly()); + } + + public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() { + var initialSettings = new CohereRerankTaskSettings(5, true, 10); + Map newSettings = Map.of( + CohereRerankTaskSettings.RETURN_DOCUMENTS, + false, + CohereRerankTaskSettings.TOP_N_DOCS_ONLY, + 7, + CohereRerankTaskSettings.MAX_CHUNKS_PER_DOC, + 15 + ); + CohereRerankTaskSettings updatedSettings = (CohereRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertFalse(updatedSettings.getReturnDocuments()); + assertEquals(7, updatedSettings.getTopNDocumentsOnly().intValue()); + assertEquals(15, updatedSettings.getMaxChunksPerDoc().intValue()); + } + + @Override + protected Writeable.Reader instanceReader() { + return CohereRerankTaskSettings::new; + } + + @Override + protected CohereRerankTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected CohereRerankTaskSettings mutateInstance(CohereRerankTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, CohereRerankTaskSettingsTests::createRandom); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettingsTests.java index 72e6daa911c1d..4207896fc54f3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettingsTests.java @@ -15,7 +15,9 @@ import org.elasticsearch.xcontent.XContentType; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; +import java.util.Map; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.sameInstance; @@ -28,6 +30,23 @@ public void testIsEmpty() { assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); } + public void testUpdatedTaskSettings() { + var initialSettings = createRandom(); + var newSettings = createRandom(); + Map newSettingsMap = new HashMap<>(); + if (newSettings.returnDocuments() != null) { + newSettingsMap.put(CustomElandRerankTaskSettings.RETURN_DOCUMENTS, newSettings.returnDocuments()); + } + CustomElandRerankTaskSettings updatedSettings = (CustomElandRerankTaskSettings) initialSettings.updatedTaskSettings( + Collections.unmodifiableMap(newSettingsMap) + ); + if (newSettings.returnDocuments() == null) { + assertEquals(initialSettings.returnDocuments(), updatedSettings.returnDocuments()); + } else { + assertEquals(newSettings.returnDocuments(), updatedSettings.returnDocuments()); + } + } + public void testDefaultsFromMap_MapIsNull_ReturnsDefaultSettings() { var customElandRerankTaskSettings = CustomElandRerankTaskSettings.defaultsFromMap(null); @@ -69,18 +88,6 @@ public void testToXContent_WritesAllValues() throws IOException { {"return_documents":true}""")); } - public void testToXContent_DoesNotWriteReturnDocuments_IfNull() throws IOException { - Boolean bool = null; - var serviceSettings = new CustomElandRerankTaskSettings(bool); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - serviceSettings.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {}""")); - } - public void testOf_PrefersNonNullRequestTaskSettings() { var originalSettings = new CustomElandRerankTaskSettings(Boolean.FALSE); var requestTaskSettings = new CustomElandRerankTaskSettings(Boolean.TRUE); @@ -90,16 +97,6 @@ public void testOf_PrefersNonNullRequestTaskSettings() { assertThat(taskSettings, sameInstance(requestTaskSettings)); } - public void testOf_UseOriginalSettings_IfRequestSettingsValuesAreNull() { - Boolean bool = null; - var originalSettings = new CustomElandRerankTaskSettings(Boolean.TRUE); - var requestTaskSettings = new CustomElandRerankTaskSettings(bool); - - var taskSettings = CustomElandRerankTaskSettings.of(originalSettings, requestTaskSettings); - - assertThat(taskSettings, sameInstance(originalSettings)); - } - private static CustomElandRerankTaskSettings createRandom() { return new CustomElandRerankTaskSettings(randomOptionalBoolean()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettingsTests.java index 95d3522b863a9..90738d43aacb3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettingsTests.java @@ -29,6 +29,15 @@ public static GoogleVertexAiSecretSettings createRandom() { return new GoogleVertexAiSecretSettings(randomSecureStringOfLength(30)); } + public void testNewSecretSettings() { + GoogleVertexAiSecretSettings initialSettings = createRandom(); + GoogleVertexAiSecretSettings newSettings = createRandom(); + GoogleVertexAiSecretSettings newGoogleVertexAiSecretSettings = (GoogleVertexAiSecretSettings) initialSettings.newSecretSettings( + Map.of(GoogleVertexAiSecretSettings.SERVICE_ACCOUNT_JSON, newSettings.serviceAccountJson.toString()) + ); + assertEquals(newSettings, newGoogleVertexAiSecretSettings); + } + public void testFromMap_ReturnsNull_WhenMapIsNUll() { assertNull(GoogleVertexAiSecretSettings.fromMap(null)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsTaskSettingsTests.java index ac7e9348b370b..5b87bbc3c42c8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsTaskSettingsTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -31,6 +32,23 @@ public void testIsEmpty() { assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); } + public void testUpdatedTaskSettings() { + var initialSettings = createRandom(); + var newSettings = createRandom(); + Map newSettingsMap = new HashMap<>(); + if (newSettings.autoTruncate() != null) { + newSettingsMap.put(GoogleVertexAiEmbeddingsTaskSettings.AUTO_TRUNCATE, newSettings.autoTruncate()); + } + GoogleVertexAiEmbeddingsTaskSettings updatedSettings = (GoogleVertexAiEmbeddingsTaskSettings) initialSettings.updatedTaskSettings( + Collections.unmodifiableMap(newSettingsMap) + ); + if (newSettings.autoTruncate() == null) { + assertEquals(initialSettings.autoTruncate(), updatedSettings.autoTruncate()); + } else { + assertEquals(newSettings.autoTruncate(), updatedSettings.autoTruncate()); + } + } + public void testFromMap_AutoTruncateIsSet() { var autoTruncate = true; var taskSettingsMap = getTaskSettingsMap(autoTruncate); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleVertexAiRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleVertexAiRerankTaskSettingsTests.java index 03f89b6a2c042..957defb54d846 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleVertexAiRerankTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleVertexAiRerankTaskSettingsTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -33,6 +34,23 @@ public void testIsEmpty() { assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); } + public void testUpdatedTaskSettings() { + var initialSettings = createRandom(); + var newSettings = createRandom(); + Map newSettingsMap = new HashMap<>(); + if (newSettings.topN() != null) { + newSettingsMap.put(GoogleVertexAiRerankTaskSettings.TOP_N, newSettings.topN()); + } + GoogleVertexAiRerankTaskSettings updatedSettings = (GoogleVertexAiRerankTaskSettings) initialSettings.updatedTaskSettings( + Collections.unmodifiableMap(newSettingsMap) + ); + if (newSettings.topN() == null) { + assertEquals(initialSettings.topN(), updatedSettings.topN()); + } else { + assertEquals(newSettings.topN(), updatedSettings.topN()); + } + } + public void testFromMap_TopNIsSet() { var topN = 1; var taskSettingsMap = getTaskSettingsMap(topN); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettingsTests.java index 16d7e8f1db9be..9d1170bb23dbb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettingsTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -31,6 +32,23 @@ public void testIsEmpty() { assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); } + public void testUpdatedTaskSettings() { + var initialSettings = createRandomWithUser(); + var newSettings = createRandomWithUser(); + Map newSettingsMap = new HashMap<>(); + if (newSettings.user() != null) { + newSettingsMap.put(OpenAiServiceFields.USER, newSettings.user()); + } + OpenAiChatCompletionTaskSettings updatedSettings = (OpenAiChatCompletionTaskSettings) initialSettings.updatedTaskSettings( + Collections.unmodifiableMap(newSettingsMap) + ); + if (newSettings.user() == null) { + assertEquals(initialSettings.user(), updatedSettings.user()); + } else { + assertEquals(newSettings.user(), updatedSettings.user()); + } + } + public void testFromMap_WithUser() { assertEquals( new OpenAiChatCompletionTaskSettings("user"), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java index a5ae2f0a3a44b..0512c36e64de5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java @@ -14,9 +14,11 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings; import org.hamcrest.MatcherAssert; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -37,11 +39,28 @@ public static OpenAiEmbeddingsTaskSettings createRandom() { } public void testIsEmpty() { - var randomSettings = createRandom(); + var randomSettings = new OpenAiChatCompletionTaskSettings(randomBoolean() ? null : "username"); var stringRep = Strings.toString(randomSettings); assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); } + public void testUpdatedTaskSettings() { + var initialSettings = createRandom(); + var newSettings = createRandom(); + Map newSettingsMap = new HashMap<>(); + if (newSettings.user() != null) { + newSettingsMap.put(OpenAiServiceFields.USER, newSettings.user()); + } + OpenAiEmbeddingsTaskSettings updatedSettings = (OpenAiEmbeddingsTaskSettings) initialSettings.updatedTaskSettings( + Collections.unmodifiableMap(newSettingsMap) + ); + if (newSettings.user() == null) { + assertEquals(initialSettings.user(), updatedSettings.user()); + } else { + assertEquals(newSettings.user(), updatedSettings.user()); + } + } + public void testFromMap_WithUser() { assertEquals( new OpenAiEmbeddingsTaskSettings("user"), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettingsTests.java index 212a867349e5c..118cf25a452a7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettingsTests.java @@ -26,6 +26,15 @@ public static DefaultSecretSettings createRandom() { return new DefaultSecretSettings(new SecureString(randomAlphaOfLength(15).toCharArray())); } + public void testNewSecretSettings() { + DefaultSecretSettings initialSettings = createRandom(); + DefaultSecretSettings newSettings = createRandom(); + DefaultSecretSettings finalSettings = (DefaultSecretSettings) initialSettings.newSecretSettings( + Map.of(DefaultSecretSettings.API_KEY, newSettings.apiKey().toString()) + ); + assertEquals(newSettings, finalSettings); + } + public void testFromMap() { var apiKey = "abc"; var serviceSettings = DefaultSecretSettings.fromMap(new HashMap<>(Map.of(DefaultSecretSettings.API_KEY, apiKey))); diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 853d0fd9318ae..d791873eb3142 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -174,6 +174,7 @@ public class Constants { "cluster:admin/xpack/enrich/reindex", "cluster:admin/xpack/inference/delete", "cluster:admin/xpack/inference/put", + "cluster:admin/xpack/inference/update", "cluster:admin/xpack/license/basic_status", // "cluster:admin/xpack/license/delete", "cluster:admin/xpack/license/feature_usage",