Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 29 additions & 34 deletions python/cocoindex/functions/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import functools
from dataclasses import dataclass
from typing import Any, Optional, TYPE_CHECKING, Literal
from typing import Any, TYPE_CHECKING, Literal
import numpy as np

from .. import op
Expand All @@ -22,18 +22,11 @@ class ColPaliModelInfo:
dimension: int


@functools.lru_cache(maxsize=None)
@functools.cache
def _get_colpali_model_and_processor(model_name: str) -> ColPaliModelInfo:
"""Load and cache ColPali model and processor with shared device setup."""
try:
from colpali_engine import ( # type: ignore[import-untyped]
ColPali,
ColPaliProcessor,
ColQwen2,
ColQwen2Processor,
ColSmol,
ColSmolProcessor,
)
import colpali_engine as ce # type: ignore[import-untyped]
import torch
except ImportError as e:
raise ImportError(
Expand All @@ -42,29 +35,30 @@ def _get_colpali_model_and_processor(model_name: str) -> ColPaliModelInfo:
) from e

device = "cuda" if torch.cuda.is_available() else "cpu"
lower_model_name = model_name.lower()

# Determine model type from name
if "colpali" in model_name.lower():
model = ColPali.from_pretrained(
if lower_model_name.startswith("colpali"):
model = ce.ColPali.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map=device
)
processor = ColPaliProcessor.from_pretrained(model_name)
elif "colqwen" in model_name.lower():
model = ColQwen2.from_pretrained(
processor = ce.ColPaliProcessor.from_pretrained(model_name)
elif lower_model_name.startswith("colqwen2.5"):
model = ce.ColQwen2_5.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map=device
)
processor = ColQwen2Processor.from_pretrained(model_name)
elif "colsmol" in model_name.lower():
model = ColSmol.from_pretrained(
processor = ce.ColQwen2_5_Processor.from_pretrained(model_name)
elif lower_model_name.startswith("colqwen"):
model = ce.ColQwen2.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map=device
)
processor = ColSmolProcessor.from_pretrained(model_name)
processor = ce.ColQwen2Processor.from_pretrained(model_name)
else:
# Fallback to ColPali for backwards compatibility
model = ColPali.from_pretrained(
model = ce.ColPali.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map=device
)
processor = ColPaliProcessor.from_pretrained(model_name)
processor = ce.ColPaliProcessor.from_pretrained(model_name)

# Detect dimension
dimension = _detect_colpali_dimension(model, processor, device)
Expand Down Expand Up @@ -130,6 +124,7 @@ class ColPaliEmbedImage(op.FunctionSpec):
@op.executor_class(
gpu=True,
cache=True,
batching=True,
behavior_version=1,
)
class ColPaliEmbedImageExecutor:
Expand All @@ -146,7 +141,7 @@ def analyze(self) -> type:
dimension = self._model_info.dimension
return Vector[Vector[np.float32, Literal[dimension]]] # type: ignore

def __call__(self, img_bytes: bytes) -> Any:
def __call__(self, img_bytes_list: list[bytes]) -> Any:
try:
from PIL import Image
import torch
Expand All @@ -160,8 +155,11 @@ def __call__(self, img_bytes: bytes) -> Any:
processor = self._model_info.processor
device = self._model_info.device

pil_image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
inputs = processor.process_images([pil_image]).to(device)
pil_images = [
Image.open(io.BytesIO(img_bytes)).convert("RGB")
for img_bytes in img_bytes_list
]
inputs = processor.process_images(pil_images).to(device)
with torch.no_grad():
embeddings = model(**inputs)

Expand All @@ -171,10 +169,8 @@ def __call__(self, img_bytes: bytes) -> Any:
f"Expected 3D tensor [batch, patches, hidden_dim], got shape {embeddings.shape}"
)

# Keep patch-level embeddings: [batch, patches, hidden_dim] -> [patches, hidden_dim]
patch_embeddings = embeddings[0] # Remove batch dimension

return patch_embeddings.cpu().to(torch.float32).numpy()
# [patches, hidden_dim]
return embeddings.cpu().to(torch.float32).numpy()


class ColPaliEmbedQuery(op.FunctionSpec):
Expand Down Expand Up @@ -207,6 +203,7 @@ class ColPaliEmbedQuery(op.FunctionSpec):
gpu=True,
cache=True,
behavior_version=1,
batching=True,
)
class ColPaliEmbedQueryExecutor:
"""Executor for ColVision query embedding (ColPali, ColQwen2, ColSmol, etc.)."""
Expand All @@ -222,7 +219,7 @@ def analyze(self) -> type:
dimension = self._model_info.dimension
return Vector[Vector[np.float32, Literal[dimension]]] # type: ignore

def __call__(self, query: str) -> Any:
def __call__(self, queries: list[str]) -> Any:
try:
import torch
except ImportError as e:
Expand All @@ -234,7 +231,7 @@ def __call__(self, query: str) -> Any:
processor = self._model_info.processor
device = self._model_info.device

inputs = processor.process_queries([query]).to(device)
inputs = processor.process_queries(queries).to(device)
with torch.no_grad():
embeddings = model(**inputs)

Expand All @@ -244,7 +241,5 @@ def __call__(self, query: str) -> Any:
f"Expected 3D tensor [batch, tokens, hidden_dim], got shape {embeddings.shape}"
)

# Keep token-level embeddings: [batch, tokens, hidden_dim] -> [tokens, hidden_dim]
token_embeddings = embeddings[0] # Remove batch dimension

return token_embeddings.cpu().to(torch.float32).numpy()
# [tokens, hidden_dim]
return embeddings.cpu().to(torch.float32).numpy()
Loading