Skip to content

Commit b06e298

Browse files
committed
wip
1 parent 5ab175e commit b06e298

File tree

4 files changed

+70
-5
lines changed

4 files changed

+70
-5
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceTextEmbeddingBitResults.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ public InferenceTextEmbeddingBitResults(StreamInput in) throws IOException {
5252

5353
@Override
5454
public int getFirstEmbeddingSize() {
55-
return TextEmbeddingUtils.getFirstEmbeddingSize(new ArrayList<>(embeddings));
55+
// bit embeddings are encoded as bytes so convert this to bits
56+
return Byte.SIZE * TextEmbeddingUtils.getFirstEmbeddingSize(new ArrayList<>(embeddings));
5657
}
5758

5859
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
2121
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
2222
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
23+
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults;
2324
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
2425
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
2526
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
@@ -46,13 +47,14 @@ public class EmbeddingRequestChunker {
4647
public enum EmbeddingType {
4748
FLOAT,
4849
BYTE,
50+
BIT,
4951
SPARSE;
5052

5153
public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.ElementType elementType) {
5254
return switch (elementType) {
5355
case BYTE -> EmbeddingType.BYTE;
5456
case FLOAT -> EmbeddingType.FLOAT;
55-
case BIT -> throw new IllegalArgumentException("Bit vectors are not supported");
57+
case BIT -> EmbeddingType.BIT;
5658
};
5759
}
5860
};
@@ -71,6 +73,7 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El
7173
private List<ChunkOffsetsAndInput> chunkedOffsets;
7274
private List<AtomicArray<List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding>>> floatResults;
7375
private List<AtomicArray<List<InferenceByteEmbedding>>> byteResults;
76+
private List<AtomicArray<List<InferenceByteEmbedding>>> bitResults;
7477
private List<AtomicArray<List<SparseEmbeddingResults.Embedding>>> sparseResults;
7578
private AtomicArray<Exception> errors;
7679
private ActionListener<List<ChunkedInference>> finalListener;
@@ -122,6 +125,7 @@ private void splitIntoBatchedRequests(List<String> inputs) {
122125
switch (embeddingType) {
123126
case FLOAT -> floatResults = new ArrayList<>(inputs.size());
124127
case BYTE -> byteResults = new ArrayList<>(inputs.size());
128+
case BIT -> bitResults = new ArrayList<>(inputs.size());
125129
case SPARSE -> sparseResults = new ArrayList<>(inputs.size());
126130
}
127131
errors = new AtomicArray<>(inputs.size());
@@ -134,6 +138,7 @@ private void splitIntoBatchedRequests(List<String> inputs) {
134138
switch (embeddingType) {
135139
case FLOAT -> floatResults.add(new AtomicArray<>(numberOfSubBatches));
136140
case BYTE -> byteResults.add(new AtomicArray<>(numberOfSubBatches));
141+
case BIT -> bitResults.add(new AtomicArray<>(numberOfSubBatches));
137142
case SPARSE -> sparseResults.add(new AtomicArray<>(numberOfSubBatches));
138143
}
139144
chunkedOffsets.add(offSetsAndInput);
@@ -233,6 +238,7 @@ public void onResponse(InferenceServiceResults inferenceServiceResults) {
233238
switch (embeddingType) {
234239
case FLOAT -> handleFloatResults(inferenceServiceResults);
235240
case BYTE -> handleByteResults(inferenceServiceResults);
241+
case BIT -> handleBitResults(inferenceServiceResults);
236242
case SPARSE -> handleSparseResults(inferenceServiceResults);
237243
}
238244
}
@@ -283,6 +289,29 @@ private void handleByteResults(InferenceServiceResults inferenceServiceResults)
283289
}
284290
}
285291

292+
private void handleBitResults(InferenceServiceResults inferenceServiceResults) {
293+
if (inferenceServiceResults instanceof InferenceTextEmbeddingBitResults bitEmbeddings) {
294+
if (failIfNumRequestsDoNotMatch(bitEmbeddings.embeddings().size())) {
295+
return;
296+
}
297+
298+
int start = 0;
299+
for (var pos : positions) {
300+
bitResults.get(pos.inputIndex())
301+
.setOnce(pos.chunkIndex(), bitEmbeddings.embeddings().subList(start, start + pos.embeddingCount()));
302+
start += pos.embeddingCount();
303+
}
304+
305+
if (resultCount.incrementAndGet() == totalNumberOfRequests) {
306+
sendResponse();
307+
}
308+
} else {
309+
onFailure(
310+
unexpectedResultTypeException(inferenceServiceResults.getWriteableName(), InferenceTextEmbeddingBitResults.NAME)
311+
);
312+
}
313+
}
314+
286315
private void handleSparseResults(InferenceServiceResults inferenceServiceResults) {
287316
if (inferenceServiceResults instanceof SparseEmbeddingResults sparseEmbeddings) {
288317
if (failIfNumRequestsDoNotMatch(sparseEmbeddings.embeddings().size())) {
@@ -358,6 +387,7 @@ private ChunkedInference mergeResultsWithInputs(int resultIndex) {
358387
return switch (embeddingType) {
359388
case FLOAT -> mergeFloatResultsWithInputs(chunkedOffsets.get(resultIndex), floatResults.get(resultIndex));
360389
case BYTE -> mergeByteResultsWithInputs(chunkedOffsets.get(resultIndex), byteResults.get(resultIndex));
390+
case BIT -> mergeBitResultsWithInputs(chunkedOffsets.get(resultIndex), bitResults.get(resultIndex));
361391
case SPARSE -> mergeSparseResultsWithInputs(chunkedOffsets.get(resultIndex), sparseResults.get(resultIndex));
362392
};
363393
}
@@ -414,6 +444,32 @@ private ChunkedInferenceEmbeddingByte mergeByteResultsWithInputs(
414444
return new ChunkedInferenceEmbeddingByte(embeddingChunks);
415445
}
416446

447+
private ChunkedInferenceEmbeddingByte mergeBitResultsWithInputs(
448+
ChunkOffsetsAndInput chunks,
449+
AtomicArray<List<InferenceByteEmbedding>> debatchedResults
450+
) {
451+
var all = new ArrayList<InferenceByteEmbedding>();
452+
for (int i = 0; i < debatchedResults.length(); i++) {
453+
var subBatch = debatchedResults.get(i);
454+
all.addAll(subBatch);
455+
}
456+
457+
assert chunks.size() == all.size();
458+
459+
var embeddingChunks = new ArrayList<ChunkedInferenceEmbeddingByte.ByteEmbeddingChunk>();
460+
for (int i = 0; i < chunks.size(); i++) {
461+
embeddingChunks.add(
462+
new ChunkedInferenceEmbeddingByte.ByteEmbeddingChunk(
463+
all.get(i).values(),
464+
chunks.chunkText(i),
465+
new ChunkedInference.TextOffset(chunks.offsets().get(i).start(), chunks.offsets().get(i).end())
466+
)
467+
);
468+
}
469+
470+
return new ChunkedInferenceEmbeddingByte(embeddingChunks);
471+
}
472+
417473
private ChunkedInferenceEmbeddingSparse mergeSparseResultsWithInputs(
418474
ChunkOffsetsAndInput chunks,
419475
AtomicArray<List<SparseEmbeddingResults.Embedding>> debatchedResults

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,10 @@ yield new SparseVectorQueryBuilder(
708708

709709
MlTextEmbeddingResults textEmbeddingResults = (MlTextEmbeddingResults) inferenceResults;
710710
float[] inference = textEmbeddingResults.getInferenceAsFloat();
711-
if (inference.length != modelSettings.dimensions()) {
711+
var inferenceLength = modelSettings.elementType() == DenseVectorFieldMapper.ElementType.BIT ? inference.length * Byte.SIZE : inference.length;
712+
System.out.println("AAA inference.length " + inference.length);
713+
System.out.println("AAA inferenceLength " + inferenceLength);
714+
if (inferenceLength != modelSettings.dimensions()) {
712715
throw new IllegalArgumentException(
713716
generateDimensionCountMismatchMessage(inference.length, modelSettings.dimensions())
714717
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.elasticsearch.xpack.inference.services.ServiceComponents;
4040
import org.elasticsearch.xpack.inference.services.ServiceUtils;
4141
import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel;
42+
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
4243
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
4344
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettings;
4445
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel;
@@ -314,7 +315,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
314315
if (model instanceof CohereEmbeddingsModel embeddingsModel) {
315316
var serviceSettings = embeddingsModel.getServiceSettings();
316317
var similarityFromModel = serviceSettings.similarity();
317-
var similarityToUse = similarityFromModel == null ? defaultSimilarity() : similarityFromModel;
318+
var similarityToUse = similarityFromModel == null ? defaultSimilarity(serviceSettings.getEmbeddingType()) : similarityFromModel;
318319

319320
var updatedServiceSettings = new CohereEmbeddingsServiceSettings(
320321
new CohereServiceSettings(
@@ -342,7 +343,11 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
342343
*
343344
* @return The default similarity.
344345
*/
345-
static SimilarityMeasure defaultSimilarity() {
346+
static SimilarityMeasure defaultSimilarity(CohereEmbeddingType embeddingType) {
347+
if (embeddingType == CohereEmbeddingType.BIT || embeddingType == CohereEmbeddingType.BINARY) {
348+
return SimilarityMeasure.L2_NORM;
349+
}
350+
346351
return SimilarityMeasure.DOT_PRODUCT;
347352
}
348353

0 commit comments

Comments
 (0)