Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/changelog/114176.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
pr: 114176
summary: "[Inference API] Deprecate task_settings, renamed to parameters"
area: Machine Learning
type: deprecation
issues: []
deprecation:
title: "[Inference API] Deprecate task_settings, renamed to parameters"
area: REST API
details: In 8.16 the inference API is renaming the `task_settings` component of inference endpoints (used in the Create _inference API, and GET _inference API) to `parameters`. Users are asked to update any code accessing or creating `task_settings` to use `parameters` instead. Support for requests and responses including `task_settings` will be removed in 9.0.
impact: The inference API is maintaing backwards compatibility until 9.0, but we now recommend replacing usages of `task_settings` with `parameters`.
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ static TransportVersion def(int id) {
public static final TransportVersion DATE_TIME_DOC_VALUES_LOCALES = def(8_761_00_0);
public static final TransportVersion FAST_REFRESH_RCO = def(8_762_00_0);
public static final TransportVersion TEXT_SIMILARITY_RERANKER_QUERY_REWRITE = def(8_763_00_0);
public static final TransportVersion INFERENCE_API_PARAMATERS_INTRODUCED = def(8_764_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,37 +36,37 @@ default void init(Client client) {}
* If the map contains unrecognized configuration option an
* {@code ElasticsearchStatusException} is thrown.
*
* @param modelId Model Id
* @param taskType The model task type
* @param config Configuration options including the secrets
* @param parsedModelListener A listener which will handle the resulting model or failure
* @param modelId Model Id
* @param taskType The model task type
* @param config Configuration options including the secrets
* @param parsedModelListener A listener which will handle the resulting model or failure
*/
void parseRequestConfig(String modelId, TaskType taskType, Map<String, Object> config, ActionListener<Model> parsedModelListener);

/**
* Parse model configuration from {@code config map} from persisted storage and return the parsed {@link Model}. This requires that
* secrets and service settings be in two separate maps.
* This function modifies {@code config map}, fields are removed from the map as they are read.
*
* <p>
* If the map contains unrecognized configuration options, no error is thrown.
*
* @param modelId Model Id
* @param taskType The model task type
* @param config Configuration options
* @param secrets Sensitive configuration options (e.g. api key)
* @param modelId Model Id
* @param taskType The model task type
* @param config Configuration options
* @param secrets Sensitive configuration options (e.g. api key)
* @return The parsed {@link Model}
*/
Model parsePersistedConfigWithSecrets(String modelId, TaskType taskType, Map<String, Object> config, Map<String, Object> secrets);

/**
* Parse model configuration from {@code config map} from persisted storage and return the parsed {@link Model}.
* This function modifies {@code config map}, fields are removed from the map as they are read.
*
* <p>
* If the map contains unrecognized configuration options, no error is thrown.
*
* @param modelId Model Id
* @param taskType The model task type
* @param config Configuration options
* @param modelId Model Id
* @param taskType The model task type
* @param config Configuration options
* @return The parsed {@link Model}
*/
Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ public class ModelConfigurations implements ToFilteredXContentObject, VersionedN
// is returned as part of a GetInferenceModelAction
public static final String INDEX_ONLY_ID_FIELD_NAME = "model_id";
public static final String INFERENCE_ID_FIELD_NAME = "inference_id";
public static final String USE_ID_FOR_INDEX = "for_index";
public static final String FOR_INDEX = "for_index"; // true if writing to index
public static final String SERVICE = "service";
public static final String SERVICE_SETTINGS = "service_settings";
public static final String TASK_SETTINGS = "task_settings";
public static final String PARAMETERS = "parameters";
public static final String CHUNKING_SETTINGS = "chunking_settings";
private static final String NAME = "inference_model";

Expand Down Expand Up @@ -165,7 +166,7 @@ public ChunkingSettings getChunkingSettings() {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (params.paramAsBoolean(USE_ID_FOR_INDEX, false)) {
if (params.paramAsBoolean(FOR_INDEX, false)) {
builder.field(INDEX_ONLY_ID_FIELD_NAME, inferenceEntityId);
} else {
builder.field(INFERENCE_ID_FIELD_NAME, inferenceEntityId);
Expand All @@ -177,14 +178,19 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (chunkingSettings != null) {
builder.field(CHUNKING_SETTINGS, chunkingSettings);
}
if (params.paramAsBoolean(FOR_INDEX, false)) {
// Don't write parameter to index, but do write parameters the rest of the time
} else {
builder.field(PARAMETERS, taskSettings);
}
builder.endObject();
return builder;
}

@Override
public XContentBuilder toFilteredXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (params.paramAsBoolean(USE_ID_FOR_INDEX, false)) {
if (params.paramAsBoolean(FOR_INDEX, false)) {
builder.field(INDEX_ONLY_ID_FIELD_NAME, inferenceEntityId);
} else {
builder.field(INFERENCE_ID_FIELD_NAME, inferenceEntityId);
Expand All @@ -196,6 +202,11 @@ public XContentBuilder toFilteredXContent(XContentBuilder builder, Params params
if (chunkingSettings != null) {
builder.field(CHUNKING_SETTINGS, chunkingSettings);
}
if (params.paramAsBoolean(FOR_INDEX, false)) {
// Don't write parameter to index, but do write parameters the rest of the time
} else {
builder.field(PARAMETERS, taskSettings);
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,39 @@

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.logging.DeprecationCategory;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.MlStrings;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.inference.ModelConfigurations.PARAMETERS;
import static org.elasticsearch.inference.ModelConfigurations.TASK_SETTINGS;

public class PutInferenceModelAction extends ActionType<PutInferenceModelAction.Response> {

public static final PutInferenceModelAction INSTANCE = new PutInferenceModelAction();
public static final String NAME = "cluster:admin/xpack/inference/put";
private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(PutInferenceModelAction.class);

public PutInferenceModelAction() {
super(NAME);
Expand All @@ -40,6 +50,7 @@ public static class Request extends AcknowledgedRequest<Request> {
private final TaskType taskType;
private final String inferenceEntityId;
private final BytesReference content;
private BytesReference rewrittenContent;
private final XContentType contentType;

public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType) {
Expand Down Expand Up @@ -70,6 +81,36 @@ public BytesReference getContent() {
return content;
}

public BytesReference getRewrittenContent() {
if (rewrittenContent == null) { // rewrittenContent is deterministic on content, so we only need to calculate it once
Map<String, Object> newContent = XContentHelper.convertToMap(content, false, contentType).v2();
if (newContent.containsKey(PARAMETERS) && newContent.containsKey(TASK_SETTINGS)) {
throw new ElasticsearchStatusException(
"Request cannot contain both [task_settings] and [parameters], use only [parameters]",
RestStatus.BAD_REQUEST
);
} else if (newContent.containsKey(TASK_SETTINGS)) {
DEPRECATION_LOGGER.critical(
DeprecationCategory.API,
"inference_api_task_settings_deprecated_use_parameters",
"The [task_settings] field is deprecated and will be removed in a future release. "
+ "Please use only the [parameters] field instead."
);
} else if (newContent.containsKey(PARAMETERS)) {
newContent.put(TASK_SETTINGS, newContent.get(PARAMETERS));
newContent.remove(PARAMETERS);
}

try (XContentBuilder builder = XContentFactory.contentBuilder(this.contentType)) {
builder.map(newContent);
this.rewrittenContent = BytesReference.bytes(builder);
} catch (IOException e) {
throw new ElasticsearchStatusException("Failed to parse rewritten request", RestStatus.INTERNAL_SERVER_ERROR, e);
}
}
return rewrittenContent;
}

public XContentType getContentType() {
return contentType;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,25 @@

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.utils.MlStringsTests;
import org.junit.Before;

import java.io.IOException;
import java.util.Locale;
import java.util.Map;

import static org.elasticsearch.inference.ModelConfigurations.PARAMETERS;
import static org.elasticsearch.inference.ModelConfigurations.TASK_SETTINGS;

public class PutInferenceModelActionTests extends ESTestCase {
public static TaskType TASK_TYPE;
Expand Down Expand Up @@ -57,4 +66,55 @@ public void testValidate() {
validationException = invalidRequest3.validate();
assertNotNull(validationException);
}

public void testWithParameters() throws IOException {
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
Map<String, Object> parametersValues = Map.of("top_n", 1, "top_p", 0.1);
Map<String, Object> serviceSettingsValues = Map.of("model_id", "embed", "dimensions", 1024);
builder.map(Map.of(PARAMETERS, parametersValues, "service", "elasticsearch", "service_settings", serviceSettingsValues));
var request = new PutInferenceModelAction.Request(TASK_TYPE, MODEL_ID, BytesReference.bytes(builder), XContentType.JSON);
Map<String, Object> map = XContentHelper.convertToMap(request.getRewrittenContent(), false, request.getContentType()).v2();
assertEquals(parametersValues, map.get(TASK_SETTINGS));
assertNull(map.get(PARAMETERS));
assertEquals("elasticsearch", map.get("service"));
assertEquals(serviceSettingsValues, map.get("service_settings"));
}

public void testWithParametersAndTaskSettings() throws IOException {
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
Map<String, Object> parametersValues = Map.of("top_n", 1, "top_p", 0.1);
Map<String, Object> taskSettingsValues = Map.of("top_n", 2, "top_p", 0.2);
Map<String, Object> serviceSettingsValues = Map.of("model_id", "embed", "dimensions", 1024);
builder.map(
Map.of(
PARAMETERS,
parametersValues,
TASK_SETTINGS,
taskSettingsValues,
"service",
"elasticsearch",
"service_settings",
serviceSettingsValues
)
);
assertThrows(
ElasticsearchStatusException.class,
() -> new PutInferenceModelAction.Request(TASK_TYPE, MODEL_ID, BytesReference.bytes(builder), XContentType.JSON)
.getRewrittenContent()
);

}

public void testWithTaskSettings() throws IOException {
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
Map<String, Object> taskSettingsValues = Map.of("top_n", 2, "top_p", 0.2);
Map<String, Object> serviceSettingsValues = Map.of("model_id", "embed", "dimensions", 1024);
builder.map(Map.of(TASK_SETTINGS, taskSettingsValues, "service", "elasticsearch", "service_settings", serviceSettingsValues));
var request = new PutInferenceModelAction.Request(TASK_TYPE, MODEL_ID, BytesReference.bytes(builder), XContentType.JSON);
Map<String, Object> map = XContentHelper.convertToMap(request.getContent(), false, request.getContentType()).v2();
assertEquals(taskSettingsValues, map.get(TASK_SETTINGS));
assertNull(map.get(PARAMETERS));
assertEquals("elasticsearch", map.get("service"));
assertEquals(serviceSettingsValues, map.get("service_settings"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseListener;
import org.elasticsearch.client.WarningsHandler;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
Expand Down Expand Up @@ -81,6 +82,45 @@ static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody) {
""", taskType);
}

static String mockSparseServiceModelConfigWithParameters(@Nullable TaskType taskTypeInBody) {
var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
return Strings.format("""
{
%s
"service": "test_service",
"service_settings": {
"model": "my_model",
"hidden_field": "my_hidden_value",
"api_key": "abc64"
},
"parameters": {
"temperature": 3
}
}
""", taskType);
}

static String mockSparseServiceModelConfigWithParametersAndTaskSettings(@Nullable TaskType taskTypeInBody) {
var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
return Strings.format("""
{
%s
"service": "test_service",
"service_settings": {
"model": "my_model",
"hidden_field": "my_hidden_value",
"api_key": "abc64"
},
"parameters": {
"temperature": 3
},
"task_settings": {
"temperature": 3
}
}
""", taskType);
}

static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody) {
var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
return Strings.format("""
Expand Down Expand Up @@ -230,6 +270,8 @@ protected Map<String, Object> putModel(String modelId, String modelConfig) throw
Map<String, Object> putRequest(String endpoint, String body) throws IOException {
var request = new Request("PUT", endpoint);
request.setJsonEntity(body);
request.setOptions(RequestOptions.DEFAULT.toBuilder().setWarningsHandler(WarningsHandler.PERMISSIVE).build()); // TODO remove
// permissive warnings once the deprecation warnings are removed in 9.0
var response = client().performRequest(request);
assertOkOrCreated(response);
return entityAsMap(response);
Expand Down
Loading
Loading