Skip to content

Commit 05ee2f5

Browse files
committed
Set default similarity for Cohere model to cosine
Cohere embeddings are expected to be normalized to unit vectors, but due to floating point precision issues, our check ({@link DenseVectorFieldMapper#isNotUnitVector(float)}) often fails. This change fixes this bug by setting the default similarity for newly created Cohere inference endpoint to cosine. Closes #122878
1 parent 0ff526a commit 05ee2f5

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.common.util.LazyInitializable;
1616
import org.elasticsearch.core.Nullable;
1717
import org.elasticsearch.core.TimeValue;
18+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1819
import org.elasticsearch.inference.ChunkedInference;
1920
import org.elasticsearch.inference.ChunkingSettings;
2021
import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -349,19 +350,18 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
349350
}
350351

351352
/**
352-
* Return the default similarity measure for the embedding type.
353-
* Cohere embeddings are normalized to unit vectors therefor Dot
354-
* Product similarity can be used and is the default for all Cohere
355-
* models.
353+
* Returns the default similarity measure for the embedding type.
354+
* Cohere embeddings are expected to be normalized to unit vectors, but due to floating point precision issues,
355+
* our check ({@link DenseVectorFieldMapper#isNotUnitVector(float)}) often fails.
356+
* Therefore, we use cosine similarity to ensure compatibility.
356357
*
357-
* @return The default similarity.
358+
* @return The default similarity measure.
358359
*/
359360
static SimilarityMeasure defaultSimilarity(CohereEmbeddingType embeddingType) {
360361
if (embeddingType == CohereEmbeddingType.BIT || embeddingType == CohereEmbeddingType.BINARY) {
361362
return SimilarityMeasure.L2_NORM;
362363
}
363-
364-
return SimilarityMeasure.DOT_PRODUCT;
364+
return SimilarityMeasure.COSINE;
365365
}
366366

367367
@Override

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,9 +1605,9 @@ public void testDefaultSimilarity_BinaryEmbedding() {
16051605
}
16061606

16071607
public void testDefaultSimilarity_NotBinaryEmbedding() {
1608-
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity(CohereEmbeddingType.FLOAT));
1609-
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity(CohereEmbeddingType.BYTE));
1610-
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity(CohereEmbeddingType.INT8));
1608+
assertEquals(SimilarityMeasure.COSINE, CohereService.defaultSimilarity(CohereEmbeddingType.FLOAT));
1609+
assertEquals(SimilarityMeasure.COSINE, CohereService.defaultSimilarity(CohereEmbeddingType.BYTE));
1610+
assertEquals(SimilarityMeasure.COSINE, CohereService.defaultSimilarity(CohereEmbeddingType.INT8));
16111611
}
16121612

16131613
public void testInfer_StreamRequest() throws Exception {

0 commit comments

Comments
 (0)