|
11 | 11 | import org.elasticsearch.action.support.PlainActionFuture;
|
12 | 12 | import org.elasticsearch.common.settings.Settings;
|
13 | 13 | import org.elasticsearch.core.TimeValue;
|
| 14 | +import org.elasticsearch.inference.ChunkedInferenceServiceResults; |
| 15 | +import org.elasticsearch.inference.ChunkingOptions; |
14 | 16 | import org.elasticsearch.inference.InferenceServiceResults;
|
15 | 17 | import org.elasticsearch.inference.InputType;
|
16 | 18 | import org.elasticsearch.inference.Model;
|
17 | 19 | import org.elasticsearch.inference.ModelConfigurations;
|
18 | 20 | import org.elasticsearch.inference.TaskType;
|
19 | 21 | import org.elasticsearch.test.ESTestCase;
|
20 | 22 | import org.elasticsearch.threadpool.ThreadPool;
|
| 23 | +import org.elasticsearch.xpack.core.inference.action.InferenceAction; |
| 24 | +import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; |
21 | 25 | import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
|
| 26 | +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; |
| 27 | +import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionVisitor; |
22 | 28 | import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
23 | 29 | import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
24 | 30 | import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
| 31 | +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils; |
25 | 32 | import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
26 | 33 | import org.elasticsearch.xpack.inference.services.ServiceFields;
|
27 | 34 | import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
|
28 | 35 | import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModelTests;
|
29 | 36 | import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettingsTests;
|
30 | 37 | import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettingsTests;
|
| 38 | +import org.hamcrest.CoreMatchers; |
31 | 39 | import org.hamcrest.MatcherAssert;
|
32 | 40 | import org.junit.After;
|
33 | 41 | import org.junit.Before;
|
34 | 42 |
|
35 | 43 | import java.io.IOException;
|
| 44 | +import java.util.Arrays; |
36 | 45 | import java.util.HashMap;
|
37 | 46 | import java.util.List;
|
38 | 47 | import java.util.Map;
|
|
44 | 53 | import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
45 | 54 | import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
|
46 | 55 | import static org.hamcrest.CoreMatchers.is;
|
| 56 | +import static org.hamcrest.Matchers.hasSize; |
47 | 57 | import static org.hamcrest.Matchers.instanceOf;
|
48 | 58 | import static org.mockito.Mockito.mock;
|
49 | 59 |
|
@@ -156,6 +166,84 @@ public void doInfer(
|
156 | 166 | }
|
157 | 167 | }
|
158 | 168 |
|
| 169 | + public void testChunkedInfer_Batches() throws IOException { |
| 170 | + var input = List.of("foo", "bar"); |
| 171 | + |
| 172 | + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); |
| 173 | + |
| 174 | + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { |
| 175 | + Map<String, Object> serviceSettingsMap = new HashMap<>(); |
| 176 | + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id"); |
| 177 | + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host"); |
| 178 | + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default"); |
| 179 | + serviceSettingsMap.put(ServiceFields.DIMENSIONS, 1536); |
| 180 | + |
| 181 | + Map<String, Object> taskSettingsMap = new HashMap<>(); |
| 182 | + |
| 183 | + Map<String, Object> secretSettingsMap = new HashMap<>(); |
| 184 | + secretSettingsMap.put("api_key", "secret"); |
| 185 | + |
| 186 | + var model = new AlibabaCloudSearchEmbeddingsModel( |
| 187 | + "service", |
| 188 | + TaskType.TEXT_EMBEDDING, |
| 189 | + AlibabaCloudSearchUtils.SERVICE_NAME, |
| 190 | + serviceSettingsMap, |
| 191 | + taskSettingsMap, |
| 192 | + secretSettingsMap, |
| 193 | + null |
| 194 | + ) { |
| 195 | + public ExecutableAction accept( |
| 196 | + AlibabaCloudSearchActionVisitor visitor, |
| 197 | + Map<String, Object> taskSettings, |
| 198 | + InputType inputType |
| 199 | + ) { |
| 200 | + return (inferenceInputs, timeout, listener) -> { |
| 201 | + InferenceTextEmbeddingFloatResults results = new InferenceTextEmbeddingFloatResults( |
| 202 | + List.of( |
| 203 | + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0123f, -0.0123f }), |
| 204 | + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0456f, -0.0456f }) |
| 205 | + ) |
| 206 | + ); |
| 207 | + |
| 208 | + listener.onResponse(results); |
| 209 | + }; |
| 210 | + } |
| 211 | + }; |
| 212 | + |
| 213 | + PlainActionFuture<List<ChunkedInferenceServiceResults>> listener = new PlainActionFuture<>(); |
| 214 | + service.chunkedInfer( |
| 215 | + model, |
| 216 | + input, |
| 217 | + new HashMap<>(), |
| 218 | + InputType.INGEST, |
| 219 | + new ChunkingOptions(null, null), |
| 220 | + InferenceAction.Request.DEFAULT_TIMEOUT, |
| 221 | + listener |
| 222 | + ); |
| 223 | + |
| 224 | + var results = listener.actionGet(TIMEOUT); |
| 225 | + assertThat(results, hasSize(2)); |
| 226 | + |
| 227 | + // first result |
| 228 | + { |
| 229 | + assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); |
| 230 | + var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); |
| 231 | + assertThat(floatResult.chunks(), hasSize(1)); |
| 232 | + assertEquals(input.get(0), floatResult.chunks().get(0).matchedText()); |
| 233 | + assertTrue(Arrays.equals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding())); |
| 234 | + } |
| 235 | + |
| 236 | + // second result |
| 237 | + { |
| 238 | + assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); |
| 239 | + var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); |
| 240 | + assertThat(floatResult.chunks(), hasSize(1)); |
| 241 | + assertEquals(input.get(1), floatResult.chunks().get(0).matchedText()); |
| 242 | + assertTrue(Arrays.equals(new float[] { 0.0456f, -0.0456f }, floatResult.chunks().get(0).embedding())); |
| 243 | + } |
| 244 | + } |
| 245 | + } |
| 246 | + |
159 | 247 | private Map<String, Object> getRequestConfigMap(
|
160 | 248 | Map<String, Object> serviceSettings,
|
161 | 249 | Map<String, Object> taskSettings,
|
|
0 commit comments