Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK_ADDED_8_19 = def(8_841_0_48);
public static final TransportVersion NONE_CHUNKING_STRATEGY_8_19 = def(8_841_0_49);
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19 = def(8_841_0_50);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 = def(8_841_0_51);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -298,6 +299,7 @@ static TransportVersion def(int id) {
public static final TransportVersion HEAP_USAGE_IN_CLUSTER_INFO = def(9_096_0_00);
public static final TransportVersion NONE_CHUNKING_STRATEGY = def(9_097_0_00);
public static final TransportVersion PROJECT_DELETION_GLOBAL_BLOCK = def(9_098_0_00);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE = def(9_099_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.custom;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
Expand Down Expand Up @@ -51,6 +52,27 @@ public CustomModel(
);
}

public CustomModel(
String inferenceId,
TaskType taskType,
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
@Nullable Map<String, Object> secrets,
@Nullable ChunkingSettings chunkingSettings,
ConfigurationParseContext context
) {
this(
inferenceId,
taskType,
service,
CustomServiceSettings.fromMap(serviceSettings, context, taskType, inferenceId),
CustomTaskSettings.fromMap(taskSettings),
CustomSecretSettings.fromMap(secrets),
chunkingSettings
);
}

// should only be used for testing
CustomModel(
String inferenceId,
Expand All @@ -67,6 +89,23 @@ public CustomModel(
);
}

// should only be used for testing
CustomModel(
String inferenceId,
TaskType taskType,
String service,
CustomServiceSettings serviceSettings,
CustomTaskSettings taskSettings,
@Nullable CustomSecretSettings secretSettings,
@Nullable ChunkingSettings chunkingSettings
) {
this(
new ModelConfigurations(inferenceId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
new ModelSecrets(secretSettings),
serviceSettings
);
}

protected CustomModel(CustomModel model, TaskSettings taskSettings) {
super(model, taskSettings);
rateLimitServiceSettings = model.rateLimitServiceSettings();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
Expand All @@ -27,6 +28,8 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
Expand All @@ -45,6 +48,7 @@
import static org.elasticsearch.inference.TaskType.unsupportedTaskTypeErrorMsg;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
Expand Down Expand Up @@ -81,12 +85,15 @@ public void parseRequestConfig(
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);

var chunkingSettings = extractChunkingSettings(config, taskType);

CustomModel model = createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
serviceSettingsMap,
chunkingSettings,
ConfigurationParseContext.REQUEST
);

Expand All @@ -100,6 +107,14 @@ public void parseRequestConfig(
}
}

private static ChunkingSettings extractChunkingSettings(Map<String, Object> config, TaskType taskType) {
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return null;
}

@Override
public InferenceServiceConfiguration getConfiguration() {
return Configuration.get();
Expand All @@ -125,14 +140,16 @@ private static CustomModel createModelWithoutLoggingDeprecations(
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
@Nullable Map<String, Object> secretSettings
@Nullable Map<String, Object> secretSettings,
@Nullable ChunkingSettings chunkingSettings
) {
return createModel(
inferenceEntityId,
taskType,
serviceSettings,
taskSettings,
secretSettings,
chunkingSettings,
ConfigurationParseContext.PERSISTENT
);
}
Expand All @@ -143,12 +160,13 @@ private static CustomModel createModel(
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
@Nullable Map<String, Object> secretSettings,
@Nullable ChunkingSettings chunkingSettings,
ConfigurationParseContext context
) {
if (supportedTaskTypes.contains(taskType) == false) {
throw new ElasticsearchStatusException(unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST);
}
return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context);
return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, chunkingSettings, context);
}

@Override
Expand All @@ -162,15 +180,33 @@ public CustomModel parsePersistedConfigWithSecrets(
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);

return createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, secretSettingsMap);
var chunkingSettings = extractChunkingSettings(config, taskType);

return createModelWithoutLoggingDeprecations(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
secretSettingsMap,
chunkingSettings
);
}

@Override
public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);

return createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null);
var chunkingSettings = extractChunkingSettings(config, taskType);

return createModelWithoutLoggingDeprecations(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
null,
chunkingSettings
);
}

@Override
Expand Down Expand Up @@ -211,7 +247,27 @@ protected void doChunkedInfer(
TimeValue timeout,
ActionListener<List<ChunkedInference>> listener
) {
listener.onFailure(new ElasticsearchStatusException("Chunking not supported by the {} service", RestStatus.BAD_REQUEST, NAME));
if (model instanceof CustomModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}

var customModel = (CustomModel) model;
var overriddenModel = CustomModel.of(customModel, taskSettings);

var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(SERVICE_NAME);
var manager = CustomRequestManager.of(overriddenModel, getServiceComponents().threadPool());

List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(),
customModel.getServiceSettings().getBatchSize(),
customModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);

