Skip to content

Commit d111a3c

Browse files
committed
Implement text search
1 parent 3ef38d9 commit d111a3c

File tree

7 files changed

+94
-32
lines changed

7 files changed

+94
-32
lines changed

Makefile

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,16 @@ test:
5353
# The echo commands informally compare the actual to the expected result.
5454
API_URL=localhost:4213
5555
integration-test:
56-
curl -X POST ${API_URL}/models/vit_b32/urls | cut -c 1-100
56+
curl -s -X POST ${API_URL}/models/vit_b32/urls | cut -c 1-100
5757
@echo =?=[]
5858
curl -H "Content-Type: application/json" -d '{"url": "https://iiif.itatti.harvard.edu/iiif/2/yashiro!letters-jp!letter_001.pdf/full/full/0/default.jpg"}' ${API_URL}/models/vit_b32/add
5959
@echo =?=
60-
curl -X POST ${API_URL}/models/vit_b32/urls | cut -c 1-100
60+
curl -s -X POST ${API_URL}/models/vit_b32/urls | cut -c 1-100
6161
@echo =?=[url]
6262
curl -H "Content-Type: application/json" -d '{"url": "https://iiif.itatti.harvard.edu/iiif/2/yashiro!letters-jp!letter_001.pdf/full/full/0/default.jpg"}' ${API_URL}/models/vit_b32/search
6363
@echo =?=[result]
64+
curl -H "Content-Type: application/json" -d '{"text": "a black and white text in japanese"}' ${API_URL}/models/vit_b32/search
65+
@echo =?=[result]
6466
curl -H "Content-Type: application/json" -d '{"url": "https://iiif.itatti.harvard.edu/iiif/2/yashiro!letters-jp!letter_001.pdf/full/full/0/default.jpg", "other": "https://iiif.itatti.harvard.edu/iiif/2/yashiro!letters-jp!letter_001.pdf/full/full/0/default.jpg"}' ${API_URL}/models/vit_b32/compare
6567
@echo =?=100
6668
curl ${API_URL}/models/vit_b32/count

