Skip to content

Commit dc496d4

Browse files
authored
Merge pull request #4 from lklic/bulk_add
Add batch indexing API
2 parents 56df920 + 529cfa7 commit dc496d4

File tree

6 files changed

+91
-28
lines changed

6 files changed

+91
-28
lines changed

api/commands.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def format_url_list(urls):
1111

1212

1313
def insert_images(
14-
model_name, urls, metadatas, image_embeddings=None, replace_existing=True
14+
model_name, urls, metadatas, image_embeddings=None, replace_existing=True, fail_on_error=True
1515
):
1616
existing_urls = []
1717
if not replace_existing:
@@ -24,30 +24,51 @@ def insert_images(
2424
]
2525

2626
new_urls = [url for url in urls if url not in existing_urls]
27-
image_embeddings = (
28-
[
29-
embeddings[model_name].get_image_embedding(load_image_from_url(url))
30-
for url in new_urls
31-
]
32-
if image_embeddings is None
33-
else [
34-
embedding
35-
for url, embedding in zip(urls, image_embeddings)
36-
if url not in existing_urls
37-
]
38-
)
39-
metadatas = [
40-
json.dumps(metadata)
41-
for url, metadata in zip(urls, metadatas)
42-
if url not in existing_urls
43-
]
44-
45-
if len(new_urls) > 0:
46-
collections[model_name].insert([new_urls, image_embeddings, metadatas])
27+
28+
# Handle individual image failures
29+
successful_urls = []
30+
failed_urls = []
31+
computed_embeddings = []
32+
processed_metadatas = []
33+
34+
if image_embeddings is None:
35+
# Process each image individually
36+
for i, url in enumerate(new_urls):
37+
try:
38+
embedding = embeddings[model_name].get_image_embedding(load_image_from_url(url))
39+
computed_embeddings.append(embedding)
40+
processed_metadatas.append(json.dumps(metadatas[urls.index(url)]))
41+
successful_urls.append(url)
42+
except Exception as e:
43+
if fail_on_error:
44+
# Original behavior: propagate the exception
45+
raise
46+
else:
47+
# New behavior: collect the error and continue
48+
failed_urls.append({"url": url, "error": str(e)})
49+
else:
50+
# Use provided embeddings
51+
for url, embedding in zip(urls, image_embeddings):
52+
if url not in existing_urls:
53+
try:
54+
computed_embeddings.append(embedding)
55+
processed_metadatas.append(json.dumps(metadatas[urls.index(url)]))
56+
successful_urls.append(url)
57+
except Exception as e:
58+
if fail_on_error:
59+
# Original behavior: propagate the exception
60+
raise
61+
else:
62+
# New behavior: collect the error and continue
63+
failed_urls.append({"url": url, "error": str(e)})
64+
65+
if len(successful_urls) > 0:
66+
collections[model_name].insert([successful_urls, computed_embeddings, processed_metadatas])
4767

4868
return {
49-
"added": new_urls,
69+
"added": successful_urls,
5070
"found": existing_urls,
71+
"failed": failed_urls,
5172
}
5273

5374

api/embeddings.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ def get_text_embedding(self, text):
4141

4242
@torch.no_grad()
4343
def get_image_embedding(self, images):
44-
inputs = self.processor(images=images, return_tensors="pt")
44+
# Ensure image is a list since processor expects a batch
45+
if not isinstance(images, list):
46+
images= [images]
47+
inputs = self.processor(images=images, return_tensors="pt", padding=True)
4548
inputs = inputs.to(self.device)
4649
image_embedding = self.model.get_image_features(**inputs)
4750
image_embedding /= image_embedding.norm(dim=-1, keepdim=True)

api/main.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,38 @@ async def restore(model_name: ModelName, images: list[DatabaseEntry]):
232232
)
233233

234234

235+
class BulkInsertResult(BaseModel):
236+
added: list[ImageUrl]
237+
found: list[ImageUrl]
238+
failed: list[dict]
239+
240+
241+
@app.post(
242+
"/models/{model_name}/add_bulk",
243+
status_code=status.HTTP_200_OK,
244+
tags=["model"],
245+
summary="""
246+
Adds multiple images to the index at once.
247+
Existing urls will have their metadata replaced with the provided ones.
248+
Returns lists of successfully added URLs, existing URLs, and failed URLs with error messages.
249+
""".strip(),
250+
response_model=BulkInsertResult
251+
)
252+
async def add_bulk(model_name: ModelName, images: list[ImageAndMetada]):
253+
result = try_rpc(
254+
"insert_images",
255+
[
256+
model_name.value,
257+
[image.url for image in images],
258+
[check_json_string_length(image.metadata) for image in images],
259+
None, # No pre-computed embeddings
260+
True, # Replace existing
261+
False, # Don't fail on error, continue processing
262+
],
263+
)
264+
return result
265+
266+
235267
@app.post(
236268
"/models/{model_name}/search",
237269
tags=["model"],

api/requirements-worker.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ pymilvus==2.2.5
77
torch==2.0.0
88
transformers==4.28.1
99
Pillow==9.5.0
10+
numpy<2.0

api/rpc_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __call__(self, command, args):
4848
),
4949
body=json.dumps([command, args]),
5050
)
51-
self.connection.process_data_events(time_limit=60) # seconds
51+
self.connection.process_data_events(time_limit=200) # seconds
5252
if self.response_data is None:
5353
raise RuntimeError("No response (timeout?)")
5454
response = json.loads(self.response_data)

docker/docker-compose.dev.yml

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,19 @@ services:
6969
depends_on:
7070
rabbitmq:
7171
condition: service_healthy
72-
worker:
72+
worker_1:
7373
condition: service_healthy
74+
worker_2:
75+
condition: service_healthy
76+
7477
environment:
75-
WEB_CONCURRENCY: 16
78+
WEB_CONCURRENCY: 3
7679
RABBITMQ_URL: amqp://guest:guest@rabbitmq:5672
7780
volumes:
7881
- ../api:/app
7982
ports:
8083
- 4213:4213
81-
command: "uvicorn main:app --reload --host 0.0.0.0 --port 4213"
84+
command: "uvicorn main:app --workers 5 --host 0.0.0.0 --port 4213"
8285
logging:
8386
options:
8487
max-size: "10M"
@@ -103,7 +106,7 @@ services:
103106
max-size: "10M"
104107
max-file: "10"
105108

106-
worker:
109+
worker_1: &worker_1
107110
build:
108111
context: ../api
109112
dockerfile: ../docker/Dockerfile
@@ -122,6 +125,9 @@ services:
122125
max-size: "10M"
123126
max-file: "10"
124127

128+
worker_2:
129+
<<: *worker_1
130+
125131
dev:
126132
build:
127133
context: ../api

0 commit comments

Comments
 (0)