Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/changelog/114176.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
pr: 114176
summary: "[Inference API ] Add `endpoint_version` to deprecate task settings"
area: Machine Learning
type: deprecation
issues: []
deprecation:
title: "[Inference API ] Add `endpoint_version` to deprecate task settings"
area: Machine Learning
details: Please describe the details of this change for the release notes. You can
use asciidoc.
impact: Please describe the impact of this change to users
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ static TransportVersion def(int id) {
public static final TransportVersion RRF_QUERY_REWRITE = def(8_758_00_0);
public static final TransportVersion SEARCH_FAILURE_STATS = def(8_759_00_0);
public static final TransportVersion INGEST_GEO_DATABASE_PROVIDERS = def(8_760_00_0);
public static final TransportVersion INFERENCE_API_PARAMATERS_INTRODUCED = def(8_761_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,40 +36,55 @@ default void init(Client client) {}
* If the map contains unrecognized configuration option an
* {@code ElasticsearchStatusException} is thrown.
*
* @param modelId Model Id
* @param taskType The model task type
* @param config Configuration options including the secrets
* @param parsedModelListener A listener which will handle the resulting model or failure
* @param modelId Model Id
* @param taskType The model task type
* @param config Configuration options including the secrets
* @param endpointVersion
* @param parsedModelListener A listener which will handle the resulting model or failure
*/
void parseRequestConfig(String modelId, TaskType taskType, Map<String, Object> config, ActionListener<Model> parsedModelListener);
void parseRequestConfig(
String modelId,
TaskType taskType,
Map<String, Object> config,
String endpointVersion,
ActionListener<Model> parsedModelListener
);

/**
* Parse model configuration from {@code config map} from persisted storage and return the parsed {@link Model}. This requires that
* secrets and service settings be in two separate maps.
* This function modifies {@code config map}, fields are removed from the map as they are read.
*
* <p>
* If the map contains unrecognized configuration options, no error is thrown.
*
* @param modelId Model Id
* @param taskType The model task type
* @param config Configuration options
* @param secrets Sensitive configuration options (e.g. api key)
* @param modelId Model Id
* @param taskType The model task type
* @param config Configuration options
* @param secrets Sensitive configuration options (e.g. api key)
* @param endpointVersion
* @return The parsed {@link Model}
*/
Model parsePersistedConfigWithSecrets(String modelId, TaskType taskType, Map<String, Object> config, Map<String, Object> secrets);
Model parsePersistedConfigWithSecrets(
String modelId,
TaskType taskType,
Map<String, Object> config,
Map<String, Object> secrets,
String endpointVersion
);

/**
* Parse model configuration from {@code config map} from persisted storage and return the parsed {@link Model}.
* This function modifies {@code config map}, fields are removed from the map as they are read.
*
* <p>
* If the map contains unrecognized configuration options, no error is thrown.
*
* @param modelId Model Id
* @param taskType The model task type
* @param config Configuration options
* @param modelId Model Id
* @param taskType The model task type
* @param config Configuration options
* @param endpointVersion
* @return The parsed {@link Model}
*/
Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config);
Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config, String endpointVersion);

