|
15 | 15 | import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; |
16 | 16 | import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; |
17 | 17 | import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding; |
| 18 | +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults; |
18 | 19 | import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; |
19 | 20 | import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; |
20 | 21 | import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; |
@@ -421,6 +422,79 @@ public void testMergingListener_Byte() { |
421 | 422 | } |
422 | 423 | } |
423 | 424 |
|
| 425 | + public void testMergingListener_Bit() { |
| 426 | + int batchSize = 5; |
| 427 | + int chunkSize = 20; |
| 428 | + int overlap = 0; |
| 429 | + // passage will be chunked into batchSize + 1 parts |
| 430 | + // and spread over 2 batch requests |
| 431 | + int numberOfWordsInPassage = (chunkSize * batchSize) + 5; |
| 432 | + |
| 433 | + var passageBuilder = new StringBuilder(); |
| 434 | + for (int i = 0; i < numberOfWordsInPassage; i++) { |
| 435 | + passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace |
| 436 | + } |
| 437 | + List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); |
| 438 | + |
| 439 | + var finalListener = testListener(); |
| 440 | + var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap, EmbeddingRequestChunker.EmbeddingType.BIT) |
| 441 | + .batchRequestsWithListeners(finalListener); |
| 442 | + assertThat(batches, hasSize(2)); |
| 443 | + |
| 444 | + // 4 inputs in 2 batches |
| 445 | + { |
| 446 | + var embeddings = new ArrayList<InferenceByteEmbedding>(); |
| 447 | + for (int i = 0; i < batchSize; i++) { |
| 448 | + embeddings.add(new InferenceByteEmbedding(new byte[] { randomByte() })); |
| 449 | + } |
| 450 | + batches.get(0).listener().onResponse(new InferenceTextEmbeddingBitResults(embeddings)); |
| 451 | + } |
| 452 | + { |
| 453 | + var embeddings = new ArrayList<InferenceByteEmbedding>(); |
| 454 | + for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch |
| 455 | + embeddings.add(new InferenceByteEmbedding(new byte[] { randomByte() })); |
| 456 | + } |
| 457 | + batches.get(1).listener().onResponse(new InferenceTextEmbeddingBitResults(embeddings)); |
| 458 | + } |
| 459 | + |
| 460 | + assertNotNull(finalListener.results); |
| 461 | + assertThat(finalListener.results, hasSize(4)); |
| 462 | + { |
| 463 | + var chunkedResult = finalListener.results.get(0); |
| 464 | + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); |
| 465 | + var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; |
| 466 | + assertThat(chunkedByteResult.chunks(), hasSize(1)); |
| 467 | + assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText()); |
| 468 | + } |
| 469 | + { |
| 470 | + // this is the large input split in multiple chunks |
| 471 | + var chunkedResult = finalListener.results.get(1); |
| 472 | + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); |
| 473 | + var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; |
| 474 | + assertThat(chunkedByteResult.chunks(), hasSize(6)); |
| 475 | + assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 ")); |
| 476 | + assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 ")); |
| 477 | + assertThat(chunkedByteResult.chunks().get(2).matchedText(), startsWith(" passage_input40 ")); |
| 478 | + assertThat(chunkedByteResult.chunks().get(3).matchedText(), startsWith(" passage_input60 ")); |
| 479 | + assertThat(chunkedByteResult.chunks().get(4).matchedText(), startsWith(" passage_input80 ")); |
| 480 | + assertThat(chunkedByteResult.chunks().get(5).matchedText(), startsWith(" passage_input100 ")); |
| 481 | + } |
| 482 | + { |
| 483 | + var chunkedResult = finalListener.results.get(2); |
| 484 | + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); |
| 485 | + var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; |
| 486 | + assertThat(chunkedByteResult.chunks(), hasSize(1)); |
| 487 | + assertEquals("2nd small", chunkedByteResult.chunks().get(0).matchedText()); |
| 488 | + } |
| 489 | + { |
| 490 | + var chunkedResult = finalListener.results.get(3); |
| 491 | + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); |
| 492 | + var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; |
| 493 | + assertThat(chunkedByteResult.chunks(), hasSize(1)); |
| 494 | + assertEquals("3rd small", chunkedByteResult.chunks().get(0).matchedText()); |
| 495 | + } |
| 496 | + } |
| 497 | + |
424 | 498 | public void testMergingListener_Sparse() { |
425 | 499 | int batchSize = 4; |
426 | 500 | int chunkSize = 10; |
|
0 commit comments