diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModel.java index 48e32c741a601..0975f8616da03 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModel.java @@ -116,7 +116,7 @@ public SageMakerModel override(Map taskSettingsOverride) { getConfigurations(), getSecrets(), serviceSettings, - taskSettings.updatedTaskSettings(taskSettingsOverride), + taskSettings.override(taskSettingsOverride), awsSecretSettings ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerTaskSettings.java index fd9eb2d20c5d3..a36944c51f104 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerTaskSettings.java @@ -71,11 +71,21 @@ public boolean isEmpty() { @Override public SageMakerTaskSettings updatedTaskSettings(Map newSettings) { var validationException = new ValidationException(); - var updateTaskSettings = fromMap(newSettings, apiTaskSettings.updatedTaskSettings(newSettings), validationException); + validationException.throwIfValidationErrorsExist(); + + return override(updateTaskSettings); + } + public SageMakerTaskSettings override(Map newSettings) { + var validationException = new ValidationException(); + var updateTaskSettings = fromMap(newSettings, apiTaskSettings.override(newSettings), validationException); validationException.throwIfValidationErrorsExist(); + return override(updateTaskSettings); + } + + private SageMakerTaskSettings override(SageMakerTaskSettings updateTaskSettings) { var updatedExtraTaskSettings = updateTaskSettings.apiTaskSettings().equals(SageMakerStoredTaskSchema.NO_OP) ? apiTaskSettings : updateTaskSettings.apiTaskSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredTaskSchema.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredTaskSchema.java index 09a73f0f42ea4..a3ff632f466c2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredTaskSchema.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredTaskSchema.java @@ -68,4 +68,8 @@ default boolean isFragment() { @Override SageMakerStoredTaskSchema updatedTaskSettings(Map newSettings); + + default SageMakerStoredTaskSchema override(Map newSettings) { + return updatedTaskSettings(newSettings); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayload.java index 46c5a9eb30a9a..781b1e906a17f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayload.java @@ -88,12 +88,6 @@ default SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest re @Override default SageMakerElasticTaskSettings apiTaskSettings(Map taskSettings, ValidationException validationException) { - if (taskSettings != null && (taskSettings.isEmpty() == false)) { - validationException.addValidationError( - InferenceAction.Request.TASK_SETTINGS.getPreferredName() - + " is only supported during the inference request and cannot be stored in the inference endpoint." - ); - } return SageMakerElasticTaskSettings.empty(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettings.java index 088de2068741c..dc0bc91fccd75 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettings.java @@ -9,10 +9,12 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema; import java.io.IOException; @@ -40,6 +42,16 @@ public boolean isEmpty() { @Override public SageMakerStoredTaskSchema updatedTaskSettings(Map newSettings) { + var validationException = new ValidationException(); + validationException.addValidationError( + InferenceAction.Request.TASK_SETTINGS.getPreferredName() + + " is only supported during the inference request and cannot be stored in the inference endpoint." + ); + throw validationException; + } + + @Override + public SageMakerStoredTaskSchema override(Map newSettings) { return new SageMakerElasticTaskSettings(newSettings); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java index adffbb366fb02..90b5042d3dec4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java @@ -119,7 +119,7 @@ public final void testWithUnknownApiTaskSettings() { } } - public final void testUpdate() throws IOException { + public void testUpdate() throws IOException { var taskSettings = randomApiTaskSettings(); if (taskSettings != SageMakerStoredTaskSchema.NO_OP) { var otherTaskSettings = randomValueOtherThan(taskSettings, this::randomApiTaskSettings); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayloadTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayloadTestCase.java index 65dcd62bb149a..9e4cfc52e9568 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayloadTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayloadTestCase.java @@ -18,8 +18,8 @@ import java.util.List; import java.util.Map; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; +import static org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase.toMap; +import static org.hamcrest.Matchers.containsString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -50,6 +50,7 @@ protected SageMakerModel mockModel(SageMakerElasticTaskSettings taskSettings) { return model; } + @Override public void testApiTaskSettings() { { var validationException = new ValidationException(); @@ -67,14 +68,21 @@ public void testApiTaskSettings() { var validationException = new ValidationException(); var actualApiTaskSettings = payload.apiTaskSettings(Map.of("hello", "world"), validationException); assertTrue(actualApiTaskSettings.isEmpty()); - assertFalse(validationException.validationErrors().isEmpty()); - assertThat( - validationException.validationErrors().get(0), - is(equalTo("task_settings is only supported during the inference request and cannot be stored in the inference endpoint.")) - ); + assertTrue(validationException.validationErrors().isEmpty()); } } + @Override + public void testUpdate() { + var taskSettings = randomApiTaskSettings(); + var otherTaskSettings = randomValueOtherThan(taskSettings, this::randomApiTaskSettings); + var e = assertThrows(ValidationException.class, () -> taskSettings.updatedTaskSettings(toMap(otherTaskSettings))); + assertThat( + e.getMessage(), + containsString("task_settings is only supported during the inference request and cannot be stored in the inference endpoint") + ); + } + public void testRequestWithRequiredFields() throws Exception { var request = new SageMakerInferenceRequest(null, null, null, List.of("hello"), false, InputType.UNSPECIFIED); var sdkByes = payload.requestBytes(mockModel(), request);