/**
* Perform inference on the model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,13 @@ public class ModelConfigurations implements ToFilteredXContentObject, VersionedN
public static final String USE_ID_FOR_INDEX = "for_index";
public static final String SERVICE = "service";
public static final String SERVICE_SETTINGS = "service_settings";
public static final String TASK_SETTINGS = "task_settings";
public static final String OLD_TASK_SETTINGS = "task_settings";
public static final String PARAMETERS = "parameters";
public static final String CHUNKING_SETTINGS = "chunking_settings";
public static final String INCLUDE_PARAMETERS = "include_parameters";
public static final String ENDPOINT_VERSION_FIELD_NAME = "endpoint_version";
public static final String FIRST_ENDPOINT_VERSION = "2023-09-29";
public static final String PARAMETERS_INTRODUCED_ENDPOINT_VERSION = "2024-10-17";
private static final String NAME = "inference_model";

public static ModelConfigurations of(Model model, TaskSettings taskSettings) {
Expand All @@ -42,7 +47,8 @@ public static ModelConfigurations of(Model model, TaskSettings taskSettings) {
model.getConfigurations().getService(),
model.getServiceSettings(),
taskSettings,
model.getConfigurations().getChunkingSettings()
model.getConfigurations().getChunkingSettings(),
model.getConfigurations().getEndpointVersion()
);
}

Expand All @@ -56,7 +62,8 @@ public static ModelConfigurations of(Model model, ServiceSettings serviceSetting
model.getConfigurations().getService(),
serviceSettings,
model.getTaskSettings(),
model.getConfigurations().getChunkingSettings()
model.getConfigurations().getChunkingSettings(),
model.getConfigurations().getEndpointVersion()
);
}

Expand All @@ -66,36 +73,46 @@ public static ModelConfigurations of(Model model, ServiceSettings serviceSetting
private final ServiceSettings serviceSettings;
private final TaskSettings taskSettings;
private final ChunkingSettings chunkingSettings;
private final String endpointVersion;

/**
* Allows no task settings to be defined. This will default to the {@link EmptyTaskSettings} object.
*/
public ModelConfigurations(String inferenceEntityId, TaskType taskType, String service, ServiceSettings serviceSettings) {
this(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE);
public ModelConfigurations(
String inferenceEntityId,
TaskType taskType,
String service,
ServiceSettings serviceSettings,
String endpointVersion
) {
this(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, endpointVersion);
}

public ModelConfigurations(
String inferenceEntityId,
TaskType taskType,
String service,
ServiceSettings serviceSettings,
ChunkingSettings chunkingSettings
ChunkingSettings chunkingSettings,
String endpointVersion
) {
this(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings);
this(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings, endpointVersion);
}

public ModelConfigurations(
String inferenceEntityId,
TaskType taskType,
String service,
ServiceSettings serviceSettings,
TaskSettings taskSettings
TaskSettings taskSettings,
String endpointVersion
) {
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
this.taskType = Objects.requireNonNull(taskType);
this.service = Objects.requireNonNull(service);
this.serviceSettings = Objects.requireNonNull(serviceSettings);
this.taskSettings = Objects.requireNonNull(taskSettings);
this.endpointVersion = endpointVersion;
this.chunkingSettings = null;
}

Expand All @@ -105,14 +122,16 @@ public ModelConfigurations(
String service,
ServiceSettings serviceSettings,
TaskSettings taskSettings,
ChunkingSettings chunkingSettings
ChunkingSettings chunkingSettings,
String endpointVersion
) {
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
this.taskType = Objects.requireNonNull(taskType);
this.service = Objects.requireNonNull(service);
this.serviceSettings = Objects.requireNonNull(serviceSettings);
this.taskSettings = Objects.requireNonNull(taskSettings);
this.chunkingSettings = chunkingSettings;
this.endpointVersion = endpointVersion;
}

public ModelConfigurations(StreamInput in) throws IOException {
Expand All @@ -124,6 +143,9 @@ public ModelConfigurations(StreamInput in) throws IOException {
this.chunkingSettings = in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CHUNKING_SETTINGS)
? in.readOptionalNamedWriteable(ChunkingSettings.class)
: null;
this.endpointVersion = in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_API_PARAMATERS_INTRODUCED)
? Objects.requireNonNullElse(in.readOptionalString(), FIRST_ENDPOINT_VERSION)
: FIRST_ENDPOINT_VERSION;
}

@Override
Expand All @@ -136,6 +158,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CHUNKING_SETTINGS)) {
out.writeOptionalNamedWriteable(chunkingSettings);
}
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_API_PARAMATERS_INTRODUCED)) {
out.writeOptionalString(endpointVersion); // not nullable after 9.0
}
}

