22
33import functools
44from dataclasses import dataclass
5- from typing import Any , Optional , TYPE_CHECKING , Literal
5+ from typing import Any , TYPE_CHECKING , Literal
6+ from typing import Any , TYPE_CHECKING , Literal
67import numpy as np
78
89from .. import op
@@ -22,18 +23,11 @@ class ColPaliModelInfo:
2223 dimension : int
2324
2425
25- @functools .lru_cache ( maxsize = None )
26+ @functools .cache
2627def _get_colpali_model_and_processor (model_name : str ) -> ColPaliModelInfo :
2728 """Load and cache ColPali model and processor with shared device setup."""
2829 try :
29- from colpali_engine import ( # type: ignore[import-untyped]
30- ColPali ,
31- ColPaliProcessor ,
32- ColQwen2 ,
33- ColQwen2Processor ,
34- ColSmol ,
35- ColSmolProcessor ,
36- )
30+ import colpali_engine as ce # type: ignore[import-untyped]
3731 import torch
3832 except ImportError as e :
3933 raise ImportError (
@@ -42,29 +36,30 @@ def _get_colpali_model_and_processor(model_name: str) -> ColPaliModelInfo:
4236 ) from e
4337
4438 device = "cuda" if torch .cuda .is_available () else "cpu"
39+ lower_model_name = model_name .lower ()
4540
4641 # Determine model type from name
47- if "colpali" in model_name . lower ( ):
48- model = ColPali .from_pretrained (
42+ if lower_model_name . startswith ( "colpali" ):
43+ model = ce . ColPali .from_pretrained (
4944 model_name , torch_dtype = torch .bfloat16 , device_map = device
5045 )
51- processor = ColPaliProcessor .from_pretrained (model_name )
52- elif "colqwen" in model_name . lower ( ):
53- model = ColQwen2 .from_pretrained (
46+ processor = ce . ColPaliProcessor .from_pretrained (model_name )
47+ elif lower_model_name . startswith ( "colqwen2.5" ):
48+ model = ce . ColQwen2_5 .from_pretrained (
5449 model_name , torch_dtype = torch .bfloat16 , device_map = device
5550 )
56- processor = ColQwen2Processor .from_pretrained (model_name )
57- elif "colsmol" in model_name . lower ( ):
58- model = ColSmol .from_pretrained (
51+ processor = ce . ColQwen2_5_Processor .from_pretrained (model_name )
52+ elif lower_model_name . startswith ( "colqwen" ):
53+ model = ce . ColQwen2 .from_pretrained (
5954 model_name , torch_dtype = torch .bfloat16 , device_map = device
6055 )
61- processor = ColSmolProcessor .from_pretrained (model_name )
56+ processor = ce . ColQwen2Processor .from_pretrained (model_name )
6257 else :
6358 # Fallback to ColPali for backwards compatibility
64- model = ColPali .from_pretrained (
59+ model = ce . ColPali .from_pretrained (
6560 model_name , torch_dtype = torch .bfloat16 , device_map = device
6661 )
67- processor = ColPaliProcessor .from_pretrained (model_name )
62+ processor = ce . ColPaliProcessor .from_pretrained (model_name )
6863
6964 # Detect dimension
7065 dimension = _detect_colpali_dimension (model , processor , device )
@@ -130,6 +125,7 @@ class ColPaliEmbedImage(op.FunctionSpec):
130125@op .executor_class (
131126 gpu = True ,
132127 cache = True ,
128+ batching = True ,
133129 behavior_version = 1 ,
134130)
135131class ColPaliEmbedImageExecutor :
@@ -146,7 +142,7 @@ def analyze(self) -> type:
146142 dimension = self ._model_info .dimension
147143 return Vector [Vector [np .float32 , Literal [dimension ]]] # type: ignore
148144
149- def __call__ (self , img_bytes : bytes ) -> Any :
145+ def __call__ (self , img_bytes_list : list [ bytes ] ) -> Any :
150146 try :
151147 from PIL import Image
152148 import torch
@@ -160,8 +156,11 @@ def __call__(self, img_bytes: bytes) -> Any:
160156 processor = self ._model_info .processor
161157 device = self ._model_info .device
162158
163- pil_image = Image .open (io .BytesIO (img_bytes )).convert ("RGB" )
164- inputs = processor .process_images ([pil_image ]).to (device )
159+ pil_images = [
160+ Image .open (io .BytesIO (img_bytes )).convert ("RGB" )
161+ for img_bytes in img_bytes_list
162+ ]
163+ inputs = processor .process_images (pil_images ).to (device )
165164 with torch .no_grad ():
166165 embeddings = model (** inputs )
167166
@@ -171,10 +170,8 @@ def __call__(self, img_bytes: bytes) -> Any:
171170 f"Expected 3D tensor [batch, patches, hidden_dim], got shape { embeddings .shape } "
172171 )
173172
174- # Keep patch-level embeddings: [batch, patches, hidden_dim] -> [patches, hidden_dim]
175- patch_embeddings = embeddings [0 ] # Remove batch dimension
176-
177- return patch_embeddings .cpu ().to (torch .float32 ).numpy ()
173+ # [patches, hidden_dim]
174+ return embeddings .cpu ().to (torch .float32 ).numpy ()
178175
179176
180177class ColPaliEmbedQuery (op .FunctionSpec ):
@@ -207,6 +204,7 @@ class ColPaliEmbedQuery(op.FunctionSpec):
207204 gpu = True ,
208205 cache = True ,
209206 behavior_version = 1 ,
207+ batching = True ,
210208)
211209class ColPaliEmbedQueryExecutor :
212210 """Executor for ColVision query embedding (ColPali, ColQwen2, ColSmol, etc.)."""
@@ -222,7 +220,7 @@ def analyze(self) -> type:
222220 dimension = self ._model_info .dimension
223221 return Vector [Vector [np .float32 , Literal [dimension ]]] # type: ignore
224222
225- def __call__ (self , query : str ) -> Any :
223+ def __call__ (self , queries : list [ str ] ) -> Any :
226224 try :
227225 import torch
228226 except ImportError as e :
@@ -234,7 +232,7 @@ def __call__(self, query: str) -> Any:
234232 processor = self ._model_info .processor
235233 device = self ._model_info .device
236234
237- inputs = processor .process_queries ([ query ] ).to (device )
235+ inputs = processor .process_queries (queries ).to (device )
238236 with torch .no_grad ():
239237 embeddings = model (** inputs )
240238
@@ -244,7 +242,5 @@ def __call__(self, query: str) -> Any:
244242 f"Expected 3D tensor [batch, tokens, hidden_dim], got shape { embeddings .shape } "
245243 )
246244
247- # Keep token-level embeddings: [batch, tokens, hidden_dim] -> [tokens, hidden_dim]
248- token_embeddings = embeddings [0 ] # Remove batch dimension
249-
250- return token_embeddings .cpu ().to (torch .float32 ).numpy ()
245+ # [tokens, hidden_dim]
246+ return embeddings .cpu ().to (torch .float32 ).numpy ()
0 commit comments