Skip to content

Commit 3ef38d9

Browse files
committed
Fix similarity score
1 parent e44440c commit 3ef38d9

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

api/commands.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ def insert_images(
4848
}
4949

5050

51+
def similarity_score(distance):
52+
# 2 is the maximum distance between normalised vectors
53+
return 100 * (1 - distance / 2)
54+
55+
5156
def search(model_name, url, limit=10):
5257
embedding = embeddings[model_name].extract(load_image_from_url(url))
5358
search_results = collections[model_name].search(
@@ -68,7 +73,7 @@ def search(model_name, url, limit=10):
6873
{
6974
"url": hit.id,
7075
"metadata": json.loads(hit.entity.get("metadata")),
71-
"similarity": 100 * (1 - hit.distance),
76+
"similarity": similarity_score(hit.distance),
7277
}
7378
for hit in search_results[0]
7479
]
@@ -84,7 +89,7 @@ def compare(model_name, url_left, url_right):
8489
# it's a bit overkill anyway if we don't compare with vectors from the db
8590
if metrics[model_name] == "L2":
8691
# _squared_ L2, to be consistent with the distances in milvus' search
87-
return 100 * (1 - np.sum(np.square(np.array(left) - np.array(right))))
92+
return similarity_score(np.sum(np.square(np.array(left) - np.array(right))))
8893

8994
raise RuntimeError(
9095
"Distance calculation has not been implemented in the API. "

api/tests/test_commands.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_crud(mock_model):
4040

4141
assert [
4242
{
43-
"similarity": pytest.approx(11.650939784262505),
43+
"similarity": pytest.approx(55.82546989213125),
4444
"metadata": metadata,
4545
"url": TEST_URLS[0],
4646
}
@@ -118,7 +118,7 @@ def test_insert_without_replacing(mock_model):
118118

119119

120120
def test_compare():
121-
assert pytest.approx(11.650939784262505) == commands["compare"](
121+
assert pytest.approx(55.82546989213125) == commands["compare"](
122122
"vit_b32", TEST_URLS[0], TEST_URLS[1]
123123
)
124124

api/tests/test_embeddings.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import random
55
from PIL import Image
6+
import math
67

78
from ..embeddings import load_image_from_url, embeddings
89

@@ -26,13 +27,18 @@ def test_load_image_from_url_and_reject_small_ones():
2627
)
2728

2829

30+
def squared_l2(v):
31+
return sum([x * x for x in v])
32+
33+
2934
def test_extract_vit_b32():
3035
random.seed(2023)
3136
image_data = bytes([random.randint(0, 255) for _ in range(500 * 500 * 3)])
3237
image = Image.frombytes("RGB", (500, 500), image_data)
3338

3439
b32 = embeddings["vit_b32"]
3540
embedding = b32.extract(image)
41+
assert pytest.approx(1) == squared_l2(embedding)
3642
assert 512 == len(embedding)
3743
assert pytest.approx(-1.00633253128035) == sum(embedding)
3844

0 commit comments

Comments
 (0)