Skip to content

Commit 20b0b67

Browse files
committed
feat(batching): support batching for ColPali functions
1 parent ca0efea commit 20b0b67

File tree

1 file changed

+30
-34
lines changed

1 file changed

+30
-34
lines changed

python/cocoindex/functions/colpali.py

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import functools
44
from dataclasses import dataclass
5-
from typing import Any, Optional, TYPE_CHECKING, Literal
5+
from typing import Any, TYPE_CHECKING, Literal
6+
from typing import Any, TYPE_CHECKING, Literal
67
import numpy as np
78

89
from .. import op
@@ -22,18 +23,11 @@ class ColPaliModelInfo:
2223
dimension: int
2324

2425

25-
@functools.lru_cache(maxsize=None)
26+
@functools.cache
2627
def _get_colpali_model_and_processor(model_name: str) -> ColPaliModelInfo:
2728
"""Load and cache ColPali model and processor with shared device setup."""
2829
try:
29-
from colpali_engine import ( # type: ignore[import-untyped]
30-
ColPali,
31-
ColPaliProcessor,
32-
ColQwen2,
33-
ColQwen2Processor,
34-
ColSmol,
35-
ColSmolProcessor,
36-
)
30+
import colpali_engine as ce # type: ignore[import-untyped]
3731
import torch
3832
except ImportError as e:
3933
raise ImportError(
@@ -42,29 +36,30 @@ def _get_colpali_model_and_processor(model_name: str) -> ColPaliModelInfo:
4236
) from e
4337

4438
device = "cuda" if torch.cuda.is_available() else "cpu"
39+
lower_model_name = model_name.lower()
4540

4641
# Determine model type from name
47-
if "colpali" in model_name.lower():
48-
model = ColPali.from_pretrained(
42+
if lower_model_name.startswith("colpali"):
43+
model = ce.ColPali.from_pretrained(
4944
model_name, torch_dtype=torch.bfloat16, device_map=device
5045
)
51-
processor = ColPaliProcessor.from_pretrained(model_name)
52-
elif "colqwen" in model_name.lower():
53-
model = ColQwen2.from_pretrained(
46+
processor = ce.ColPaliProcessor.from_pretrained(model_name)
47+
elif lower_model_name.startswith("colqwen2.5"):
48+
model = ce.ColQwen2_5.from_pretrained(
5449
model_name, torch_dtype=torch.bfloat16, device_map=device
5550
)
56-
processor = ColQwen2Processor.from_pretrained(model_name)
57-
elif "colsmol" in model_name.lower():
58-
model = ColSmol.from_pretrained(
51+
processor = ce.ColQwen2_5_Processor.from_pretrained(model_name)
52+
elif lower_model_name.startswith("colqwen"):
53+
model = ce.ColQwen2.from_pretrained(
5954
model_name, torch_dtype=torch.bfloat16, device_map=device
6055
)
61-
processor = ColSmolProcessor.from_pretrained(model_name)
56+
processor = ce.ColQwen2Processor.from_pretrained(model_name)
6257
else:
6358
# Fallback to ColPali for backwards compatibility
64-
model = ColPali.from_pretrained(
59+
model = ce.ColPali.from_pretrained(
6560
model_name, torch_dtype=torch.bfloat16, device_map=device
6661
)
67-
processor = ColPaliProcessor.from_pretrained(model_name)
62+
processor = ce.ColPaliProcessor.from_pretrained(model_name)
6863

6964
# Detect dimension
7065
dimension = _detect_colpali_dimension(model, processor, device)
@@ -130,6 +125,7 @@ class ColPaliEmbedImage(op.FunctionSpec):
130125
@op.executor_class(
131126
gpu=True,
132127
cache=True,
128+
batching=True,
133129
behavior_version=1,
134130
)
135131
class ColPaliEmbedImageExecutor:
@@ -146,7 +142,7 @@ def analyze(self) -> type:
146142
dimension = self._model_info.dimension
147143
return Vector[Vector[np.float32, Literal[dimension]]] # type: ignore
148144

149-
def __call__(self, img_bytes: bytes) -> Any:
145+
def __call__(self, img_bytes_list: list[bytes]) -> Any:
150146
try:
151147
from PIL import Image
152148
import torch
@@ -160,8 +156,11 @@ def __call__(self, img_bytes: bytes) -> Any:
160156
processor = self._model_info.processor
161157
device = self._model_info.device
162158

163-
pil_image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
164-
inputs = processor.process_images([pil_image]).to(device)
159+
pil_images = [
160+
Image.open(io.BytesIO(img_bytes)).convert("RGB")
161+
for img_bytes in img_bytes_list
162+
]
163+
inputs = processor.process_images(pil_images).to(device)
165164
with torch.no_grad():
166165
embeddings = model(**inputs)
167166

@@ -171,10 +170,8 @@ def __call__(self, img_bytes: bytes) -> Any:
171170
f"Expected 3D tensor [batch, patches, hidden_dim], got shape {embeddings.shape}"
172171
)
173172

174-
# Keep patch-level embeddings: [batch, patches, hidden_dim] -> [patches, hidden_dim]
175-
patch_embeddings = embeddings[0] # Remove batch dimension
176-
177-
return patch_embeddings.cpu().to(torch.float32).numpy()
173+
# [patches, hidden_dim]
174+
return embeddings.cpu().to(torch.float32).numpy()
178175

179176

180177
class ColPaliEmbedQuery(op.FunctionSpec):
@@ -207,6 +204,7 @@ class ColPaliEmbedQuery(op.FunctionSpec):
207204
gpu=True,
208205
cache=True,
209206
behavior_version=1,
207+
batching=True,
210208
)
211209
class ColPaliEmbedQueryExecutor:
212210
"""Executor for ColVision query embedding (ColPali, ColQwen2, ColSmol, etc.)."""
@@ -222,7 +220,7 @@ def analyze(self) -> type:
222220
dimension = self._model_info.dimension
223221
return Vector[Vector[np.float32, Literal[dimension]]] # type: ignore
224222

225-
def __call__(self, query: str) -> Any:
223+
def __call__(self, queries: list[str]) -> Any:
226224
try:
227225
import torch
228226
except ImportError as e:
@@ -234,7 +232,7 @@ def __call__(self, query: str) -> Any:
234232
processor = self._model_info.processor
235233
device = self._model_info.device
236234

237-
inputs = processor.process_queries([query]).to(device)
235+
inputs = processor.process_queries(queries).to(device)
238236
with torch.no_grad():
239237
embeddings = model(**inputs)
240238

@@ -244,7 +242,5 @@ def __call__(self, query: str) -> Any:
244242
f"Expected 3D tensor [batch, tokens, hidden_dim], got shape {embeddings.shape}"
245243
)
246244

247-
# Keep token-level embeddings: [batch, tokens, hidden_dim] -> [tokens, hidden_dim]
248-
token_embeddings = embeddings[0] # Remove batch dimension
249-
250-
return token_embeddings.cpu().to(torch.float32).numpy()
245+
# [tokens, hidden_dim]
246+
return embeddings.cpu().to(torch.float32).numpy()

0 commit comments

Comments
 (0)