api/commands.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ def insert_images(
2525

2626
new_urls = [url for url in urls if url not in existing_urls]
2727
image_embeddings = (
28-
[embeddings[model_name].extract(load_image_from_url(url)) for url in new_urls]
28+
[
29+
embeddings[model_name].get_image_embedding(load_image_from_url(url))
30+
for url in new_urls
31+
]
2932
if image_embeddings is None
3033
else [
3134
embedding
@@ -53,8 +56,7 @@ def similarity_score(distance):
5356
return 100 * (1 - distance / 2)
5457

5558

56-
def search(model_name, url, limit=10):
57-
embedding = embeddings[model_name].extract(load_image_from_url(url))
59+
def search_by_embedding(model_name, embedding, limit=10):
5860
search_results = collections[model_name].search(
5961
data=[embedding],
6062
anns_field="embedding",
@@ -79,11 +81,21 @@ def search(model_name, url, limit=10):
7981
]
8082

8183

84+
def search_by_url(model_name, url, limit=10):
85+
embedding = embeddings[model_name].get_image_embedding(load_image_from_url(url))
86+
return search_by_embedding(model_name, embedding, limit)
87+
88+
89+
def search_by_text(model_name, text, limit=10):
90+
embedding = embeddings[model_name].get_text_embedding(text)
91+
return search_by_embedding(model_name, embedding, limit)
92+
93+
8294
def compare(model_name, url_left, url_right):
8395
# alternatively, we could first try to fetch the embeddings from milvus in
8496
# case their computation is significantly more expensive than a query
85-
left = embeddings[model_name].extract(load_image_from_url(url_left))
86-
right = embeddings[model_name].extract(load_image_from_url(url_right))
97+
left = embeddings[model_name].get_image_embedding(load_image_from_url(url_left))
98+
right = embeddings[model_name].get_image_embedding(load_image_from_url(url_right))
8799

88100
# calc_distance() has been removed from milvus
89101
# it's a bit overkill anyway if we don't compare with vectors from the db
@@ -140,7 +152,8 @@ def remove_images(model_name, urls):
140152

141153
commands = dict(
142154
insert_images=insert_images,
143-
search=search,
155+
search_by_url=search_by_url,
156+
search_by_text=search_by_text,
144157
compare=compare,
145158
list_images=list_images,
146159
count=count,

api/embeddings.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def get_text_embedding(self, text):
3737
text_embedding = self.model.get_text_features(**inputs)
3838
text_embedding /= text_embedding.norm(dim=-1, keepdim=True)
3939
text_embedding = text_embedding.tolist()
40-
return text_embedding
40+
return text_embedding[0]
4141

4242
@torch.no_grad()
4343
def get_image_embedding(self, images):
@@ -46,10 +46,7 @@ def get_image_embedding(self, images):
4646
image_embedding = self.model.get_image_features(**inputs)
4747
image_embedding /= image_embedding.norm(dim=-1, keepdim=True)
4848
image_embedding = image_embedding.tolist()
49-
return image_embedding
50-
51-
def extract(self, image):
52-
return self.get_image_embedding(image)[0]
49+
return image_embedding[0]
5350

5451

5552
embeddings = {

api/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,9 @@ async def restore(model_name: ModelName, images: list[DatabaseEntry]):
240240
)
241241
async def search(model_name: ModelName, params: SearchParameters):
242242
if params.url:
243-
return try_rpc("search", [model_name.value, params.url, params.limit])
243+
return try_rpc("search_by_url", [model_name.value, params.url, params.limit])
244244
elif params.text:
245-
return try_rpc("text_search", [model_name.value, params.text, params.limit])
245+
return try_rpc("search_by_text", [model_name.value, params.text, params.limit])
246246
else:
247247
raise HTTPException(
248248
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,

api/tests/test_commands.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,21 @@ def test_crud(mock_model):
4444
"metadata": metadata,
4545
"url": TEST_URLS[0],
4646
}
47-
] == commands["search"](mock_model, TEST_URLS[1])
47+
] == commands["search_by_url"](mock_model, TEST_URLS[1])
48+
assert [
49+
{
50+
"similarity": pytest.approx(29.19090986251831),
51+
"metadata": metadata,
52+
"url": TEST_URLS[0],
53+
}
54+
] == commands["search_by_text"](mock_model, "a black and white text in japanese")
55+
assert [
56+
{
57+
"similarity": pytest.approx(17.36249327659607),
58+
"metadata": metadata,
59+
"url": TEST_URLS[0],
60+
}
61+
] == commands["search_by_text"](mock_model, "a cute colorful cat")
4862

4963
commands["remove_images"](mock_model, [TEST_URLS[0]])
5064

@@ -65,9 +79,9 @@ def test_search_more_results(mock_model):
6579
[None] * 1000,
6680
[[0] * 512] * 1000,
6781
)
68-
results = commands["search"](mock_model, TEST_URLS[1])
82+
results = commands["search_by_url"](mock_model, TEST_URLS[1])
6983
assert 10 == len(results)
70-
results = commands["search"](mock_model, TEST_URLS[1], 100)
84+
results = commands["search_by_url"](mock_model, TEST_URLS[1], 100)
7185
assert 100 == len(results)
7286
commands["remove_images"](mock_model, urls)
7387

api/tests/test_embeddings.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,26 @@ def squared_l2(v):
3131
return sum([x * x for x in v])
3232

3333

34-
def test_extract_vit_b32():
34+
def test_vit_b32_get_image_embedding():
3535
random.seed(2023)
3636
image_data = bytes([random.randint(0, 255) for _ in range(500 * 500 * 3)])
3737
image = Image.frombytes("RGB", (500, 500), image_data)
3838

3939
b32 = embeddings["vit_b32"]
40-
embedding = b32.extract(image)
40+
embedding = b32.get_image_embedding(image)
4141
assert pytest.approx(1) == squared_l2(embedding)
4242
assert 512 == len(embedding)
4343
assert pytest.approx(-1.00633253128035) == sum(embedding)
4444

4545

46+
def test_vit_b32_get_text_embedding():
47+
b32 = embeddings["vit_b32"]
48+
embedding = b32.get_text_embedding("some random text")
49+
assert pytest.approx(1) == squared_l2(embedding)
50+
assert 512 == len(embedding)
51+
assert pytest.approx(1.223632167381993) == sum(embedding)
52+
53+
4654
@functools.cache
4755
def batch_test_images():
4856
batch_test_ids = (
@@ -68,14 +76,14 @@ def batch_test_images():
6876
def test_individual_extract(benchmark):
6977
def extract_each():
7078
for image in batch_test_images():
71-
embeddings["vit_b32"].extract(image)
79+
embeddings["vit_b32"].get_image_embedding(image)
7280

7381
benchmark(extract_each)
7482

7583

7684
@pytest.mark.benchmark
7785
def test_batch_extract(benchmark):
78-
benchmark(embeddings["vit_b32"].extract, batch_test_images())
86+
benchmark(embeddings["vit_b32"].get_image_embedding, batch_test_images())
7987

8088

8189
def test_unresized_image():
@@ -84,4 +92,4 @@ def test_unresized_image():
8492
url2 = "https://artresearch-iiif.s3.eu-west-1.amazonaws.com/marburg/gm1159076.jpg"
8593

8694
images = [load_image_from_url(url) for url in [url0, url1, url2]]
87-
[embeddings["vit_b32"].extract(image) for image in images]
95+
[embeddings["vit_b32"].get_image_embedding(image) for image in images]

api/tests/test_main.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def test_search_add_image_conflict(mock_rpc):
180180
},
181181
)
182182
assert response.status_code == 409
183+
assert response.json() == {"detail": "Image already inserted"}
183184
mock_rpc.assert_called_once_with(
184185
"insert_images",
185186
[
@@ -192,7 +193,7 @@ def test_search_add_image_conflict(mock_rpc):
192193
)
193194

194195

195-
def test_search_success(mock_rpc):
196+
def test_search_by_url_success(mock_rpc):
196197
mock_rpc.return_value = [
197198
{
198199
"url": "http://example.com/image1.jpg",
@@ -223,11 +224,11 @@ def test_search_success(mock_rpc):
223224
},
224225
]
225226
mock_rpc.assert_called_once_with(
226-
"search", ["vit_b32", "http://example.com/query.jpg", 10]
227+
"search_by_url", ["vit_b32", "http://example.com/query.jpg", 10]
227228
)
228229

229230

230-
def test_search_success_with_limit(mock_rpc):
231+
def test_search_by_url_success_with_limit(mock_rpc):
231232
mock_rpc.return_value = [
232233
{
233234
"url": "http://example.com/image1.jpg",
@@ -242,7 +243,7 @@ def test_search_success_with_limit(mock_rpc):
242243
assert response.status_code == 200
243244
assert len(response.json()) == 100
244245
mock_rpc.assert_called_once_with(
245-
"search", ["vit_b32", "http://example.com/query.jpg", 100]
246+
"search_by_url", ["vit_b32", "http://example.com/query.jpg", 100]
246247
)
247248

248249

@@ -255,20 +256,30 @@ def test_search_empty(mock_rpc):
255256
assert response.status_code == 200
256257
assert response.json() == []
257258
mock_rpc.assert_called_once_with(
258-
"search", ["vit_b32", "http://example.com/query.jpg", 10]
259+
"search_by_url", ["vit_b32", "http://example.com/query.jpg", 10]
259260
)
260261

261262

262-
def test_search_returns_422_when_invalid_url(mock_rpc):
263+
def test_search_by_url_returns_422_when_invalid_url(mock_rpc):
263264
response = client.post(
264265
"/models/vit_b32/search",
265266
json={"url": "not_a_url"},
266267
)
267268
assert response.status_code == 422
269+
assert response.json() == {
270+
"detail": [
271+
{
272+
"loc": ["body", "url"],
273+
"msg": "invalid or missing URL scheme",
274+
"type": "value_error.url.scheme",
275+
}
276+
]
277+
}
278+
268279
mock_rpc.assert_not_called()
269280

270281

271-
def test_search_text_success(mock_rpc):
282+
def test_search_by_text_success(mock_rpc):
272283
mock_rpc.return_value = [
273284
{
274285
"url": "http://example.com/image1.jpg",
@@ -298,7 +309,24 @@ def test_search_text_success(mock_rpc):
298309
"similarity": 20,
299310
},
300311
]
301-
mock_rpc.assert_called_once_with("text_search", ["vit_b32", "cute cat", 10])
312+
mock_rpc.assert_called_once_with("search_by_text", ["vit_b32", "cute cat", 10])
313+
314+
315+
def test_search_by_text_success_with_limit(mock_rpc):
316+
mock_rpc.return_value = [
317+
{
318+
"url": "http://example.com/image1.jpg",
319+
"metadata": {"tags": ["cat", "cute"]},
320+
"similarity": 10,
321+
}
322+
] * 100
323+
response = client.post(
324+
"/models/vit_b32/search",
325+
json={"text": "some text", "limit": 100},
326+
)
327+
assert response.status_code == 200
328+
assert len(response.json()) == 100
329+
mock_rpc.assert_called_once_with("search_by_text", ["vit_b32", "some text", 100])
302330

303331

304332
def test_search_returns_422_whithout_query(mock_rpc):
@@ -319,7 +347,7 @@ def test_search_returns_500_when_rpc_error(mock_rpc):
319347
)
320348
assert response.status_code == 500
321349
mock_rpc.assert_called_once_with(
322-
"search", ["vit_b32", "http://example.com/query.jpg", 10]
350+
"search_by_url", ["vit_b32", "http://example.com/query.jpg", 10]
323351
)
324352

325353

0 commit comments

Comments
 (0)