From cfbb819a9d60921bdde0336ae623c06c7c66dfb0 Mon Sep 17 00:00:00 2001 From: Jiangzhou He Date: Fri, 31 Oct 2025 08:33:27 -0700 Subject: [PATCH] feat(batching): support batching for ColPali functions --- python/cocoindex/functions/colpali.py | 63 ++++++++++++--------------- 1 file changed, 29 insertions(+), 34 deletions(-) diff --git a/python/cocoindex/functions/colpali.py b/python/cocoindex/functions/colpali.py index c45e3c1e..35d04e20 100644 --- a/python/cocoindex/functions/colpali.py +++ b/python/cocoindex/functions/colpali.py @@ -2,7 +2,7 @@ import functools from dataclasses import dataclass -from typing import Any, Optional, TYPE_CHECKING, Literal +from typing import Any, TYPE_CHECKING, Literal import numpy as np from .. import op @@ -22,18 +22,11 @@ class ColPaliModelInfo: dimension: int -@functools.lru_cache(maxsize=None) +@functools.cache def _get_colpali_model_and_processor(model_name: str) -> ColPaliModelInfo: """Load and cache ColPali model and processor with shared device setup.""" try: - from colpali_engine import ( # type: ignore[import-untyped] - ColPali, - ColPaliProcessor, - ColQwen2, - ColQwen2Processor, - ColSmol, - ColSmolProcessor, - ) + import colpali_engine as ce # type: ignore[import-untyped] import torch except ImportError as e: raise ImportError( @@ -42,29 +35,30 @@ def _get_colpali_model_and_processor(model_name: str) -> ColPaliModelInfo: ) from e device = "cuda" if torch.cuda.is_available() else "cpu" + lower_model_name = model_name.lower() # Determine model type from name - if "colpali" in model_name.lower(): - model = ColPali.from_pretrained( + if lower_model_name.startswith("colpali"): + model = ce.ColPali.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map=device ) - processor = ColPaliProcessor.from_pretrained(model_name) - elif "colqwen" in model_name.lower(): - model = ColQwen2.from_pretrained( + processor = ce.ColPaliProcessor.from_pretrained(model_name) + elif lower_model_name.startswith("colqwen2.5"): + model = ce.ColQwen2_5.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map=device ) - processor = ColQwen2Processor.from_pretrained(model_name) - elif "colsmol" in model_name.lower(): - model = ColSmol.from_pretrained( + processor = ce.ColQwen2_5_Processor.from_pretrained(model_name) + elif lower_model_name.startswith("colqwen"): + model = ce.ColQwen2.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map=device ) - processor = ColSmolProcessor.from_pretrained(model_name) + processor = ce.ColQwen2Processor.from_pretrained(model_name) else: # Fallback to ColPali for backwards compatibility - model = ColPali.from_pretrained( + model = ce.ColPali.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map=device ) - processor = ColPaliProcessor.from_pretrained(model_name) + processor = ce.ColPaliProcessor.from_pretrained(model_name) # Detect dimension dimension = _detect_colpali_dimension(model, processor, device) @@ -130,6 +124,7 @@ class ColPaliEmbedImage(op.FunctionSpec): @op.executor_class( gpu=True, cache=True, + batching=True, behavior_version=1, ) class ColPaliEmbedImageExecutor: @@ -146,7 +141,7 @@ def analyze(self) -> type: dimension = self._model_info.dimension return Vector[Vector[np.float32, Literal[dimension]]] # type: ignore - def __call__(self, img_bytes: bytes) -> Any: + def __call__(self, img_bytes_list: list[bytes]) -> Any: try: from PIL import Image import torch @@ -160,8 +155,11 @@ def __call__(self, img_bytes: bytes) -> Any: processor = self._model_info.processor device = self._model_info.device - pil_image = Image.open(io.BytesIO(img_bytes)).convert("RGB") - inputs = processor.process_images([pil_image]).to(device) + pil_images = [ + Image.open(io.BytesIO(img_bytes)).convert("RGB") + for img_bytes in img_bytes_list + ] + inputs = processor.process_images(pil_images).to(device) with torch.no_grad(): embeddings = model(**inputs) @@ -171,10 +169,8 @@ def __call__(self, img_bytes: bytes) -> Any: f"Expected 3D tensor [batch, patches, hidden_dim], got shape {embeddings.shape}" ) - # Keep patch-level embeddings: [batch, patches, hidden_dim] -> [patches, hidden_dim] - patch_embeddings = embeddings[0] # Remove batch dimension - - return patch_embeddings.cpu().to(torch.float32).numpy() + # [patches, hidden_dim] + return embeddings.cpu().to(torch.float32).numpy() class ColPaliEmbedQuery(op.FunctionSpec): @@ -207,6 +203,7 @@ class ColPaliEmbedQuery(op.FunctionSpec): gpu=True, cache=True, behavior_version=1, + batching=True, ) class ColPaliEmbedQueryExecutor: """Executor for ColVision query embedding (ColPali, ColQwen2, ColSmol, etc.).""" @@ -222,7 +219,7 @@ def analyze(self) -> type: dimension = self._model_info.dimension return Vector[Vector[np.float32, Literal[dimension]]] # type: ignore - def __call__(self, query: str) -> Any: + def __call__(self, queries: list[str]) -> Any: try: import torch except ImportError as e: @@ -234,7 +231,7 @@ def __call__(self, query: str) -> Any: processor = self._model_info.processor device = self._model_info.device - inputs = processor.process_queries([query]).to(device) + inputs = processor.process_queries(queries).to(device) with torch.no_grad(): embeddings = model(**inputs) @@ -244,7 +241,5 @@ def __call__(self, query: str) -> Any: f"Expected 3D tensor [batch, tokens, hidden_dim], got shape {embeddings.shape}" ) - # Keep token-level embeddings: [batch, tokens, hidden_dim] -> [tokens, hidden_dim] - token_embeddings = embeddings[0] # Remove batch dimension - - return token_embeddings.cpu().to(torch.float32).numpy() + # [tokens, hidden_dim] + return embeddings.cpu().to(torch.float32).numpy()