public String getInferenceEntityId() {
Expand All @@ -162,6 +187,10 @@ public ChunkingSettings getChunkingSettings() {
return chunkingSettings;
}

public String getEndpointVersion() {
return endpointVersion;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand All @@ -173,10 +202,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(TaskType.NAME, taskType.toString());
builder.field(SERVICE, service);
builder.field(SERVICE_SETTINGS, serviceSettings);
builder.field(TASK_SETTINGS, taskSettings);
builder.field(OLD_TASK_SETTINGS, taskSettings);
if (chunkingSettings != null) {
builder.field(CHUNKING_SETTINGS, chunkingSettings);
}
if (params.paramAsBoolean(INCLUDE_PARAMETERS, true)) { // default true so that REST requests get parameters
builder.field(PARAMETERS, taskSettings);
}
builder.field(ENDPOINT_VERSION_FIELD_NAME, endpointVersion);
builder.endObject();
return builder;
}
Expand All @@ -192,10 +225,14 @@ public XContentBuilder toFilteredXContent(XContentBuilder builder, Params params
builder.field(TaskType.NAME, taskType.toString());
builder.field(SERVICE, service);
builder.field(SERVICE_SETTINGS, serviceSettings.getFilteredXContentObject());
builder.field(TASK_SETTINGS, taskSettings);
builder.field(OLD_TASK_SETTINGS, taskSettings);
if (chunkingSettings != null) {
builder.field(CHUNKING_SETTINGS, chunkingSettings);
}
if (params.paramAsBoolean(INCLUDE_PARAMETERS, true)) { // default true so that REST requests get parameters
builder.field(PARAMETERS, taskSettings);
}
builder.field(ENDPOINT_VERSION_FIELD_NAME, endpointVersion);
builder.endObject();
return builder;
}
Expand All @@ -219,11 +256,12 @@ public boolean equals(Object o) {
&& taskType == model.taskType
&& Objects.equals(service, model.service)
&& Objects.equals(serviceSettings, model.serviceSettings)
&& Objects.equals(taskSettings, model.taskSettings);
&& Objects.equals(taskSettings, model.taskSettings)
&& Objects.equals(endpointVersion, model.endpointVersion);
}

@Override
public int hashCode() {
return Objects.hash(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
return Objects.hash(inferenceEntityId, taskType, service, serviceSettings, taskSettings, endpointVersion);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ public record UnparsedModel(
TaskType taskType,
String service,
Map<String, Object> settings,
Map<String, Object> secrets
Map<String, Object> secrets,
String endpointVersion
) {}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
Expand All @@ -17,15 +18,24 @@
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.MlStrings;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.inference.ModelConfigurations.ENDPOINT_VERSION_FIELD_NAME;
import static org.elasticsearch.inference.ModelConfigurations.FIRST_ENDPOINT_VERSION;
import static org.elasticsearch.inference.ModelConfigurations.OLD_TASK_SETTINGS;
import static org.elasticsearch.inference.ModelConfigurations.PARAMETERS;
import static org.elasticsearch.inference.ModelConfigurations.PARAMETERS_INTRODUCED_ENDPOINT_VERSION;

public class PutInferenceModelAction extends ActionType<PutInferenceModelAction.Response> {

public static final PutInferenceModelAction INSTANCE = new PutInferenceModelAction();
Expand All @@ -40,6 +50,7 @@ public static class Request extends AcknowledgedRequest<Request> {
private final TaskType taskType;
private final String inferenceEntityId;
private final BytesReference content;
private BytesReference rewrittenContent;
private final XContentType contentType;

public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType) {
Expand Down Expand Up @@ -70,6 +81,33 @@ public BytesReference getContent() {
return content;
}

public BytesReference getRewrittenContent() {
if (rewrittenContent == null) { // rewrittenContent is deterministic on content, so we only need to calculate it once
Map<String, Object> newContent = XContentHelper.convertToMap(content, false, contentType).v2();
if (newContent.containsKey(PARAMETERS) && newContent.containsKey(OLD_TASK_SETTINGS)) {
throw new ElasticsearchStatusException(
"Request cannot contain both [task_settings] and [parameters], use only [parameters]",
RestStatus.BAD_REQUEST
);
} else if (newContent.containsKey(PARAMETERS)) {
newContent.put(OLD_TASK_SETTINGS, newContent.get(PARAMETERS));
newContent.put(ENDPOINT_VERSION_FIELD_NAME, PARAMETERS_INTRODUCED_ENDPOINT_VERSION);
newContent.remove(PARAMETERS);
} else if (newContent.containsKey(OLD_TASK_SETTINGS)) {
newContent.put(ENDPOINT_VERSION_FIELD_NAME, FIRST_ENDPOINT_VERSION);
} else {
newContent.put(ENDPOINT_VERSION_FIELD_NAME, FIRST_ENDPOINT_VERSION);
}
try (XContentBuilder builder = XContentFactory.contentBuilder(this.contentType)) {
builder.map(newContent);
this.rewrittenContent = BytesReference.bytes(builder);
} catch (IOException e) {
throw new ElasticsearchStatusException("Failed to parse rewritten request", RestStatus.INTERNAL_SERVER_ERROR, e);
}
}
return rewrittenContent;
}

public XContentType getContentType() {
return contentType;
}
Expand Down
Loading
Loading