for (var request : batchedRequests) {
var action = new SenderExecutableAction(getSender(), manager, failedToSendRequestErrorMessage);
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
Expand All @@ -53,16 +54,18 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues;

public class CustomServiceSettings extends FilteredXContentObject implements ServiceSettings, CustomRateLimitServiceSettings {

public static final String NAME = "custom_service_settings";
public static final String URL = "url";
public static final String BATCH_SIZE = "batch_size";
public static final String HEADERS = "headers";
public static final String REQUEST = "request";
public static final String RESPONSE = "response";
public static final String JSON_PARSER = "json_parser";
public static final String ERROR_PARSER = "error_parser";

private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE);
private static final int DEFAULT_EMBEDDING_BATCH_SIZE = 1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is a batch of 1 still a batch?

Suggested change
private static final int DEFAULT_EMBEDDING_BATCH_SIZE = 1;
private static final int DEFAULT_EMBEDDING_BATCH_SIZE = 10;


public static CustomServiceSettings fromMap(
Map<String, Object> map,
Expand Down Expand Up @@ -117,6 +120,8 @@ public static CustomServiceSettings fromMap(
context
);

var batchSize = extractOptionalPositiveInteger(map, BATCH_SIZE, ModelConfigurations.SERVICE_SETTINGS, validationException);

if (responseParserMap == null || jsonParserMap == null || errorParserMap == null) {
throw validationException;
}
Expand All @@ -137,7 +142,8 @@ public static CustomServiceSettings fromMap(
requestContentString,
responseJsonParser,
rateLimitSettings,
errorParser
errorParser,
batchSize
);
}

Expand All @@ -155,7 +161,6 @@ public record TextEmbeddingSettings(
null,
DenseVectorFieldMapper.ElementType.FLOAT
);

// This refers to settings that are not related to the text embedding task type (all the settings should be null)
public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null, null);

Expand Down Expand Up @@ -210,6 +215,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
private final CustomResponseParser responseJsonParser;
private final RateLimitSettings rateLimitSettings;
private final ErrorResponseParser errorParser;
private final int batchSize;

public CustomServiceSettings(
TextEmbeddingSettings textEmbeddingSettings,
Expand All @@ -220,6 +226,30 @@ public CustomServiceSettings(
CustomResponseParser responseJsonParser,
@Nullable RateLimitSettings rateLimitSettings,
ErrorResponseParser errorParser
) {
this(
textEmbeddingSettings,
url,
headers,
queryParameters,
requestContentString,
responseJsonParser,
rateLimitSettings,
errorParser,
null
);
}

public CustomServiceSettings(
TextEmbeddingSettings textEmbeddingSettings,
String url,
@Nullable Map<String, String> headers,
@Nullable QueryParameters queryParameters,
String requestContentString,
CustomResponseParser responseJsonParser,
@Nullable RateLimitSettings rateLimitSettings,
ErrorResponseParser errorParser,
@Nullable Integer batchSize
) {
this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings);
this.url = Objects.requireNonNull(url);
Expand All @@ -229,6 +259,7 @@ public CustomServiceSettings(
this.responseJsonParser = Objects.requireNonNull(responseJsonParser);
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
this.errorParser = Objects.requireNonNull(errorParser);
this.batchSize = Objects.requireNonNullElse(batchSize, DEFAULT_EMBEDDING_BATCH_SIZE);
}

public CustomServiceSettings(StreamInput in) throws IOException {
Expand All @@ -240,6 +271,12 @@ public CustomServiceSettings(StreamInput in) throws IOException {
responseJsonParser = in.readNamedWriteable(CustomResponseParser.class);
rateLimitSettings = new RateLimitSettings(in);
errorParser = new ErrorResponseParser(in);
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE)
|| in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19)) {
batchSize = in.readVInt();
} else {
batchSize = DEFAULT_EMBEDDING_BATCH_SIZE;
}
}

@Override
Expand Down Expand Up @@ -291,6 +328,10 @@ public ErrorResponseParser getErrorParser() {
return errorParser;
}

public int getBatchSize() {
return batchSize;
}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
Expand Down Expand Up @@ -337,6 +378,8 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder

rateLimitSettings.toXContent(builder, params);

builder.field(BATCH_SIZE, batchSize);

return builder;
}

Expand All @@ -360,6 +403,11 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(responseJsonParser);
rateLimitSettings.writeTo(out);
errorParser.writeTo(out);

if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE)
|| out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19)) {
out.writeVInt(batchSize);
}
}

@Override
Expand All @@ -374,7 +422,8 @@ public boolean equals(Object o) {
&& Objects.equals(requestContentString, that.requestContentString)
&& Objects.equals(responseJsonParser, that.responseJsonParser)
&& Objects.equals(rateLimitSettings, that.rateLimitSettings)
&& Objects.equals(errorParser, that.errorParser);
&& Objects.equals(errorParser, that.errorParser)
&& Objects.equals(batchSize, that.batchSize);
}

@Override
Expand All @@ -387,7 +436,8 @@ public int hashCode() {
requestContentString,
responseJsonParser,
rateLimitSettings,
errorParser
errorParser,
batchSize
);
}

Expand Down
Loading