Skip to content

Commit 27faf59

Browse files
Adding chunking settings to GoogleVertexAiService, AzureAiStudioService, and AlibabaCloudSearchService (#113981)
* Adding chunking settings to GoogleVertexAiService, AzureAiStudioService, and AlibabaCloudSearchService * Update docs/changelog/113981.yaml * Updating AlibabaService chunkedInfer to handle sparse embedding task types --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent e07c935 commit 27faf59

File tree

16 files changed

+1350
-87
lines changed

16 files changed

+1350
-87
lines changed

docs/changelog/113981.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 113981
2+
summary: "Adding chunking settings to `GoogleVertexAiService,` `AzureAiStudioService,`\
3+
\ and `AlibabaCloudSearchService`"
4+
area: Machine Learning
5+
type: enhancement
6+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.core.TimeValue;
1616
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
1717
import org.elasticsearch.inference.ChunkingOptions;
18+
import org.elasticsearch.inference.ChunkingSettings;
1819
import org.elasticsearch.inference.InferenceService;
1920
import org.elasticsearch.inference.InferenceServiceResults;
2021
import org.elasticsearch.inference.InputType;
@@ -24,6 +25,8 @@
2425
import org.elasticsearch.inference.SimilarityMeasure;
2526
import org.elasticsearch.inference.TaskType;
2627
import org.elasticsearch.rest.RestStatus;
28+
import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag;
29+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
2730
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
2831
import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionCreator;
2932
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
@@ -74,11 +77,19 @@ public void parseRequestConfig(
7477
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
7578
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
7679

80+
ChunkingSettings chunkingSettings = null;
81+
if (ChunkingSettingsFeatureFlag.isEnabled() && List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) {
82+
chunkingSettings = ChunkingSettingsBuilder.fromMap(
83+
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
84+
);
85+
}
86+
7787
AlibabaCloudSearchModel model = createModel(
7888
inferenceEntityId,
7989
taskType,
8090
serviceSettingsMap,
8191
taskSettingsMap,
92+
chunkingSettings,
8293
serviceSettingsMap,
8394
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
8495
ConfigurationParseContext.REQUEST
@@ -99,6 +110,7 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations(
99110
TaskType taskType,
100111
Map<String, Object> serviceSettings,
101112
Map<String, Object> taskSettings,
113+
ChunkingSettings chunkingSettings,
102114
@Nullable Map<String, Object> secretSettings,
103115
String failureMessage
104116
) {
@@ -107,6 +119,7 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations(
107119
taskType,
108120
serviceSettings,
109121
taskSettings,
122+
chunkingSettings,
110123
secretSettings,
111124
failureMessage,
112125
ConfigurationParseContext.PERSISTENT
@@ -118,6 +131,7 @@ private static AlibabaCloudSearchModel createModel(
118131
TaskType taskType,
119132
Map<String, Object> serviceSettings,
120133
Map<String, Object> taskSettings,
134+
ChunkingSettings chunkingSettings,
121135
@Nullable Map<String, Object> secretSettings,
122136
String failureMessage,
123137
ConfigurationParseContext context
@@ -129,6 +143,7 @@ private static AlibabaCloudSearchModel createModel(
129143
NAME,
130144
serviceSettings,
131145
taskSettings,
146+
chunkingSettings,
132147
secretSettings,
133148
context
134149
);
@@ -138,6 +153,7 @@ private static AlibabaCloudSearchModel createModel(
138153
NAME,
139154
serviceSettings,
140155
taskSettings,
156+
chunkingSettings,
141157
secretSettings,
142158
context
143159
);
@@ -174,11 +190,17 @@ public AlibabaCloudSearchModel parsePersistedConfigWithSecrets(
174190
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
175191
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
176192

193+
ChunkingSettings chunkingSettings = null;
194+
if (ChunkingSettingsFeatureFlag.isEnabled() && List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) {
195+
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
196+
}
197+
177198
return createModelWithoutLoggingDeprecations(
178199
inferenceEntityId,
179200
taskType,
180201
serviceSettingsMap,
181202
taskSettingsMap,
203+
chunkingSettings,
182204
secretSettingsMap,
183205
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
184206
);
@@ -189,11 +211,17 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta
189211
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
190212
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
191213

214+
ChunkingSettings chunkingSettings = null;
215+
if (ChunkingSettingsFeatureFlag.isEnabled() && List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) {
216+
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
217+
}
218+
192219
return createModelWithoutLoggingDeprecations(
193220
inferenceEntityId,
194221
taskType,
195222
serviceSettingsMap,
196223
taskSettingsMap,
224+
chunkingSettings,
197225
null,
198226
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
199227
);
@@ -238,17 +266,36 @@ protected void doChunkedInfer(
238266
AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model;
239267
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());
240268

241-
var batchedRequests = new EmbeddingRequestChunker(
242-
inputs.getInputs(),
243-
EMBEDDING_MAX_BATCH_SIZE,
244-
EmbeddingRequestChunker.EmbeddingType.FLOAT
245-
).batchRequestsWithListeners(listener);
269+
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests;
270+
if (ChunkingSettingsFeatureFlag.isEnabled()) {
271+
batchedRequests = new EmbeddingRequestChunker(
272+
inputs.getInputs(),
273+
EMBEDDING_MAX_BATCH_SIZE,
274+
getEmbeddingTypeFromTaskType(alibabaCloudSearchModel.getTaskType()),
275+
alibabaCloudSearchModel.getConfigurations().getChunkingSettings()
276+
).batchRequestsWithListeners(listener);
277+
} else {
278+
batchedRequests = new EmbeddingRequestChunker(
279+
inputs.getInputs(),
280+
EMBEDDING_MAX_BATCH_SIZE,
281+
getEmbeddingTypeFromTaskType(alibabaCloudSearchModel.getTaskType())
282+
).batchRequestsWithListeners(listener);
283+
}
284+
246285
for (var request : batchedRequests) {
247286
var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings, inputType);
248287
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
249288
}
250289
}
251290

291+
private EmbeddingRequestChunker.EmbeddingType getEmbeddingTypeFromTaskType(TaskType taskType) {
292+
return switch (taskType) {
293+
case TaskType.TEXT_EMBEDDING -> EmbeddingRequestChunker.EmbeddingType.FLOAT;
294+
case TaskType.SPARSE_EMBEDDING -> EmbeddingRequestChunker.EmbeddingType.SPARSE;
295+
default -> throw new IllegalArgumentException("Unsupported task type for chunking: " + taskType);
296+
};
297+
}
298+
252299
/**
253300
* For text embedding models get the embedding size and
254301
* update the service settings.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModel.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings;
99

1010
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.inference.ChunkingSettings;
1112
import org.elasticsearch.inference.InputType;
1213
import org.elasticsearch.inference.ModelConfigurations;
1314
import org.elasticsearch.inference.ModelSecrets;
@@ -39,6 +40,7 @@ public AlibabaCloudSearchEmbeddingsModel(
3940
String service,
4041
Map<String, Object> serviceSettings,
4142
Map<String, Object> taskSettings,
43+
ChunkingSettings chunkingSettings,
4244
@Nullable Map<String, Object> secrets,
4345
ConfigurationParseContext context
4446
) {
@@ -48,6 +50,7 @@ public AlibabaCloudSearchEmbeddingsModel(
4850
service,
4951
AlibabaCloudSearchEmbeddingsServiceSettings.fromMap(serviceSettings, context),
5052
AlibabaCloudSearchEmbeddingsTaskSettings.fromMap(taskSettings),
53+
chunkingSettings,
5154
DefaultSecretSettings.fromMap(secrets)
5255
);
5356
}
@@ -59,10 +62,11 @@ public AlibabaCloudSearchEmbeddingsModel(
5962
String service,
6063
AlibabaCloudSearchEmbeddingsServiceSettings serviceSettings,
6164
AlibabaCloudSearchEmbeddingsTaskSettings taskSettings,
65+
ChunkingSettings chunkingSettings,
6266
@Nullable DefaultSecretSettings secretSettings
6367
) {
6468
super(
65-
new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings),
69+
new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
6670
new ModelSecrets(secretSettings),
6771
serviceSettings.getCommonSettings()
6872
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettings.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.common.io.stream.StreamInput;
1414
import org.elasticsearch.common.io.stream.StreamOutput;
1515
import org.elasticsearch.core.Nullable;
16+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1617
import org.elasticsearch.inference.ModelConfigurations;
1718
import org.elasticsearch.inference.ServiceSettings;
1819
import org.elasticsearch.inference.SimilarityMeasure;
@@ -81,10 +82,21 @@ public SimilarityMeasure getSimilarity() {
8182
return similarity;
8283
}
8384

84-
public Integer getDimensions() {
85+
@Override
86+
public Integer dimensions() {
8587
return dimensions;
8688
}
8789

90+
@Override
91+
public SimilarityMeasure similarity() {
92+
return similarity;
93+
}
94+
95+
@Override
96+
public DenseVectorFieldMapper.ElementType elementType() {
97+
return DenseVectorFieldMapper.ElementType.FLOAT;
98+
}
99+
88100
public Integer getMaxInputTokens() {
89101
return maxInputTokens;
90102
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModel.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse;
99

1010
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.inference.ChunkingSettings;
1112
import org.elasticsearch.inference.InputType;
1213
import org.elasticsearch.inference.ModelConfigurations;
1314
import org.elasticsearch.inference.ModelSecrets;
@@ -39,6 +40,7 @@ public AlibabaCloudSearchSparseModel(
3940
String service,
4041
Map<String, Object> serviceSettings,
4142
Map<String, Object> taskSettings,
43+
ChunkingSettings chunkingSettings,
4244
@Nullable Map<String, Object> secrets,
4345
ConfigurationParseContext context
4446
) {
@@ -48,6 +50,7 @@ public AlibabaCloudSearchSparseModel(
4850
service,
4951
AlibabaCloudSearchSparseServiceSettings.fromMap(serviceSettings, context),
5052
AlibabaCloudSearchSparseTaskSettings.fromMap(taskSettings),
53+
chunkingSettings,
5154
DefaultSecretSettings.fromMap(secrets)
5255
);
5356
}
@@ -59,10 +62,11 @@ public AlibabaCloudSearchSparseModel(
5962
String service,
6063
AlibabaCloudSearchSparseServiceSettings serviceSettings,
6164
AlibabaCloudSearchSparseTaskSettings taskSettings,
65+
ChunkingSettings chunkingSettings,
6266
@Nullable DefaultSecretSettings secretSettings
6367
) {
6468
super(
65-
new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings),
69+
new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
6670
new ModelSecrets(secretSettings),
6771
serviceSettings.getCommonSettings()
6872
);

0 commit comments

Comments
 (0)