diff --git a/docs/changelog/125370.yaml b/docs/changelog/125370.yaml new file mode 100644 index 0000000000000..113988089776c --- /dev/null +++ b/docs/changelog/125370.yaml @@ -0,0 +1,6 @@ +pr: 125370 +summary: Set default similarity for Cohere model to cosine +area: Machine Learning +type: bug +issues: + - 122878 diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index e22852f0e78e5..9f1181a5d4382 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -349,19 +350,18 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { } /** - * Return the default similarity measure for the embedding type. - * Cohere embeddings are normalized to unit vectors therefor Dot - * Product similarity can be used and is the default for all Cohere - * models. + * Returns the default similarity measure for the embedding type. + * Cohere embeddings are expected to be normalized to unit vectors, but due to floating point precision issues, + * our check ({@link DenseVectorFieldMapper#isNotUnitVector(float)}) often fails. + * Therefore, we use cosine similarity to ensure compatibility. * - * @return The default similarity. + * @return The default similarity measure. */ static SimilarityMeasure defaultSimilarity(CohereEmbeddingType embeddingType) { if (embeddingType == CohereEmbeddingType.BIT || embeddingType == CohereEmbeddingType.BINARY) { return SimilarityMeasure.L2_NORM; } - - return SimilarityMeasure.DOT_PRODUCT; + return SimilarityMeasure.COSINE; } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 877cc294fec67..dec1052589c93 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -1605,9 +1605,9 @@ public void testDefaultSimilarity_BinaryEmbedding() { } public void testDefaultSimilarity_NotBinaryEmbedding() { - assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity(CohereEmbeddingType.FLOAT)); - assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity(CohereEmbeddingType.BYTE)); - assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity(CohereEmbeddingType.INT8)); + assertEquals(SimilarityMeasure.COSINE, CohereService.defaultSimilarity(CohereEmbeddingType.FLOAT)); + assertEquals(SimilarityMeasure.COSINE, CohereService.defaultSimilarity(CohereEmbeddingType.BYTE)); + assertEquals(SimilarityMeasure.COSINE, CohereService.defaultSimilarity(CohereEmbeddingType.INT8)); } public void testInfer_StreamRequest() throws Exception {