Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ public InferenceTextEmbeddingBitResults(StreamInput in) throws IOException {

@Override
public int getFirstEmbeddingSize() {
return TextEmbeddingUtils.getFirstEmbeddingSize(new ArrayList<>(embeddings));
// bit embeddings are encoded as bytes so convert this to bits
return Byte.SIZE * TextEmbeddingUtils.getFirstEmbeddingSize(new ArrayList<>(embeddings));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
Expand All @@ -46,13 +47,14 @@ public class EmbeddingRequestChunker {
public enum EmbeddingType {
FLOAT,
BYTE,
BIT,
SPARSE;

public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.ElementType elementType) {
return switch (elementType) {
case BYTE -> EmbeddingType.BYTE;
case FLOAT -> EmbeddingType.FLOAT;
case BIT -> throw new IllegalArgumentException("Bit vectors are not supported");
case BIT -> EmbeddingType.BIT;
};
}
};
Expand All @@ -71,6 +73,7 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El
private List<ChunkOffsetsAndInput> chunkedOffsets;
private List<AtomicArray<List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding>>> floatResults;
private List<AtomicArray<List<InferenceByteEmbedding>>> byteResults;
private List<AtomicArray<List<InferenceByteEmbedding>>> bitResults;
private List<AtomicArray<List<SparseEmbeddingResults.Embedding>>> sparseResults;
private AtomicArray<Exception> errors;
private ActionListener<List<ChunkedInference>> finalListener;
Expand Down Expand Up @@ -122,6 +125,7 @@ private void splitIntoBatchedRequests(List<String> inputs) {
switch (embeddingType) {
case FLOAT -> floatResults = new ArrayList<>(inputs.size());
case BYTE -> byteResults = new ArrayList<>(inputs.size());
case BIT -> bitResults = new ArrayList<>(inputs.size());
case SPARSE -> sparseResults = new ArrayList<>(inputs.size());
}
errors = new AtomicArray<>(inputs.size());
Expand All @@ -134,6 +138,7 @@ private void splitIntoBatchedRequests(List<String> inputs) {
switch (embeddingType) {
case FLOAT -> floatResults.add(new AtomicArray<>(numberOfSubBatches));
case BYTE -> byteResults.add(new AtomicArray<>(numberOfSubBatches));
case BIT -> bitResults.add(new AtomicArray<>(numberOfSubBatches));
case SPARSE -> sparseResults.add(new AtomicArray<>(numberOfSubBatches));
}
chunkedOffsets.add(offSetsAndInput);
Expand Down Expand Up @@ -233,6 +238,7 @@ public void onResponse(InferenceServiceResults inferenceServiceResults) {
switch (embeddingType) {
case FLOAT -> handleFloatResults(inferenceServiceResults);
case BYTE -> handleByteResults(inferenceServiceResults);
case BIT -> handleBitResults(inferenceServiceResults);
case SPARSE -> handleSparseResults(inferenceServiceResults);
}
}
Expand Down Expand Up @@ -283,6 +289,27 @@ private void handleByteResults(InferenceServiceResults inferenceServiceResults)
}
}

private void handleBitResults(InferenceServiceResults inferenceServiceResults) {
if (inferenceServiceResults instanceof InferenceTextEmbeddingBitResults bitEmbeddings) {
if (failIfNumRequestsDoNotMatch(bitEmbeddings.embeddings().size())) {
return;
}

int start = 0;
for (var pos : positions) {
bitResults.get(pos.inputIndex())
.setOnce(pos.chunkIndex(), bitEmbeddings.embeddings().subList(start, start + pos.embeddingCount()));
start += pos.embeddingCount();
}

if (resultCount.incrementAndGet() == totalNumberOfRequests) {
sendResponse();
}
} else {
onFailure(unexpectedResultTypeException(inferenceServiceResults.getWriteableName(), InferenceTextEmbeddingBitResults.NAME));
}
}

private void handleSparseResults(InferenceServiceResults inferenceServiceResults) {
if (inferenceServiceResults instanceof SparseEmbeddingResults sparseEmbeddings) {
if (failIfNumRequestsDoNotMatch(sparseEmbeddings.embeddings().size())) {
Expand Down Expand Up @@ -358,6 +385,7 @@ private ChunkedInference mergeResultsWithInputs(int resultIndex) {
return switch (embeddingType) {
case FLOAT -> mergeFloatResultsWithInputs(chunkedOffsets.get(resultIndex), floatResults.get(resultIndex));
case BYTE -> mergeByteResultsWithInputs(chunkedOffsets.get(resultIndex), byteResults.get(resultIndex));
case BIT -> mergeBitResultsWithInputs(chunkedOffsets.get(resultIndex), bitResults.get(resultIndex));
case SPARSE -> mergeSparseResultsWithInputs(chunkedOffsets.get(resultIndex), sparseResults.get(resultIndex));
};
}
Expand Down Expand Up @@ -414,6 +442,32 @@ private ChunkedInferenceEmbeddingByte mergeByteResultsWithInputs(
return new ChunkedInferenceEmbeddingByte(embeddingChunks);
}

private ChunkedInferenceEmbeddingByte mergeBitResultsWithInputs(
ChunkOffsetsAndInput chunks,
AtomicArray<List<InferenceByteEmbedding>> debatchedResults
) {
var all = new ArrayList<InferenceByteEmbedding>();
for (int i = 0; i < debatchedResults.length(); i++) {
var subBatch = debatchedResults.get(i);
all.addAll(subBatch);
}

assert chunks.size() == all.size();

var embeddingChunks = new ArrayList<ChunkedInferenceEmbeddingByte.ByteEmbeddingChunk>();
for (int i = 0; i < chunks.size(); i++) {
embeddingChunks.add(
new ChunkedInferenceEmbeddingByte.ByteEmbeddingChunk(
all.get(i).values(),
chunks.chunkText(i),
new ChunkedInference.TextOffset(chunks.offsets().get(i).start(), chunks.offsets().get(i).end())
)
);
}

return new ChunkedInferenceEmbeddingByte(embeddingChunks);
}

private ChunkedInferenceEmbeddingSparse mergeSparseResultsWithInputs(
ChunkOffsetsAndInput chunks,
AtomicArray<List<SparseEmbeddingResults.Embedding>> debatchedResults
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,10 @@ yield new SparseVectorQueryBuilder(

MlTextEmbeddingResults textEmbeddingResults = (MlTextEmbeddingResults) inferenceResults;
float[] inference = textEmbeddingResults.getInferenceAsFloat();
if (inference.length != modelSettings.dimensions()) {
var inferenceLength = modelSettings.elementType() == DenseVectorFieldMapper.ElementType.BIT
? inference.length * Byte.SIZE
: inference.length;
if (inferenceLength != modelSettings.dimensions()) {
throw new IllegalArgumentException(
generateDimensionCountMismatchMessage(inference.length, modelSettings.dimensions())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure to update this error message with the correct dimension count as well

);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel;
Expand Down Expand Up @@ -314,7 +315,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
if (model instanceof CohereEmbeddingsModel embeddingsModel) {
var serviceSettings = embeddingsModel.getServiceSettings();
var similarityFromModel = serviceSettings.similarity();
var similarityToUse = similarityFromModel == null ? defaultSimilarity() : similarityFromModel;
var similarityToUse = similarityFromModel == null ? defaultSimilarity(serviceSettings.getEmbeddingType()) : similarityFromModel;

var updatedServiceSettings = new CohereEmbeddingsServiceSettings(
new CohereServiceSettings(
Expand Down Expand Up @@ -342,7 +343,11 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
*
* @return The default similarity.
*/
static SimilarityMeasure defaultSimilarity() {
static SimilarityMeasure defaultSimilarity(CohereEmbeddingType embeddingType) {
if (embeddingType == CohereEmbeddingType.BIT || embeddingType == CohereEmbeddingType.BINARY) {
return SimilarityMeasure.L2_NORM;
}

return SimilarityMeasure.DOT_PRODUCT;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
Expand Down Expand Up @@ -421,6 +422,79 @@ public void testMergingListener_Byte() {
}
}

public void testMergingListener_Bit() {
int batchSize = 5;
int chunkSize = 20;
int overlap = 0;
// passage will be chunked into batchSize + 1 parts
// and spread over 2 batch requests
int numberOfWordsInPassage = (chunkSize * batchSize) + 5;

var passageBuilder = new StringBuilder();
for (int i = 0; i < numberOfWordsInPassage; i++) {
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
}
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");

var finalListener = testListener();
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap, EmbeddingRequestChunker.EmbeddingType.BIT)
.batchRequestsWithListeners(finalListener);
assertThat(batches, hasSize(2));

// 4 inputs in 2 batches
{
var embeddings = new ArrayList<InferenceByteEmbedding>();
for (int i = 0; i < batchSize; i++) {
embeddings.add(new InferenceByteEmbedding(new byte[] { randomByte() }));
}
batches.get(0).listener().onResponse(new InferenceTextEmbeddingBitResults(embeddings));
}
{
var embeddings = new ArrayList<InferenceByteEmbedding>();
for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch
embeddings.add(new InferenceByteEmbedding(new byte[] { randomByte() }));
}
batches.get(1).listener().onResponse(new InferenceTextEmbeddingBitResults(embeddings));
}

assertNotNull(finalListener.results);
assertThat(finalListener.results, hasSize(4));
{
var chunkedResult = finalListener.results.get(0);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class));
var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(1));
assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText());
}
{
// this is the large input split in multiple chunks
var chunkedResult = finalListener.results.get(1);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class));
var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(6));
assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 "));
assertThat(chunkedByteResult.chunks().get(2).matchedText(), startsWith(" passage_input40 "));
assertThat(chunkedByteResult.chunks().get(3).matchedText(), startsWith(" passage_input60 "));
assertThat(chunkedByteResult.chunks().get(4).matchedText(), startsWith(" passage_input80 "));
assertThat(chunkedByteResult.chunks().get(5).matchedText(), startsWith(" passage_input100 "));
}
{
var chunkedResult = finalListener.results.get(2);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class));
var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(1));
assertEquals("2nd small", chunkedByteResult.chunks().get(0).matchedText());
}
{
var chunkedResult = finalListener.results.get(3);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class));
var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(1));
assertEquals("3rd small", chunkedByteResult.chunks().get(0).matchedText());
}
}

