Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/121827.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 121827
summary: Updates to allow using Cohere binary embedding response in semantic search
queries
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ public int getFirstEmbeddingSize() {
if (embeddings.isEmpty()) {
throw new IllegalStateException("Embeddings list is empty");
}
return embeddings.getFirst().values().length;
// bit embeddings are encoded as bytes so convert this to bits
return Byte.SIZE * embeddings.getFirst().values().length;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -709,9 +709,12 @@ yield new SparseVectorQueryBuilder(

MlTextEmbeddingResults textEmbeddingResults = (MlTextEmbeddingResults) inferenceResults;
float[] inference = textEmbeddingResults.getInferenceAsFloat();
if (inference.length != modelSettings.dimensions()) {
var inferenceLength = modelSettings.elementType() == DenseVectorFieldMapper.ElementType.BIT
? inference.length * Byte.SIZE
: inference.length;
if (inferenceLength != modelSettings.dimensions()) {
throw new IllegalArgumentException(
generateDimensionCountMismatchMessage(inference.length, modelSettings.dimensions())
generateDimensionCountMismatchMessage(inferenceLength, modelSettings.dimensions())
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel;
Expand Down Expand Up @@ -313,7 +314,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
if (model instanceof CohereEmbeddingsModel embeddingsModel) {
var serviceSettings = embeddingsModel.getServiceSettings();
var similarityFromModel = serviceSettings.similarity();
var similarityToUse = similarityFromModel == null ? defaultSimilarity() : similarityFromModel;
var similarityToUse = similarityFromModel == null ? defaultSimilarity(serviceSettings.getEmbeddingType()) : similarityFromModel;

var updatedServiceSettings = new CohereEmbeddingsServiceSettings(
new CohereServiceSettings(
Expand Down Expand Up @@ -341,7 +342,11 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
*
* @return The default similarity.
*/
static SimilarityMeasure defaultSimilarity() {
static SimilarityMeasure defaultSimilarity(CohereEmbeddingType embeddingType) {
if (embeddingType == CohereEmbeddingType.BIT || embeddingType == CohereEmbeddingType.BINARY) {
return SimilarityMeasure.L2_NORM;
}

return SimilarityMeasure.DOT_PRODUCT;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
Expand Down Expand Up @@ -377,6 +378,78 @@ public void testMergingListener_Byte() {
}
}

public void testMergingListener_Bit() {
int batchSize = 5;
int chunkSize = 20;
int overlap = 0;
// passage will be chunked into batchSize + 1 parts
// and spread over 2 batch requests
int numberOfWordsInPassage = (chunkSize * batchSize) + 5;

var passageBuilder = new StringBuilder();
for (int i = 0; i < numberOfWordsInPassage; i++) {
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
}
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");

var finalListener = testListener();
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
assertThat(batches, hasSize(2));

// 4 inputs in 2 batches
{
var embeddings = new ArrayList<TextEmbeddingByteResults.Embedding>();
for (int i = 0; i < batchSize; i++) {
embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
}
batches.get(0).listener().onResponse(new TextEmbeddingBitResults(embeddings));
}
{
var embeddings = new ArrayList<TextEmbeddingByteResults.Embedding>();
for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch
embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
}
batches.get(1).listener().onResponse(new TextEmbeddingBitResults(embeddings));
}

assertNotNull(finalListener.results);
assertThat(finalListener.results, hasSize(4));
{
var chunkedResult = finalListener.results.get(0);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(1));
assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText());
}
{
// this is the large input split in multiple chunks
var chunkedResult = finalListener.results.get(1);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(6));
assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 "));
assertThat(chunkedByteResult.chunks().get(2).matchedText(), startsWith(" passage_input40 "));
assertThat(chunkedByteResult.chunks().get(3).matchedText(), startsWith(" passage_input60 "));
assertThat(chunkedByteResult.chunks().get(4).matchedText(), startsWith(" passage_input80 "));
assertThat(chunkedByteResult.chunks().get(5).matchedText(), startsWith(" passage_input100 "));
}
{
var chunkedResult = finalListener.results.get(2);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(1));
assertEquals("2nd small", chunkedByteResult.chunks().get(0).matchedText());
}
{
var chunkedResult = finalListener.results.get(3);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(1));
assertEquals("3rd small", chunkedByteResult.chunks().get(0).matchedText());
}
}

public void testMergingListener_Sparse() {
int batchSize = 4;
int chunkSize = 10;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ public void testTransformToCoordinationFormat() {
);
}

public void testGetFirstEmbeddingSize() {
var firstEmbeddingSize = new TextEmbeddingBitResults(
List.of(
new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
)
).getFirstEmbeddingSize();

assertThat(firstEmbeddingSize, is(16));
}

@Override
protected Writeable.Reader<TextEmbeddingBitResults> instanceReader() {
return TextEmbeddingBitResults::new;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,17 @@ public void testTransformToCoordinationFormat() {
);
}

public void testGetFirstEmbeddingSize() {
var firstEmbeddingSize = new TextEmbeddingByteResults(
List.of(
new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
)
).getFirstEmbeddingSize();

assertThat(firstEmbeddingSize, is(2));
}

@Override
protected Writeable.Reader<TextEmbeddingByteResults> instanceReader() {
return TextEmbeddingByteResults::new;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ public void testTransformToCoordinationFormat() {
);
}

public void testGetFirstEmbeddingSize() {
var firstEmbeddingSize = new TextEmbeddingFloatResults(
List.of(
new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }),
new TextEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F })
)
).getFirstEmbeddingSize();

assertThat(firstEmbeddingSize, is(2));
}

@Override
protected Writeable.Reader<TextEmbeddingFloatResults> instanceReader() {
return TextEmbeddingFloatResults::new;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1100,20 +1100,23 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si

try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) {
var embeddingSize = randomNonNegativeInt();
var embeddingType = randomFrom(CohereEmbeddingType.values());
var model = CohereEmbeddingsModelTests.createModel(
randomAlphaOfLength(10),
randomAlphaOfLength(10),
CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
randomNonNegativeInt(),
randomNonNegativeInt(),
randomAlphaOfLength(10),
randomFrom(CohereEmbeddingType.values()),
embeddingType,
similarityMeasure
);

Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);

SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? CohereService.defaultSimilarity() : similarityMeasure;
SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null
? CohereService.defaultSimilarity(embeddingType)
: similarityMeasure;
assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
}
Expand Down Expand Up @@ -1590,8 +1593,15 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException {
}
}

public void testDefaultSimilarity() {
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity());
public void testDefaultSimilarity_BinaryEmbedding() {
assertEquals(SimilarityMeasure.L2_NORM, CohereService.defaultSimilarity(CohereEmbeddingType.BINARY));
assertEquals(SimilarityMeasure.L2_NORM, CohereService.defaultSimilarity(CohereEmbeddingType.BIT));
}

public void testDefaultSimilarity_NotBinaryEmbedding() {
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity(CohereEmbeddingType.FLOAT));
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity(CohereEmbeddingType.BYTE));
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity(CohereEmbeddingType.INT8));
}

public void testInfer_StreamRequest() throws Exception {
Expand Down