Skip to content

Commit ca90d62

Browse files
committed
Setting default similarity to L2 norm for binary embedding type
1 parent 6b3d4f1 commit ca90d62

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.elasticsearch.xpack.inference.services.SenderService;
3939
import org.elasticsearch.xpack.inference.services.ServiceComponents;
4040
import org.elasticsearch.xpack.inference.services.ServiceUtils;
41+
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
4142
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
4243
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings;
4344
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel;
@@ -294,7 +295,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
294295
if (model instanceof JinaAIEmbeddingsModel embeddingsModel) {
295296
var serviceSettings = embeddingsModel.getServiceSettings();
296297
var similarityFromModel = serviceSettings.similarity();
297-
var similarityToUse = similarityFromModel == null ? defaultSimilarity() : similarityFromModel;
298+
var similarityToUse = similarityFromModel == null ? defaultSimilarity(serviceSettings.getEmbeddingType()) : similarityFromModel;
298299
var maxInputTokens = serviceSettings.maxInputTokens();
299300

300301
var updatedServiceSettings = new JinaAIEmbeddingsServiceSettings(
@@ -323,7 +324,10 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
323324
*
324325
* @return The default similarity.
325326
*/
326-
static SimilarityMeasure defaultSimilarity() {
327+
static SimilarityMeasure defaultSimilarity(JinaAIEmbeddingType embeddingType) {
328+
if (embeddingType == JinaAIEmbeddingType.BINARY || embeddingType == JinaAIEmbeddingType.BIT) {
329+
return SimilarityMeasure.L2_NORM;
330+
}
327331
return SimilarityMeasure.DOT_PRODUCT;
328332
}
329333

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si
996996

997997
try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) {
998998
var embeddingSize = randomNonNegativeInt();
999+
var embeddingType = randomFrom(JinaAIEmbeddingType.values());
9991000
var model = JinaAIEmbeddingsModelTests.createModel(
10001001
randomAlphaOfLength(10),
10011002
randomAlphaOfLength(10),
@@ -1004,12 +1005,14 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si
10041005
randomNonNegativeInt(),
10051006
randomAlphaOfLength(10),
10061007
similarityMeasure,
1007-
JinaAIEmbeddingType.FLOAT
1008+
embeddingType
10081009
);
10091010

10101011
Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
10111012

1012-
SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? JinaAIService.defaultSimilarity() : similarityMeasure;
1013+
SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null
1014+
? JinaAIService.defaultSimilarity(embeddingType)
1015+
: similarityMeasure;
10131016
assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
10141017
assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
10151018
}
@@ -1866,8 +1869,13 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel mode
18661869
}
18671870
}
18681871

1869-
public void testDefaultSimilarity() {
1870-
assertEquals(SimilarityMeasure.DOT_PRODUCT, JinaAIService.defaultSimilarity());
1872+
public void testDefaultSimilarity_BinaryEmbedding() {
1873+
assertEquals(SimilarityMeasure.L2_NORM, JinaAIService.defaultSimilarity(JinaAIEmbeddingType.BINARY));
1874+
assertEquals(SimilarityMeasure.L2_NORM, JinaAIService.defaultSimilarity(JinaAIEmbeddingType.BIT));
1875+
}
1876+
1877+
public void testDefaultSimilarity_NotBinaryEmbedding() {
1878+
assertEquals(SimilarityMeasure.DOT_PRODUCT, JinaAIService.defaultSimilarity(JinaAIEmbeddingType.FLOAT));
18711879
}
18721880

18731881
@SuppressWarnings("checkstyle:LineLength")

0 commit comments

Comments
 (0)