@@ -342,8 +342,7 @@ public void testVeryLongInput_Float() {
342342 }
343343 assertThat (batches .get (2000 ).batch ().inputs (), hasSize (2 ));
344344
345- // Produce inference results for each request, with just the token
346- // "word" and increasing weights.
345+ // Produce inference results for each request, with increasing weights.
347346 float weight = 0f ;
348347 for (var batch : batches ) {
349348 var embeddings = new ArrayList <TextEmbeddingFloatResults .Embedding >();
@@ -403,6 +402,91 @@ public void testVeryLongInput_Float() {
403402 assertThat (chunk .embedding (), equalTo (new float [] { 10002 / 16384f }));
404403 }
405404
405+ public void testVeryLongInput_Byte () {
406+ int batchSize = 5 ;
407+ int chunkSize = 20 ;
408+ int numberOfWordsInPassage = (chunkSize * 10000 );
409+
410+ var passageBuilder = new StringBuilder ();
411+ for (int i = 0 ; i < numberOfWordsInPassage ; i ++) {
412+ passageBuilder .append ("word" ).append (i ).append (" " ); // chunk on whitespace
413+ }
414+
415+ List <String > inputs = List .of ("1st small" , passageBuilder .toString (), "2nd small" );
416+
417+ var finalListener = testListener ();
418+ List <EmbeddingRequestChunker .BatchRequestAndListener > batches = new EmbeddingRequestChunker <>(inputs , batchSize , chunkSize , 0 )
419+ .batchRequestsWithListeners (finalListener );
420+
421+ // The very long passage is split into 10000 chunks for inference, so
422+ // there are 10002 inference requests, resulting in 2001 batches.
423+ assertThat (batches , hasSize (2001 ));
424+ for (int i = 0 ; i < 2000 ; i ++) {
425+ assertThat (batches .get (i ).batch ().inputs (), hasSize (5 ));
426+ }
427+ assertThat (batches .get (2000 ).batch ().inputs (), hasSize (2 ));
428+
429+ // Produce inference results for each request, with increasing weights.
430+ byte weight = 0 ;
431+ for (var batch : batches ) {
432+ var embeddings = new ArrayList <TextEmbeddingByteResults .Embedding >();
433+ for (int i = 0 ; i < batch .batch ().requests ().size (); i ++) {
434+ weight += 1 ;
435+ embeddings .add (new TextEmbeddingByteResults .Embedding (new byte [] { weight }));
436+ }
437+ batch .listener ().onResponse (new TextEmbeddingByteResults (embeddings ));
438+ }
439+
440+ assertNotNull (finalListener .results );
441+ assertThat (finalListener .results , hasSize (3 ));
442+
443+ // The first input has the embedding with weight 1.
444+ ChunkedInference inference = finalListener .results .get (0 );
445+ assertThat (inference , instanceOf (ChunkedInferenceEmbedding .class ));
446+ ChunkedInferenceEmbedding embedding = (ChunkedInferenceEmbedding ) inference ;
447+ assertThat (embedding .chunks (), hasSize (1 ));
448+ assertThat (embedding .chunks ().get (0 ).matchedText (), equalTo ("1st small" ));
449+ assertThat (embedding .chunks ().get (0 ), instanceOf (TextEmbeddingByteResults .Chunk .class ));
450+ TextEmbeddingByteResults .Chunk chunk = (TextEmbeddingByteResults .Chunk ) embedding .chunks ().get (0 );
451+ assertThat (chunk .embedding (), equalTo (new byte [] { 1 }));
452+
453+ // The very long passage "word0 word1 ... word199999" is split into 10000 chunks for
454+ // inference. They get the embeddings with weights 2/1024 ... 10000/16384.
455+ // Next, they are merged into 512 larger chunks, which consists of 19 or 20 smaller chunks
456+ // and therefore 380 or 400 words. For each, the average weight is collected.
457+ inference = finalListener .results .get (1 );
458+ assertThat (inference , instanceOf (ChunkedInferenceEmbedding .class ));
459+ embedding = (ChunkedInferenceEmbedding ) inference ;
460+ assertThat (embedding .chunks (), hasSize (512 ));
461+
462+ // The first merged chunk consists of 20 small chunks (so 400 words) and the weight
463+ // is the average of the weights 2 ... 21, with some round-off errors.
464+ assertThat (embedding .chunks ().get (0 ).matchedText (), startsWith ("word0 word1 " ));
465+ assertThat (embedding .chunks ().get (0 ).matchedText (), endsWith (" word398 word399" ));
466+ assertThat (embedding .chunks ().get (0 ), instanceOf (TextEmbeddingByteResults .Chunk .class ));
467+ chunk = (TextEmbeddingByteResults .Chunk ) embedding .chunks ().get (0 );
468+ assertThat (chunk .embedding (), equalTo (new byte [] { 12 }));
469+
470+ // The last merged chunk consists of 19 small chunks (so 380 words) and the weight
471+ // is the average of the weights 9983 ... 10001 modulo 256 (bytes overflowing), so
472+ // the average of -1, 0, 1, ... , 17, with some round-off errors.
473+ assertThat (embedding .chunks ().get (511 ).matchedText (), startsWith (" word199620 word199621 " ));
474+ assertThat (embedding .chunks ().get (511 ).matchedText (), endsWith (" word199998 word199999" ));
475+ assertThat (embedding .chunks ().get (511 ), instanceOf (TextEmbeddingByteResults .Chunk .class ));
476+ chunk = (TextEmbeddingByteResults .Chunk ) embedding .chunks ().get (511 );
477+ assertThat (chunk .embedding (), equalTo (new byte [] { 8 }));
478+
479+ // The last input has the token with weight 10002 % 256 = 18
480+ inference = finalListener .results .get (2 );
481+ assertThat (inference , instanceOf (ChunkedInferenceEmbedding .class ));
482+ embedding = (ChunkedInferenceEmbedding ) inference ;
483+ assertThat (embedding .chunks (), hasSize (1 ));
484+ assertThat (embedding .chunks ().get (0 ).matchedText (), equalTo ("2nd small" ));
485+ assertThat (embedding .chunks ().get (0 ), instanceOf (TextEmbeddingByteResults .Chunk .class ));
486+ chunk = (TextEmbeddingByteResults .Chunk ) embedding .chunks ().get (0 );
487+ assertThat (chunk .embedding (), equalTo (new byte [] { 18 }));
488+ }
489+
406490 public void testMergingListener_Float () {
407491 int batchSize = 5 ;
408492 int chunkSize = 20 ;
0 commit comments