public void testMergingListener_Sparse() {
int batchSize = 4;
int chunkSize = 10;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@ public void testTransformToCoordinationFormat() {
);
}

public void testGetFirstEmbeddingSize() {
var firstEmbeddingSize = new InferenceTextEmbeddingBitResults(
List.of(
new InferenceByteEmbedding(new byte[] { (byte) 23, (byte) 24 }),
new InferenceByteEmbedding(new byte[] { (byte) 25, (byte) 26 })
)
).getFirstEmbeddingSize();

assertThat(firstEmbeddingSize, is(16));
}

@Override
protected Writeable.Reader<InferenceTextEmbeddingBitResults> instanceReader() {
return InferenceTextEmbeddingBitResults::new;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@ public void testTransformToCoordinationFormat() {
);
}

public void testGetFirstEmbeddingSize() {
var firstEmbeddingSize = new InferenceTextEmbeddingByteResults(
List.of(
new InferenceByteEmbedding(new byte[] { (byte) 23, (byte) 24 }),
new InferenceByteEmbedding(new byte[] { (byte) 25, (byte) 26 })
)
).getFirstEmbeddingSize();

assertThat(firstEmbeddingSize, is(2));
}

@Override
protected Writeable.Reader<InferenceTextEmbeddingByteResults> instanceReader() {
return InferenceTextEmbeddingByteResults::new;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ public void testTransformToCoordinationFormat() {
);
}

