2020import org .elasticsearch .xpack .core .inference .results .ChunkedInferenceEmbeddingSparse ;
2121import org .elasticsearch .xpack .core .inference .results .ChunkedInferenceError ;
2222import org .elasticsearch .xpack .core .inference .results .InferenceByteEmbedding ;
23+ import org .elasticsearch .xpack .core .inference .results .InferenceTextEmbeddingBitResults ;
2324import org .elasticsearch .xpack .core .inference .results .InferenceTextEmbeddingByteResults ;
2425import org .elasticsearch .xpack .core .inference .results .InferenceTextEmbeddingFloatResults ;
2526import org .elasticsearch .xpack .core .inference .results .SparseEmbeddingResults ;
@@ -46,13 +47,14 @@ public class EmbeddingRequestChunker {
4647 public enum EmbeddingType {
4748 FLOAT ,
4849 BYTE ,
50+ BIT ,
4951 SPARSE ;
5052
5153 public static EmbeddingType fromDenseVectorElementType (DenseVectorFieldMapper .ElementType elementType ) {
5254 return switch (elementType ) {
5355 case BYTE -> EmbeddingType .BYTE ;
5456 case FLOAT -> EmbeddingType .FLOAT ;
55- case BIT -> throw new IllegalArgumentException ( "Bit vectors are not supported" ) ;
57+ case BIT -> EmbeddingType . BIT ;
5658 };
5759 }
5860 };
@@ -71,6 +73,7 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El
7173 private List <ChunkOffsetsAndInput > chunkedOffsets ;
7274 private List <AtomicArray <List <InferenceTextEmbeddingFloatResults .InferenceFloatEmbedding >>> floatResults ;
7375 private List <AtomicArray <List <InferenceByteEmbedding >>> byteResults ;
76+ private List <AtomicArray <List <InferenceByteEmbedding >>> bitResults ;
7477 private List <AtomicArray <List <SparseEmbeddingResults .Embedding >>> sparseResults ;
7578 private AtomicArray <Exception > errors ;
7679 private ActionListener <List <ChunkedInference >> finalListener ;
@@ -122,6 +125,7 @@ private void splitIntoBatchedRequests(List<String> inputs) {
122125 switch (embeddingType ) {
123126 case FLOAT -> floatResults = new ArrayList <>(inputs .size ());
124127 case BYTE -> byteResults = new ArrayList <>(inputs .size ());
128+ case BIT -> bitResults = new ArrayList <>(inputs .size ());
125129 case SPARSE -> sparseResults = new ArrayList <>(inputs .size ());
126130 }
127131 errors = new AtomicArray <>(inputs .size ());
@@ -134,6 +138,7 @@ private void splitIntoBatchedRequests(List<String> inputs) {
134138 switch (embeddingType ) {
135139 case FLOAT -> floatResults .add (new AtomicArray <>(numberOfSubBatches ));
136140 case BYTE -> byteResults .add (new AtomicArray <>(numberOfSubBatches ));
141+ case BIT -> bitResults .add (new AtomicArray <>(numberOfSubBatches ));
137142 case SPARSE -> sparseResults .add (new AtomicArray <>(numberOfSubBatches ));
138143 }
139144 chunkedOffsets .add (offSetsAndInput );
@@ -233,6 +238,7 @@ public void onResponse(InferenceServiceResults inferenceServiceResults) {
233238 switch (embeddingType ) {
234239 case FLOAT -> handleFloatResults (inferenceServiceResults );
235240 case BYTE -> handleByteResults (inferenceServiceResults );
241+ case BIT -> handleBitResults (inferenceServiceResults );
236242 case SPARSE -> handleSparseResults (inferenceServiceResults );
237243 }
238244 }
@@ -283,6 +289,29 @@ private void handleByteResults(InferenceServiceResults inferenceServiceResults)
283289 }
284290 }
285291
292+ private void handleBitResults (InferenceServiceResults inferenceServiceResults ) {
293+ if (inferenceServiceResults instanceof InferenceTextEmbeddingBitResults bitEmbeddings ) {
294+ if (failIfNumRequestsDoNotMatch (bitEmbeddings .embeddings ().size ())) {
295+ return ;
296+ }
297+
298+ int start = 0 ;
299+ for (var pos : positions ) {
300+ bitResults .get (pos .inputIndex ())
301+ .setOnce (pos .chunkIndex (), bitEmbeddings .embeddings ().subList (start , start + pos .embeddingCount ()));
302+ start += pos .embeddingCount ();
303+ }
304+
305+ if (resultCount .incrementAndGet () == totalNumberOfRequests ) {
306+ sendResponse ();
307+ }
308+ } else {
309+ onFailure (
310+ unexpectedResultTypeException (inferenceServiceResults .getWriteableName (), InferenceTextEmbeddingBitResults .NAME )
311+ );
312+ }
313+ }
314+
286315 private void handleSparseResults (InferenceServiceResults inferenceServiceResults ) {
287316 if (inferenceServiceResults instanceof SparseEmbeddingResults sparseEmbeddings ) {
288317 if (failIfNumRequestsDoNotMatch (sparseEmbeddings .embeddings ().size ())) {
@@ -358,6 +387,7 @@ private ChunkedInference mergeResultsWithInputs(int resultIndex) {
358387 return switch (embeddingType ) {
359388 case FLOAT -> mergeFloatResultsWithInputs (chunkedOffsets .get (resultIndex ), floatResults .get (resultIndex ));
360389 case BYTE -> mergeByteResultsWithInputs (chunkedOffsets .get (resultIndex ), byteResults .get (resultIndex ));
390+ case BIT -> mergeBitResultsWithInputs (chunkedOffsets .get (resultIndex ), bitResults .get (resultIndex ));
361391 case SPARSE -> mergeSparseResultsWithInputs (chunkedOffsets .get (resultIndex ), sparseResults .get (resultIndex ));
362392 };
363393 }
@@ -414,6 +444,32 @@ private ChunkedInferenceEmbeddingByte mergeByteResultsWithInputs(
414444 return new ChunkedInferenceEmbeddingByte (embeddingChunks );
415445 }
416446
447+ private ChunkedInferenceEmbeddingByte mergeBitResultsWithInputs (
448+ ChunkOffsetsAndInput chunks ,
449+ AtomicArray <List <InferenceByteEmbedding >> debatchedResults
450+ ) {
451+ var all = new ArrayList <InferenceByteEmbedding >();
452+ for (int i = 0 ; i < debatchedResults .length (); i ++) {
453+ var subBatch = debatchedResults .get (i );
454+ all .addAll (subBatch );
455+ }
456+
457+ assert chunks .size () == all .size ();
458+
459+ var embeddingChunks = new ArrayList <ChunkedInferenceEmbeddingByte .ByteEmbeddingChunk >();
460+ for (int i = 0 ; i < chunks .size (); i ++) {
461+ embeddingChunks .add (
462+ new ChunkedInferenceEmbeddingByte .ByteEmbeddingChunk (
463+ all .get (i ).values (),
464+ chunks .chunkText (i ),
465+ new ChunkedInference .TextOffset (chunks .offsets ().get (i ).start (), chunks .offsets ().get (i ).end ())
466+ )
467+ );
468+ }
469+
470+ return new ChunkedInferenceEmbeddingByte (embeddingChunks );
471+ }
472+
417473 private ChunkedInferenceEmbeddingSparse mergeSparseResultsWithInputs (
418474 ChunkOffsetsAndInput chunks ,
419475 AtomicArray <List <SparseEmbeddingResults .Embedding >> debatchedResults
0 commit comments