Skip to content

Commit 9e7cb3f

Browse files
Adding chunking tests
1 parent 37e2729 commit 9e7cb3f

File tree

6 files changed

+394
-15
lines changed

6 files changed

+394
-15
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ static TransportVersion def(int id) {
196196
public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK_ADDED_8_19 = def(8_841_0_48);
197197
public static final TransportVersion NONE_CHUNKING_STRATEGY_8_19 = def(8_841_0_49);
198198
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19 = def(8_841_0_50);
199+
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 = def(8_841_0_51);
199200
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
200201
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
201202
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@@ -298,6 +299,7 @@ static TransportVersion def(int id) {
298299
public static final TransportVersion HEAP_USAGE_IN_CLUSTER_INFO = def(9_096_0_00);
299300
public static final TransportVersion NONE_CHUNKING_STRATEGY = def(9_097_0_00);
300301
public static final TransportVersion PROJECT_DELETION_GLOBAL_BLOCK = def(9_098_0_00);
302+
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE = def(9_099_0_00);
301303

302304
/*
303305
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.services.custom;
99

1010
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.inference.ChunkingSettings;
1112
import org.elasticsearch.inference.Model;
1213
import org.elasticsearch.inference.ModelConfigurations;
1314
import org.elasticsearch.inference.ModelSecrets;
@@ -51,6 +52,27 @@ public CustomModel(
5152
);
5253
}
5354

55+
public CustomModel(
56+
String inferenceId,
57+
TaskType taskType,
58+
String service,
59+
Map<String, Object> serviceSettings,
60+
Map<String, Object> taskSettings,
61+
@Nullable Map<String, Object> secrets,
62+
@Nullable ChunkingSettings chunkingSettings,
63+
ConfigurationParseContext context
64+
) {
65+
this(
66+
inferenceId,
67+
taskType,
68+
service,
69+
CustomServiceSettings.fromMap(serviceSettings, context, taskType, inferenceId),
70+
CustomTaskSettings.fromMap(taskSettings),
71+
CustomSecretSettings.fromMap(secrets),
72+
chunkingSettings
73+
);
74+
}
75+
5476
// should only be used for testing
5577
CustomModel(
5678
String inferenceId,
@@ -67,6 +89,23 @@ public CustomModel(
6789
);
6890
}
6991

92+
// should only be used for testing
93+
CustomModel(
94+
String inferenceId,
95+
TaskType taskType,
96+
String service,
97+
CustomServiceSettings serviceSettings,
98+
CustomTaskSettings taskSettings,
99+
@Nullable CustomSecretSettings secretSettings,
100+
@Nullable ChunkingSettings chunkingSettings
101+
) {
102+
this(
103+
new ModelConfigurations(inferenceId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
104+
new ModelSecrets(secretSettings),
105+
serviceSettings
106+
);
107+
}
108+
70109
protected CustomModel(CustomModel model, TaskSettings taskSettings) {
71110
super(model, taskSettings);
72111
rateLimitServiceSettings = model.rateLimitServiceSettings();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.core.Strings;
1818
import org.elasticsearch.core.TimeValue;
1919
import org.elasticsearch.inference.ChunkedInference;
20+
import org.elasticsearch.inference.ChunkingSettings;
2021
import org.elasticsearch.inference.InferenceServiceConfiguration;
2122
import org.elasticsearch.inference.InferenceServiceResults;
2223
import org.elasticsearch.inference.InputType;
@@ -27,6 +28,8 @@
2728
import org.elasticsearch.inference.SimilarityMeasure;
2829
import org.elasticsearch.inference.TaskType;
2930
import org.elasticsearch.rest.RestStatus;
31+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
32+
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
3033
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
3134
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
3235
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -45,6 +48,7 @@
4548
import static org.elasticsearch.inference.TaskType.unsupportedTaskTypeErrorMsg;
4649
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
4750
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
51+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
4852
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
4953
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
5054
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
@@ -81,12 +85,15 @@ public void parseRequestConfig(
8185
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
8286
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
8387

88+
var chunkingSettings = extractChunkingSettings(config, taskType);
89+
8490
CustomModel model = createModel(
8591
inferenceEntityId,
8692
taskType,
8793
serviceSettingsMap,
8894
taskSettingsMap,
8995
serviceSettingsMap,
96+
chunkingSettings,
9097
ConfigurationParseContext.REQUEST
9198
);
9299

@@ -100,6 +107,14 @@ public void parseRequestConfig(
100107
}
101108
}
102109

110+
private static ChunkingSettings extractChunkingSettings(Map<String, Object> config, TaskType taskType) {
111+
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
112+
return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
113+
}
114+
115+
return null;
116+
}
117+
103118
@Override
104119
public InferenceServiceConfiguration getConfiguration() {
105120
return Configuration.get();
@@ -125,14 +140,16 @@ private static CustomModel createModelWithoutLoggingDeprecations(
125140
TaskType taskType,
126141
Map<String, Object> serviceSettings,
127142
Map<String, Object> taskSettings,
128-
@Nullable Map<String, Object> secretSettings
143+
@Nullable Map<String, Object> secretSettings,
144+
@Nullable ChunkingSettings chunkingSettings
129145
) {
130146
return createModel(
131147
inferenceEntityId,
132148
taskType,
133149
serviceSettings,
134150
taskSettings,
135151
secretSettings,
152+
chunkingSettings,
136153
ConfigurationParseContext.PERSISTENT
137154
);
138155
}
@@ -143,12 +160,13 @@ private static CustomModel createModel(
143160
Map<String, Object> serviceSettings,
144161
Map<String, Object> taskSettings,
145162
@Nullable Map<String, Object> secretSettings,
163+
@Nullable ChunkingSettings chunkingSettings,
146164
ConfigurationParseContext context
147165
) {
148166
if (supportedTaskTypes.contains(taskType) == false) {
149167
throw new ElasticsearchStatusException(unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST);
150168
}
151-
return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context);
169+
return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, chunkingSettings, context);
152170
}
153171

154172
@Override
@@ -162,15 +180,33 @@ public CustomModel parsePersistedConfigWithSecrets(
162180
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
163181
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
164182

165-
return createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, secretSettingsMap);
183+
var chunkingSettings = extractChunkingSettings(config, taskType);
184+
185+
return createModelWithoutLoggingDeprecations(
186+
inferenceEntityId,
187+
taskType,
188+
serviceSettingsMap,
189+
taskSettingsMap,
190+
secretSettingsMap,
191+
chunkingSettings
192+
);
166193
}
167194

168195
@Override
169196
public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
170197
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
171198
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
172199

173-
return createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null);
200+
var chunkingSettings = extractChunkingSettings(config, taskType);
201+
202+
return createModelWithoutLoggingDeprecations(
203+
inferenceEntityId,
204+
taskType,
205+
serviceSettingsMap,
206+
taskSettingsMap,
207+
null,
208+
chunkingSettings
209+
);
174210
}
175211

176212
@Override
@@ -211,7 +247,27 @@ protected void doChunkedInfer(
211247
TimeValue timeout,
212248
ActionListener<List<ChunkedInference>> listener
213249
) {
214-
listener.onFailure(new ElasticsearchStatusException("Chunking not supported by the {} service", RestStatus.BAD_REQUEST, NAME));
250+
if (model instanceof CustomModel == false) {
251+
listener.onFailure(createInvalidModelException(model));
252+
return;
253+
}
254+
255+
var customModel = (CustomModel) model;
256+
var overriddenModel = CustomModel.of(customModel, taskSettings);
257+
258+
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(SERVICE_NAME);
259+
var manager = CustomRequestManager.of(overriddenModel, getServiceComponents().threadPool());
260+
261+
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
262+
inputs.getInputs(),
263+
customModel.getServiceSettings().getBatchSize(),
264+
customModel.getConfigurations().getChunkingSettings()
265+
).batchRequestsWithListeners(listener);
266+
267+
for (var request : batchedRequests) {
268+
var action = new SenderExecutableAction(getSender(), manager, failedToSendRequestErrorMessage);
269+
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
270+
}
215271
}
216272

217273
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
4545
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
4646
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap;
47+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
4748
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredMap;
4849
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
4950
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
@@ -53,16 +54,18 @@
5354
import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues;
5455

5556
public class CustomServiceSettings extends FilteredXContentObject implements ServiceSettings, CustomRateLimitServiceSettings {
57+
5658
public static final String NAME = "custom_service_settings";
5759
public static final String URL = "url";
60+
public static final String BATCH_SIZE = "batch_size";
5861
public static final String HEADERS = "headers";
5962
public static final String REQUEST = "request";
6063
public static final String RESPONSE = "response";
6164
public static final String JSON_PARSER = "json_parser";
6265
public static final String ERROR_PARSER = "error_parser";
63-
6466
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
6567
private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE);
68+
private static final int DEFAULT_EMBEDDING_BATCH_SIZE = 1;
6669

6770
public static CustomServiceSettings fromMap(
6871
Map<String, Object> map,
@@ -117,6 +120,8 @@ public static CustomServiceSettings fromMap(
117120
context
118121
);
119122

123+
var batchSize = extractOptionalPositiveInteger(map, BATCH_SIZE, ModelConfigurations.SERVICE_SETTINGS, validationException);
124+
120125
if (responseParserMap == null || jsonParserMap == null || errorParserMap == null) {
121126
throw validationException;
122127
}
@@ -137,7 +142,8 @@ public static CustomServiceSettings fromMap(
137142
requestContentString,
138143
responseJsonParser,
139144
rateLimitSettings,
140-
errorParser
145+
errorParser,
146+
batchSize
141147
);
142148
}
143149

@@ -155,7 +161,6 @@ public record TextEmbeddingSettings(
155161
null,
156162
DenseVectorFieldMapper.ElementType.FLOAT
157163
);
158-
159164
// This refers to settings that are not related to the text embedding task type (all the settings should be null)
160165
public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null, null);
161166

@@ -210,6 +215,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
210215
private final CustomResponseParser responseJsonParser;
211216
private final RateLimitSettings rateLimitSettings;
212217
private final ErrorResponseParser errorParser;
218+
private final int batchSize;
213219

214220
public CustomServiceSettings(
215221
TextEmbeddingSettings textEmbeddingSettings,
@@ -220,6 +226,30 @@ public CustomServiceSettings(
220226
CustomResponseParser responseJsonParser,
221227
@Nullable RateLimitSettings rateLimitSettings,
222228
ErrorResponseParser errorParser
229+
) {
230+
this(
231+
textEmbeddingSettings,
232+
url,
233+
headers,
234+
queryParameters,
235+
requestContentString,
236+
responseJsonParser,
237+
rateLimitSettings,
238+
errorParser,
239+
null
240+
);
241+
}
242+
243+
public CustomServiceSettings(
244+
TextEmbeddingSettings textEmbeddingSettings,
245+
String url,
246+
@Nullable Map<String, String> headers,
247+
@Nullable QueryParameters queryParameters,
248+
String requestContentString,
249+
CustomResponseParser responseJsonParser,
250+
@Nullable RateLimitSettings rateLimitSettings,
251+
ErrorResponseParser errorParser,
252+
@Nullable Integer batchSize
223253
) {
224254
this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings);
225255
this.url = Objects.requireNonNull(url);
@@ -229,6 +259,7 @@ public CustomServiceSettings(
229259
this.responseJsonParser = Objects.requireNonNull(responseJsonParser);
230260
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
231261
this.errorParser = Objects.requireNonNull(errorParser);
262+
this.batchSize = Objects.requireNonNullElse(batchSize, DEFAULT_EMBEDDING_BATCH_SIZE);
232263
}
233264

234265
public CustomServiceSettings(StreamInput in) throws IOException {
@@ -240,6 +271,12 @@ public CustomServiceSettings(StreamInput in) throws IOException {
240271
responseJsonParser = in.readNamedWriteable(CustomResponseParser.class);
241272
rateLimitSettings = new RateLimitSettings(in);
242273
errorParser = new ErrorResponseParser(in);
274+
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE)
275+
|| in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19)) {
276+
batchSize = in.readVInt();
277+
} else {
278+
batchSize = DEFAULT_EMBEDDING_BATCH_SIZE;
279+
}
243280
}
244281

245282
@Override
@@ -291,6 +328,10 @@ public ErrorResponseParser getErrorParser() {
291328
return errorParser;
292329
}
293330

331+
public int getBatchSize() {
332+
return batchSize;
333+
}
334+
294335
@Override
295336
public RateLimitSettings rateLimitSettings() {
296337
return rateLimitSettings;
@@ -337,6 +378,8 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder
337378

338379
rateLimitSettings.toXContent(builder, params);
339380

381+
builder.field(BATCH_SIZE, batchSize);
382+
340383
return builder;
341384
}
342385

@@ -360,6 +403,11 @@ public void writeTo(StreamOutput out) throws IOException {
360403
out.writeNamedWriteable(responseJsonParser);
361404
rateLimitSettings.writeTo(out);
362405
errorParser.writeTo(out);
406+
407+
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE)
408+
|| out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19)) {
409+
out.writeVInt(batchSize);
410+
}
363411
}
364412

365413
@Override
@@ -374,7 +422,8 @@ public boolean equals(Object o) {
374422
&& Objects.equals(requestContentString, that.requestContentString)
375423
&& Objects.equals(responseJsonParser, that.responseJsonParser)
376424
&& Objects.equals(rateLimitSettings, that.rateLimitSettings)
377-
&& Objects.equals(errorParser, that.errorParser);
425+
&& Objects.equals(errorParser, that.errorParser)
426+
&& Objects.equals(batchSize, that.batchSize);
378427
}
379428

380429
@Override
@@ -387,7 +436,8 @@ public int hashCode() {
387436
requestContentString,
388437
responseJsonParser,
389438
rateLimitSettings,
390-
errorParser
439+
errorParser,
440+
batchSize
391441
);
392442
}
393443

0 commit comments

Comments
 (0)