public void testGetFirstEmbeddingSize() {
var firstEmbeddingSize = new InferenceTextEmbeddingFloatResults(
List.of(
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.1F, 0.2F }),
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.3F, 0.4F })
)
).getFirstEmbeddingSize();

assertThat(firstEmbeddingSize, is(2));
}

@Override
protected Writeable.Reader<InferenceTextEmbeddingFloatResults> instanceReader() {
return InferenceTextEmbeddingFloatResults::new;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1099,20 +1099,23 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si

try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) {
var embeddingSize = randomNonNegativeInt();
var embeddingType = randomFrom(CohereEmbeddingType.values());
var model = CohereEmbeddingsModelTests.createModel(
randomAlphaOfLength(10),
randomAlphaOfLength(10),
CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
randomNonNegativeInt(),
randomNonNegativeInt(),
randomAlphaOfLength(10),
randomFrom(CohereEmbeddingType.values()),
embeddingType,
similarityMeasure
);

Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);

SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? CohereService.defaultSimilarity() : similarityMeasure;
SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null
? CohereService.defaultSimilarity(embeddingType)
: similarityMeasure;
assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
}
Expand Down Expand Up @@ -1579,8 +1582,15 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException {
}
}

public void testDefaultSimilarity() {
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity());
public void testDefaultSimilarity_BinaryEmbedding() {
assertEquals(SimilarityMeasure.L2_NORM, CohereService.defaultSimilarity(CohereEmbeddingType.BINARY));
assertEquals(SimilarityMeasure.L2_NORM, CohereService.defaultSimilarity(CohereEmbeddingType.BIT));
}

public void testDefaultSimilarity_NotBinaryEmbedding() {
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity(CohereEmbeddingType.FLOAT));
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity(CohereEmbeddingType.BYTE));
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity(CohereEmbeddingType.INT8));
}

public void testInfer_StreamRequest() throws Exception {
Expand Down