Skip to content

Commit 9b276e3

Browse files
authored
[Inference API] alibabacloud ai search service support chunk infer to support semantic_text field (#112652)
1 parent 5daa82a commit 9b276e3

File tree

4 files changed

+124
-1
lines changed

4 files changed

+124
-1
lines changed

docs/changelog/112652.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 110399
2+
summary: "[Inference API] alibabacloud ai search service support chunk infer to support semantic_text field"
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.elasticsearch.inference.SimilarityMeasure;
2525
import org.elasticsearch.inference.TaskType;
2626
import org.elasticsearch.rest.RestStatus;
27+
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
2728
import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionCreator;
2829
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
2930
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -49,6 +50,7 @@
4950
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
5051
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
5152
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
53+
import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceFields.EMBEDDING_MAX_BATCH_SIZE;
5254

5355
public class AlibabaCloudSearchService extends SenderService {
5456
public static final String NAME = AlibabaCloudSearchUtils.SERVICE_NAME;
@@ -253,7 +255,20 @@ protected void doChunkedInfer(
253255
TimeValue timeout,
254256
ActionListener<List<ChunkedInferenceServiceResults>> listener
255257
) {
256-
listener.onFailure(new ElasticsearchStatusException("Chunking not supported by the {} service", RestStatus.BAD_REQUEST, NAME));
258+
if (model instanceof AlibabaCloudSearchModel == false) {
259+
listener.onFailure(createInvalidModelException(model));
260+
return;
261+
}
262+
263+
AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model;
264+
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());
265+
266+
var batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT)
267+
.batchRequestsWithListeners(listener);
268+
for (var request : batchedRequests) {
269+
var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings, inputType);
270+
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
271+
}
257272
}
258273

259274
/**
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.alibabacloudsearch;
9+
10+
public class AlibabaCloudSearchServiceFields {
11+
/**
12+
* Taken from https://help.aliyun.com/zh/open-search/search-platform/developer-reference/text-embedding-api-details
13+
*/
14+
static final int EMBEDDING_MAX_BATCH_SIZE = 32;
15+
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,37 @@
1111
import org.elasticsearch.action.support.PlainActionFuture;
1212
import org.elasticsearch.common.settings.Settings;
1313
import org.elasticsearch.core.TimeValue;
14+
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
15+
import org.elasticsearch.inference.ChunkingOptions;
1416
import org.elasticsearch.inference.InferenceServiceResults;
1517
import org.elasticsearch.inference.InputType;
1618
import org.elasticsearch.inference.Model;
1719
import org.elasticsearch.inference.ModelConfigurations;
1820
import org.elasticsearch.inference.TaskType;
1921
import org.elasticsearch.test.ESTestCase;
2022
import org.elasticsearch.threadpool.ThreadPool;
23+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
24+
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
2125
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
26+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
27+
import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionVisitor;
2228
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
2329
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
2430
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
31+
import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils;
2532
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
2633
import org.elasticsearch.xpack.inference.services.ServiceFields;
2734
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
2835
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModelTests;
2936
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettingsTests;
3037
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettingsTests;
38+
import org.hamcrest.CoreMatchers;
3139
import org.hamcrest.MatcherAssert;
3240
import org.junit.After;
3341
import org.junit.Before;
3442

3543
import java.io.IOException;
44+
import java.util.Arrays;
3645
import java.util.HashMap;
3746
import java.util.List;
3847
import java.util.Map;
@@ -44,6 +53,7 @@
4453
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
4554
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
4655
import static org.hamcrest.CoreMatchers.is;
56+
import static org.hamcrest.Matchers.hasSize;
4757
import static org.hamcrest.Matchers.instanceOf;
4858
import static org.mockito.Mockito.mock;
4959

@@ -156,6 +166,84 @@ public void doInfer(
156166
}
157167
}
158168

169+
public void testChunkedInfer_Batches() throws IOException {
170+
var input = List.of("foo", "bar");
171+
172+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
173+
174+
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) {
175+
Map<String, Object> serviceSettingsMap = new HashMap<>();
176+
serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id");
177+
serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host");
178+
serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default");
179+
serviceSettingsMap.put(ServiceFields.DIMENSIONS, 1536);
180+
181+
Map<String, Object> taskSettingsMap = new HashMap<>();
182+
183+
Map<String, Object> secretSettingsMap = new HashMap<>();
184+
secretSettingsMap.put("api_key", "secret");
185+
186+
var model = new AlibabaCloudSearchEmbeddingsModel(
187+
"service",
188+
TaskType.TEXT_EMBEDDING,
189+
AlibabaCloudSearchUtils.SERVICE_NAME,
190+
serviceSettingsMap,
191+
taskSettingsMap,
192+
secretSettingsMap,
193+
null
194+
) {
195+
public ExecutableAction accept(
196+
AlibabaCloudSearchActionVisitor visitor,
197+
Map<String, Object> taskSettings,
198+
InputType inputType
199+
) {
200+
return (inferenceInputs, timeout, listener) -> {
201+
InferenceTextEmbeddingFloatResults results = new InferenceTextEmbeddingFloatResults(
202+
List.of(
203+
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0123f, -0.0123f }),
204+
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0456f, -0.0456f })
205+
)
206+
);
207+
208+
listener.onResponse(results);
209+
};
210+
}
211+
};
212+
213+
PlainActionFuture<List<ChunkedInferenceServiceResults>> listener = new PlainActionFuture<>();
214+
service.chunkedInfer(
215+
model,
216+
input,
217+
new HashMap<>(),
218+
InputType.INGEST,
219+
new ChunkingOptions(null, null),
220+
InferenceAction.Request.DEFAULT_TIMEOUT,
221+
listener
222+
);
223+
224+
var results = listener.actionGet(TIMEOUT);
225+
assertThat(results, hasSize(2));
226+
227+
// first result
228+
{
229+
assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
230+
var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0);
231+
assertThat(floatResult.chunks(), hasSize(1));
232+
assertEquals(input.get(0), floatResult.chunks().get(0).matchedText());
233+
assertTrue(Arrays.equals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding()));
234+
}
235+
236+
// second result
237+
{
238+
assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
239+
var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1);
240+
assertThat(floatResult.chunks(), hasSize(1));
241+
assertEquals(input.get(1), floatResult.chunks().get(0).matchedText());
242+
assertTrue(Arrays.equals(new float[] { 0.0456f, -0.0456f }, floatResult.chunks().get(0).embedding()));
243+
}
244+
}
245+
}
246+
159247
private Map<String, Object> getRequestConfigMap(
160248
Map<String, Object> serviceSettings,
161249
Map<String, Object> taskSettings,

0 commit comments

Comments
 (0)