|
13 | 13 | import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; |
14 | 14 | import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; |
15 | 15 | import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; |
| 16 | +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; |
16 | 17 | import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; |
17 | 18 | import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; |
18 | 19 | import org.elasticsearch.xpack.core.ml.search.WeightedToken; |
@@ -377,78 +378,78 @@ public void testMergingListener_Byte() { |
377 | 378 | } |
378 | 379 | } |
379 | 380 |
|
380 | | -// public void testMergingListener_Bit() { |
381 | | -// int batchSize = 5; |
382 | | -// int chunkSize = 20; |
383 | | -// int overlap = 0; |
384 | | -// // passage will be chunked into batchSize + 1 parts |
385 | | -// // and spread over 2 batch requests |
386 | | -// int numberOfWordsInPassage = (chunkSize * batchSize) + 5; |
387 | | -// |
388 | | -// var passageBuilder = new StringBuilder(); |
389 | | -// for (int i = 0; i < numberOfWordsInPassage; i++) { |
390 | | -// passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace |
391 | | -// } |
392 | | -// List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); |
393 | | -// |
394 | | -// var finalListener = testListener(); |
395 | | -// var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap, EmbeddingRequestChunker.EmbeddingType.BIT) |
396 | | -// .batchRequestsWithListeners(finalListener); |
397 | | -// assertThat(batches, hasSize(2)); |
398 | | -// |
399 | | -// // 4 inputs in 2 batches |
400 | | -// { |
401 | | -// var embeddings = new ArrayList<InferenceByteEmbedding>(); |
402 | | -// for (int i = 0; i < batchSize; i++) { |
403 | | -// embeddings.add(new InferenceByteEmbedding(new byte[] { randomByte() })); |
404 | | -// } |
405 | | -// batches.get(0).listener().onResponse(new InferenceTextEmbeddingBitResults(embeddings)); |
406 | | -// } |
407 | | -// { |
408 | | -// var embeddings = new ArrayList<InferenceByteEmbedding>(); |
409 | | -// for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch |
410 | | -// embeddings.add(new InferenceByteEmbedding(new byte[] { randomByte() })); |
411 | | -// } |
412 | | -// batches.get(1).listener().onResponse(new InferenceTextEmbeddingBitResults(embeddings)); |
413 | | -// } |
414 | | -// |
415 | | -// assertNotNull(finalListener.results); |
416 | | -// assertThat(finalListener.results, hasSize(4)); |
417 | | -// { |
418 | | -// var chunkedResult = finalListener.results.get(0); |
419 | | -// assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); |
420 | | -// var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; |
421 | | -// assertThat(chunkedByteResult.chunks(), hasSize(1)); |
422 | | -// assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText()); |
423 | | -// } |
424 | | -// { |
425 | | -// // this is the large input split in multiple chunks |
426 | | -// var chunkedResult = finalListener.results.get(1); |
427 | | -// assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); |
428 | | -// var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; |
429 | | -// assertThat(chunkedByteResult.chunks(), hasSize(6)); |
430 | | -// assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 ")); |
431 | | -// assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 ")); |
432 | | -// assertThat(chunkedByteResult.chunks().get(2).matchedText(), startsWith(" passage_input40 ")); |
433 | | -// assertThat(chunkedByteResult.chunks().get(3).matchedText(), startsWith(" passage_input60 ")); |
434 | | -// assertThat(chunkedByteResult.chunks().get(4).matchedText(), startsWith(" passage_input80 ")); |
435 | | -// assertThat(chunkedByteResult.chunks().get(5).matchedText(), startsWith(" passage_input100 ")); |
436 | | -// } |
437 | | -// { |
438 | | -// var chunkedResult = finalListener.results.get(2); |
439 | | -// assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); |
440 | | -// var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; |
441 | | -// assertThat(chunkedByteResult.chunks(), hasSize(1)); |
442 | | -// assertEquals("2nd small", chunkedByteResult.chunks().get(0).matchedText()); |
443 | | -// } |
444 | | -// { |
445 | | -// var chunkedResult = finalListener.results.get(3); |
446 | | -// assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); |
447 | | -// var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; |
448 | | -// assertThat(chunkedByteResult.chunks(), hasSize(1)); |
449 | | -// assertEquals("3rd small", chunkedByteResult.chunks().get(0).matchedText()); |
450 | | -// } |
451 | | -// } |
| 381 | + public void testMergingListener_Bit() { |
| 382 | + int batchSize = 5; |
| 383 | + int chunkSize = 20; |
| 384 | + int overlap = 0; |
| 385 | + // passage will be chunked into batchSize + 1 parts |
| 386 | + // and spread over 2 batch requests |
| 387 | + int numberOfWordsInPassage = (chunkSize * batchSize) + 5; |
| 388 | + |
| 389 | + var passageBuilder = new StringBuilder(); |
| 390 | + for (int i = 0; i < numberOfWordsInPassage; i++) { |
| 391 | + passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace |
| 392 | + } |
| 393 | + List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); |
| 394 | + |
| 395 | + var finalListener = testListener(); |
| 396 | + var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap) |
| 397 | + .batchRequestsWithListeners(finalListener); |
| 398 | + assertThat(batches, hasSize(2)); |
| 399 | + |
| 400 | + // 4 inputs in 2 batches |
| 401 | + { |
| 402 | + var embeddings = new ArrayList<TextEmbeddingByteResults.Embedding>(); |
| 403 | + for (int i = 0; i < batchSize; i++) { |
| 404 | + embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() })); |
| 405 | + } |
| 406 | + batches.get(0).listener().onResponse(new TextEmbeddingBitResults(embeddings)); |
| 407 | + } |
| 408 | + { |
| 409 | + var embeddings = new ArrayList<TextEmbeddingByteResults.Embedding>(); |
| 410 | + for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch |
| 411 | + embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() })); |
| 412 | + } |
| 413 | + batches.get(1).listener().onResponse(new TextEmbeddingBitResults(embeddings)); |
| 414 | + } |
| 415 | + |
| 416 | + assertNotNull(finalListener.results); |
| 417 | + assertThat(finalListener.results, hasSize(4)); |
| 418 | + { |
| 419 | + var chunkedResult = finalListener.results.get(0); |
| 420 | + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); |
| 421 | + var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; |
| 422 | + assertThat(chunkedByteResult.chunks(), hasSize(1)); |
| 423 | + assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText()); |
| 424 | + } |
| 425 | + { |
| 426 | + // this is the large input split in multiple chunks |
| 427 | + var chunkedResult = finalListener.results.get(1); |
| 428 | + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); |
| 429 | + var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; |
| 430 | + assertThat(chunkedByteResult.chunks(), hasSize(6)); |
| 431 | + assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 ")); |
| 432 | + assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 ")); |
| 433 | + assertThat(chunkedByteResult.chunks().get(2).matchedText(), startsWith(" passage_input40 ")); |
| 434 | + assertThat(chunkedByteResult.chunks().get(3).matchedText(), startsWith(" passage_input60 ")); |
| 435 | + assertThat(chunkedByteResult.chunks().get(4).matchedText(), startsWith(" passage_input80 ")); |
| 436 | + assertThat(chunkedByteResult.chunks().get(5).matchedText(), startsWith(" passage_input100 ")); |
| 437 | + } |
| 438 | + { |
| 439 | + var chunkedResult = finalListener.results.get(2); |
| 440 | + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); |
| 441 | + var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; |
| 442 | + assertThat(chunkedByteResult.chunks(), hasSize(1)); |
| 443 | + assertEquals("2nd small", chunkedByteResult.chunks().get(0).matchedText()); |
| 444 | + } |
| 445 | + { |
| 446 | + var chunkedResult = finalListener.results.get(3); |
| 447 | + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); |
| 448 | + var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; |
| 449 | + assertThat(chunkedByteResult.chunks(), hasSize(1)); |
| 450 | + assertEquals("3rd small", chunkedByteResult.chunks().get(0).matchedText()); |
| 451 | + } |
| 452 | + } |
452 | 453 |
|
453 | 454 | public void testMergingListener_Sparse() { |
454 | 455 | int batchSize = 4; |
|
0 commit comments