|
22 | 22 | import org.elasticsearch.xpack.core.inference.action.InferenceAction; |
23 | 23 | import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests; |
24 | 24 | import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; |
| 25 | +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; |
25 | 26 | import org.elasticsearch.xpack.inference.external.http.HttpClientManager; |
26 | 27 | import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; |
27 | 28 | import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; |
28 | 29 | import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; |
29 | 30 | import org.elasticsearch.xpack.inference.logging.ThrottlerManager; |
30 | 31 | import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests; |
| 32 | +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests; |
31 | 33 | import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests; |
32 | 34 | import org.elasticsearch.xpack.inference.telemetry.TraceContext; |
33 | 35 | import org.junit.After; |
@@ -256,6 +258,213 @@ public void testExecute_ReturnsSuccessfulResponse_ForRerankAction() throws IOExc |
256 | 258 | } |
257 | 259 | } |
258 | 260 |
|
| 261 | + @SuppressWarnings("unchecked") |
| 262 | + public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction() throws IOException { |
| 263 | + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); |
| 264 | + |
| 265 | + try (var sender = createSender(senderFactory)) { |
| 266 | + sender.start(); |
| 267 | + |
| 268 | + String responseJson = """ |
| 269 | + { |
| 270 | + "data": [ |
| 271 | + [ |
| 272 | + 2.1259406, |
| 273 | + 1.7073475, |
| 274 | + 0.9020516 |
| 275 | + ], |
| 276 | + [ |
| 277 | + 1.8342123, |
| 278 | + 2.3456789, |
| 279 | + 0.7654321 |
| 280 | + ] |
| 281 | + ] |
| 282 | + } |
| 283 | + """; |
| 284 | + |
| 285 | + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); |
| 286 | + |
| 287 | + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null); |
| 288 | + var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); |
| 289 | + var action = actionCreator.create(model); |
| 290 | + |
| 291 | + PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); |
| 292 | + action.execute( |
| 293 | + new EmbeddingsInput(List.of("hello world", "second text"), null, InputType.UNSPECIFIED), |
| 294 | + InferenceAction.Request.DEFAULT_TIMEOUT, |
| 295 | + listener |
| 296 | + ); |
| 297 | + |
| 298 | + var result = listener.actionGet(TIMEOUT); |
| 299 | + |
| 300 | + assertThat(result, instanceOf(TextEmbeddingFloatResults.class)); |
| 301 | + var textEmbeddingResults = (TextEmbeddingFloatResults) result; |
| 302 | + assertThat(textEmbeddingResults.embeddings(), hasSize(2)); |
| 303 | + |
| 304 | + var firstEmbedding = textEmbeddingResults.embeddings().get(0); |
| 305 | + assertThat(firstEmbedding.values(), is(new float[]{2.1259406f, 1.7073475f, 0.9020516f})); |
| 306 | + |
| 307 | + var secondEmbedding = textEmbeddingResults.embeddings().get(1); |
| 308 | + assertThat(secondEmbedding.values(), is(new float[]{1.8342123f, 2.3456789f, 0.7654321f})); |
| 309 | + |
| 310 | + assertThat(webServer.requests(), hasSize(1)); |
| 311 | + assertNull(webServer.requests().get(0).getUri().getQuery()); |
| 312 | + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); |
| 313 | + |
| 314 | + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); |
| 315 | + assertThat(requestMap.size(), is(2)); |
| 316 | + assertThat(requestMap.get("input"), instanceOf(List.class)); |
| 317 | + var inputList = (List<String>) requestMap.get("input"); |
| 318 | + assertThat(inputList, contains("hello world", "second text")); |
| 319 | + assertThat(requestMap.get("model"), is("my-dense-model-id")); |
| 320 | + } |
| 321 | + } |
| 322 | + |
| 323 | + @SuppressWarnings("unchecked") |
| 324 | + public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_WithUsageContext() throws IOException { |
| 325 | + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); |
| 326 | + |
| 327 | + try (var sender = createSender(senderFactory)) { |
| 328 | + sender.start(); |
| 329 | + |
| 330 | + String responseJson = """ |
| 331 | + { |
| 332 | + "data": [ |
| 333 | + [ |
| 334 | + 0.1234567, |
| 335 | + 0.9876543 |
| 336 | + ] |
| 337 | + ] |
| 338 | + } |
| 339 | + """; |
| 340 | + |
| 341 | + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); |
| 342 | + |
| 343 | + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null); |
| 344 | + var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); |
| 345 | + var action = actionCreator.create(model); |
| 346 | + |
| 347 | + PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); |
| 348 | + action.execute( |
| 349 | + new EmbeddingsInput(List.of("search query"), null, InputType.SEARCH), |
| 350 | + InferenceAction.Request.DEFAULT_TIMEOUT, |
| 351 | + listener |
| 352 | + ); |
| 353 | + |
| 354 | + var result = listener.actionGet(TIMEOUT); |
| 355 | + |
| 356 | + assertThat(result, instanceOf(TextEmbeddingFloatResults.class)); |
| 357 | + var textEmbeddingResults = (TextEmbeddingFloatResults) result; |
| 358 | + assertThat(textEmbeddingResults.embeddings(), hasSize(1)); |
| 359 | + |
| 360 | + var embedding = textEmbeddingResults.embeddings().get(0); |
| 361 | + assertThat(embedding.values(), is(new float[]{0.1234567f, 0.9876543f})); |
| 362 | + |
| 363 | + assertThat(webServer.requests(), hasSize(1)); |
| 364 | + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); |
| 365 | + assertThat(requestMap.size(), is(3)); |
| 366 | + assertThat(requestMap.get("input"), instanceOf(List.class)); |
| 367 | + var inputList = (List<String>) requestMap.get("input"); |
| 368 | + assertThat(inputList, contains("search query")); |
| 369 | + assertThat(requestMap.get("model"), is("my-dense-model-id")); |
| 370 | + assertThat(requestMap.get("usage_context"), is("search")); |
| 371 | + } |
| 372 | + } |
| 373 | + |
| 374 | + @SuppressWarnings("unchecked") |
| 375 | + public void testSend_FailsFromInvalidResponseFormat_ForDenseTextEmbeddingsAction() throws IOException { |
| 376 | + // timeout as zero for no retries |
| 377 | + var settings = buildSettingsWithRetryFields( |
| 378 | + TimeValue.timeValueMillis(1), |
| 379 | + TimeValue.timeValueMinutes(1), |
| 380 | + TimeValue.timeValueSeconds(0) |
| 381 | + ); |
| 382 | + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); |
| 383 | + |
| 384 | + try (var sender = createSender(senderFactory)) { |
| 385 | + sender.start(); |
| 386 | + |
| 387 | + // This will fail because the expected output is {"data": [[...]]} |
| 388 | + String responseJson = """ |
| 389 | + { |
| 390 | + "data": { |
| 391 | + "embedding": [2.1259406, 1.7073475] |
| 392 | + } |
| 393 | + } |
| 394 | + """; |
| 395 | + |
| 396 | + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); |
| 397 | + |
| 398 | + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null); |
| 399 | + var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); |
| 400 | + var action = actionCreator.create(model); |
| 401 | + |
| 402 | + PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); |
| 403 | + action.execute( |
| 404 | + new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), |
| 405 | + InferenceAction.Request.DEFAULT_TIMEOUT, |
| 406 | + listener |
| 407 | + ); |
| 408 | + |
| 409 | + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); |
| 410 | + assertThat( |
| 411 | + thrownException.getMessage(), |
| 412 | + is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]") |
| 413 | + ); |
| 414 | + |
| 415 | + assertThat(webServer.requests(), hasSize(1)); |
| 416 | + assertNull(webServer.requests().get(0).getUri().getQuery()); |
| 417 | + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); |
| 418 | + |
| 419 | + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); |
| 420 | + assertThat(requestMap.size(), is(2)); |
| 421 | + assertThat(requestMap.get("input"), instanceOf(List.class)); |
| 422 | + var inputList = (List<String>) requestMap.get("input"); |
| 423 | + assertThat(inputList, contains("hello world")); |
| 424 | + assertThat(requestMap.get("model"), is("my-dense-model-id")); |
| 425 | + } |
| 426 | + } |
| 427 | + |
| 428 | + @SuppressWarnings("unchecked") |
| 429 | + public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_EmptyEmbeddings() throws IOException { |
| 430 | + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); |
| 431 | + |
| 432 | + try (var sender = createSender(senderFactory)) { |
| 433 | + sender.start(); |
| 434 | + |
| 435 | + String responseJson = """ |
| 436 | + { |
| 437 | + "data": [] |
| 438 | + } |
| 439 | + """; |
| 440 | + |
| 441 | + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); |
| 442 | + |
| 443 | + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null); |
| 444 | + var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); |
| 445 | + var action = actionCreator.create(model); |
| 446 | + |
| 447 | + PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); |
| 448 | + action.execute( |
| 449 | + new EmbeddingsInput(List.of(), null, InputType.UNSPECIFIED), |
| 450 | + InferenceAction.Request.DEFAULT_TIMEOUT, |
| 451 | + listener |
| 452 | + ); |
| 453 | + |
| 454 | + var result = listener.actionGet(TIMEOUT); |
| 455 | + |
| 456 | + assertThat(result, instanceOf(TextEmbeddingFloatResults.class)); |
| 457 | + var textEmbeddingResults = (TextEmbeddingFloatResults) result; |
| 458 | + assertThat(textEmbeddingResults.embeddings(), hasSize(0)); |
| 459 | + |
| 460 | + assertThat(webServer.requests(), hasSize(1)); |
| 461 | + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); |
| 462 | + assertThat(requestMap.get("input"), instanceOf(List.class)); |
| 463 | + var inputList = (List<String>) requestMap.get("input"); |
| 464 | + assertThat(inputList, hasSize(0)); |
| 465 | + } |
| 466 | + } |
| 467 | + |
259 | 468 | public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException { |
260 | 469 | var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); |
261 | 470 |
|
|
0 commit comments