Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ public SageMakerModel override(Map<String, Object> taskSettingsOverride) {
getConfigurations(),
getSecrets(),
serviceSettings,
taskSettings.updatedTaskSettings(taskSettingsOverride),
taskSettings.override(taskSettingsOverride),
awsSecretSettings
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,21 @@ public boolean isEmpty() {
@Override
public SageMakerTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
var validationException = new ValidationException();

var updateTaskSettings = fromMap(newSettings, apiTaskSettings.updatedTaskSettings(newSettings), validationException);
validationException.throwIfValidationErrorsExist();

return override(updateTaskSettings);
}

public SageMakerTaskSettings override(Map<String, Object> newSettings) {
var validationException = new ValidationException();
var updateTaskSettings = fromMap(newSettings, apiTaskSettings.override(newSettings), validationException);
validationException.throwIfValidationErrorsExist();

return override(updateTaskSettings);
}

private SageMakerTaskSettings override(SageMakerTaskSettings updateTaskSettings) {
var updatedExtraTaskSettings = updateTaskSettings.apiTaskSettings().equals(SageMakerStoredTaskSchema.NO_OP)
? apiTaskSettings
: updateTaskSettings.apiTaskSettings();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,8 @@ default boolean isFragment() {

@Override
SageMakerStoredTaskSchema updatedTaskSettings(Map<String, Object> newSettings);

default SageMakerStoredTaskSchema override(Map<String, Object> newSettings) {
return updatedTaskSettings(newSettings);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,6 @@ default SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest re

@Override
default SageMakerElasticTaskSettings apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
if (taskSettings != null && (taskSettings.isEmpty() == false)) {
validationException.addValidationError(
InferenceAction.Request.TASK_SETTINGS.getPreferredName()
+ " is only supported during the inference request and cannot be stored in the inference endpoint."
);
}
return SageMakerElasticTaskSettings.empty();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;

import java.io.IOException;
Expand Down Expand Up @@ -40,6 +42,16 @@ public boolean isEmpty() {

@Override
public SageMakerStoredTaskSchema updatedTaskSettings(Map<String, Object> newSettings) {
var validationException = new ValidationException();
validationException.addValidationError(
InferenceAction.Request.TASK_SETTINGS.getPreferredName()
+ " is only supported during the inference request and cannot be stored in the inference endpoint."
);
throw validationException;
}

@Override
public SageMakerStoredTaskSchema override(Map<String, Object> newSettings) {
return new SageMakerElasticTaskSettings(newSettings);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public final void testWithUnknownApiTaskSettings() {
}
}

public final void testUpdate() throws IOException {
public void testUpdate() throws IOException {
var taskSettings = randomApiTaskSettings();
if (taskSettings != SageMakerStoredTaskSchema.NO_OP) {
var otherTaskSettings = randomValueOtherThan(taskSettings, this::randomApiTaskSettings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import java.util.List;
import java.util.Map;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase.toMap;
import static org.hamcrest.Matchers.containsString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -50,6 +50,7 @@ protected SageMakerModel mockModel(SageMakerElasticTaskSettings taskSettings) {
return model;
}

@Override
public void testApiTaskSettings() {
{
var validationException = new ValidationException();
Expand All @@ -67,14 +68,21 @@ public void testApiTaskSettings() {
var validationException = new ValidationException();
var actualApiTaskSettings = payload.apiTaskSettings(Map.of("hello", "world"), validationException);
assertTrue(actualApiTaskSettings.isEmpty());
assertFalse(validationException.validationErrors().isEmpty());
assertThat(
validationException.validationErrors().get(0),
is(equalTo("task_settings is only supported during the inference request and cannot be stored in the inference endpoint."))
);
assertTrue(validationException.validationErrors().isEmpty());
}
}

@Override
public void testUpdate() {
var taskSettings = randomApiTaskSettings();
var otherTaskSettings = randomValueOtherThan(taskSettings, this::randomApiTaskSettings);
var e = assertThrows(ValidationException.class, () -> taskSettings.updatedTaskSettings(toMap(otherTaskSettings)));
assertThat(
e.getMessage(),
containsString("task_settings is only supported during the inference request and cannot be stored in the inference endpoint")
);
}

public void testRequestWithRequiredFields() throws Exception {
var request = new SageMakerInferenceRequest(null, null, null, List.of("hello"), false, InputType.UNSPECIFIED);
var sdkByes = payload.requestBytes(mockModel(), request);
Expand Down