Skip to content

Commit 0d5cbb5

Browse files
authored
feat(batching): support batching for ColPali functions (#1233)
1 parent ca0efea commit 0d5cbb5

File tree

1 file changed

+29
-34
lines changed

1 file changed

+29
-34
lines changed

python/cocoindex/functions/colpali.py

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import functools
44
from dataclasses import dataclass
5-
from typing import Any, Optional, TYPE_CHECKING, Literal
5+
from typing import Any, TYPE_CHECKING, Literal
66
import numpy as np
77

88
from .. import op
@@ -22,18 +22,11 @@ class ColPaliModelInfo:
2222
dimension: int
2323

2424

25-
@functools.lru_cache(maxsize=None)
25+
@functools.cache
2626
def _get_colpali_model_and_processor(model_name: str) -> ColPaliModelInfo:
2727
"""Load and cache ColPali model and processor with shared device setup."""
2828
try:
29-
from colpali_engine import ( # type: ignore[import-untyped]
30-
ColPali,
31-
ColPaliProcessor,
32-
ColQwen2,
33-
ColQwen2Processor,
34-
ColSmol,
35-
ColSmolProcessor,
36-
)
29+
import colpali_engine as ce # type: ignore[import-untyped]
3730
import torch
3831
except ImportError as e:
3932
raise ImportError(
@@ -42,29 +35,30 @@ def _get_colpali_model_and_processor(model_name: str) -> ColPaliModelInfo:
4235
) from e
4336

4437
device = "cuda" if torch.cuda.is_available() else "cpu"
38+
lower_model_name = model_name.lower()
4539

4640
# Determine model type from name
47-
if "colpali" in model_name.lower():
48-
model = ColPali.from_pretrained(
41+
if lower_model_name.startswith("colpali"):
42+
model = ce.ColPali.from_pretrained(
4943
model_name, torch_dtype=torch.bfloat16, device_map=device
5044
)
51-
processor = ColPaliProcessor.from_pretrained(model_name)
52-
elif "colqwen" in model_name.lower():
53-
model = ColQwen2.from_pretrained(
45+
processor = ce.ColPaliProcessor.from_pretrained(model_name)
46+
elif lower_model_name.startswith("colqwen2.5"):
47+
model = ce.ColQwen2_5.from_pretrained(
5448
model_name, torch_dtype=torch.bfloat16, device_map=device
5549
)
56-
processor = ColQwen2Processor.from_pretrained(model_name)
57-
elif "colsmol" in model_name.lower():
58-
model = ColSmol.from_pretrained(
50+
processor = ce.ColQwen2_5_Processor.from_pretrained(model_name)
51+
elif lower_model_name.startswith("colqwen"):
52+
model = ce.ColQwen2.from_pretrained(
5953
model_name, torch_dtype=torch.bfloat16, device_map=device
6054
)
61-
processor = ColSmolProcessor.from_pretrained(model_name)
55+
processor = ce.ColQwen2Processor.from_pretrained(model_name)
6256
else:
6357
# Fallback to ColPali for backwards compatibility
64-
model = ColPali.from_pretrained(
58+
model = ce.ColPali.from_pretrained(
6559
model_name, torch_dtype=torch.bfloat16, device_map=device
6660
)
67-
processor = ColPaliProcessor.from_pretrained(model_name)
61+
processor = ce.ColPaliProcessor.from_pretrained(model_name)
6862

6963
# Detect dimension
7064
dimension = _detect_colpali_dimension(model, processor, device)
@@ -130,6 +124,7 @@ class ColPaliEmbedImage(op.FunctionSpec):
130124
@op.executor_class(
131125
gpu=True,
132126
cache=True,
127+
batching=True,
133128
behavior_version=1,
134129
)
135130
class ColPaliEmbedImageExecutor:
@@ -146,7 +141,7 @@ def analyze(self) -> type:
146141
dimension = self._model_info.dimension
147142
return Vector[Vector[np.float32, Literal[dimension]]] # type: ignore
148143

149-
def __call__(self, img_bytes: bytes) -> Any:
144+
def __call__(self, img_bytes_list: list[bytes]) -> Any:
150145
try:
151146
from PIL import Image
152147
import torch
@@ -160,8 +155,11 @@ def __call__(self, img_bytes: bytes) -> Any:
160155
processor = self._model_info.processor
161156
device = self._model_info.device
162157

163-
pil_image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
164-
inputs = processor.process_images([pil_image]).to(device)
158+
pil_images = [
159+
Image.open(io.BytesIO(img_bytes)).convert("RGB")
160+
for img_bytes in img_bytes_list
161+
]
162+
inputs = processor.process_images(pil_images).to(device)
165163
with torch.no_grad():
166164
embeddings = model(**inputs)
167165

@@ -171,10 +169,8 @@ def __call__(self, img_bytes: bytes) -> Any:
171169
f"Expected 3D tensor [batch, patches, hidden_dim], got shape {embeddings.shape}"
172170
)
173171

174-
# Keep patch-level embeddings: [batch, patches, hidden_dim] -> [patches, hidden_dim]
175-
patch_embeddings = embeddings[0] # Remove batch dimension
176-
177-
return patch_embeddings.cpu().to(torch.float32).numpy()
172+
# [patches, hidden_dim]
173+
return embeddings.cpu().to(torch.float32).numpy()
178174

179175

180176
class ColPaliEmbedQuery(op.FunctionSpec):
@@ -207,6 +203,7 @@ class ColPaliEmbedQuery(op.FunctionSpec):
207203
gpu=True,
208204
cache=True,
209205
behavior_version=1,
206+
batching=True,
210207
)
211208
class ColPaliEmbedQueryExecutor:
212209
"""Executor for ColVision query embedding (ColPali, ColQwen2, ColSmol, etc.)."""
@@ -222,7 +219,7 @@ def analyze(self) -> type:
222219
dimension = self._model_info.dimension
223220
return Vector[Vector[np.float32, Literal[dimension]]] # type: ignore
224221

225-
def __call__(self, query: str) -> Any:
222+
def __call__(self, queries: list[str]) -> Any:
226223
try:
227224
import torch
228225
except ImportError as e:
@@ -234,7 +231,7 @@ def __call__(self, query: str) -> Any:
234231
processor = self._model_info.processor
235232
device = self._model_info.device
236233

237-
inputs = processor.process_queries([query]).to(device)
234+
inputs = processor.process_queries(queries).to(device)
238235
with torch.no_grad():
239236
embeddings = model(**inputs)
240237

@@ -244,7 +241,5 @@ def __call__(self, query: str) -> Any:
244241
f"Expected 3D tensor [batch, tokens, hidden_dim], got shape {embeddings.shape}"
245242
)
246243

247-
# Keep token-level embeddings: [batch, tokens, hidden_dim] -> [tokens, hidden_dim]
248-
token_embeddings = embeddings[0] # Remove batch dimension
249-
250-
return token_embeddings.cpu().to(torch.float32).numpy()
244+
# [tokens, hidden_dim]
245+
return embeddings.cpu().to(torch.float32).numpy()

0 commit comments

Comments
 (0)