22
33import functools
44from dataclasses import dataclass
5- from typing import Any , Optional , TYPE_CHECKING , Literal
5+ from typing import Any , TYPE_CHECKING , Literal
66import numpy as np
77
88from .. import op
@@ -22,18 +22,11 @@ class ColPaliModelInfo:
2222 dimension : int
2323
2424
25- @functools .lru_cache ( maxsize = None )
25+ @functools .cache
2626def _get_colpali_model_and_processor (model_name : str ) -> ColPaliModelInfo :
2727 """Load and cache ColPali model and processor with shared device setup."""
2828 try :
29- from colpali_engine import ( # type: ignore[import-untyped]
30- ColPali ,
31- ColPaliProcessor ,
32- ColQwen2 ,
33- ColQwen2Processor ,
34- ColSmol ,
35- ColSmolProcessor ,
36- )
29+ import colpali_engine as ce # type: ignore[import-untyped]
3730 import torch
3831 except ImportError as e :
3932 raise ImportError (
@@ -42,29 +35,30 @@ def _get_colpali_model_and_processor(model_name: str) -> ColPaliModelInfo:
4235 ) from e
4336
4437 device = "cuda" if torch .cuda .is_available () else "cpu"
38+ lower_model_name = model_name .lower ()
4539
4640 # Determine model type from name
47- if "colpali" in model_name . lower ( ):
48- model = ColPali .from_pretrained (
41+ if lower_model_name . startswith ( "colpali" ):
42+ model = ce . ColPali .from_pretrained (
4943 model_name , torch_dtype = torch .bfloat16 , device_map = device
5044 )
51- processor = ColPaliProcessor .from_pretrained (model_name )
52- elif "colqwen" in model_name . lower ( ):
53- model = ColQwen2 .from_pretrained (
45+ processor = ce . ColPaliProcessor .from_pretrained (model_name )
46+ elif lower_model_name . startswith ( "colqwen2.5" ):
47+ model = ce . ColQwen2_5 .from_pretrained (
5448 model_name , torch_dtype = torch .bfloat16 , device_map = device
5549 )
56- processor = ColQwen2Processor .from_pretrained (model_name )
57- elif "colsmol" in model_name . lower ( ):
58- model = ColSmol .from_pretrained (
50+ processor = ce . ColQwen2_5_Processor .from_pretrained (model_name )
51+ elif lower_model_name . startswith ( "colqwen" ):
52+ model = ce . ColQwen2 .from_pretrained (
5953 model_name , torch_dtype = torch .bfloat16 , device_map = device
6054 )
61- processor = ColSmolProcessor .from_pretrained (model_name )
55+ processor = ce . ColQwen2Processor .from_pretrained (model_name )
6256 else :
6357 # Fallback to ColPali for backwards compatibility
64- model = ColPali .from_pretrained (
58+ model = ce . ColPali .from_pretrained (
6559 model_name , torch_dtype = torch .bfloat16 , device_map = device
6660 )
67- processor = ColPaliProcessor .from_pretrained (model_name )
61+ processor = ce . ColPaliProcessor .from_pretrained (model_name )
6862
6963 # Detect dimension
7064 dimension = _detect_colpali_dimension (model , processor , device )
@@ -130,6 +124,7 @@ class ColPaliEmbedImage(op.FunctionSpec):
130124@op .executor_class (
131125 gpu = True ,
132126 cache = True ,
127+ batching = True ,
133128 behavior_version = 1 ,
134129)
135130class ColPaliEmbedImageExecutor :
@@ -146,7 +141,7 @@ def analyze(self) -> type:
146141 dimension = self ._model_info .dimension
147142 return Vector [Vector [np .float32 , Literal [dimension ]]] # type: ignore
148143
149- def __call__ (self , img_bytes : bytes ) -> Any :
144+ def __call__ (self , img_bytes_list : list [ bytes ] ) -> Any :
150145 try :
151146 from PIL import Image
152147 import torch
@@ -160,8 +155,11 @@ def __call__(self, img_bytes: bytes) -> Any:
160155 processor = self ._model_info .processor
161156 device = self ._model_info .device
162157
163- pil_image = Image .open (io .BytesIO (img_bytes )).convert ("RGB" )
164- inputs = processor .process_images ([pil_image ]).to (device )
158+ pil_images = [
159+ Image .open (io .BytesIO (img_bytes )).convert ("RGB" )
160+ for img_bytes in img_bytes_list
161+ ]
162+ inputs = processor .process_images (pil_images ).to (device )
165163 with torch .no_grad ():
166164 embeddings = model (** inputs )
167165
@@ -171,10 +169,8 @@ def __call__(self, img_bytes: bytes) -> Any:
171169 f"Expected 3D tensor [batch, patches, hidden_dim], got shape { embeddings .shape } "
172170 )
173171
174- # Keep patch-level embeddings: [batch, patches, hidden_dim] -> [patches, hidden_dim]
175- patch_embeddings = embeddings [0 ] # Remove batch dimension
176-
177- return patch_embeddings .cpu ().to (torch .float32 ).numpy ()
172+ # [patches, hidden_dim]
173+ return embeddings .cpu ().to (torch .float32 ).numpy ()
178174
179175
180176class ColPaliEmbedQuery (op .FunctionSpec ):
@@ -207,6 +203,7 @@ class ColPaliEmbedQuery(op.FunctionSpec):
207203 gpu = True ,
208204 cache = True ,
209205 behavior_version = 1 ,
206+ batching = True ,
210207)
211208class ColPaliEmbedQueryExecutor :
212209 """Executor for ColVision query embedding (ColPali, ColQwen2, ColSmol, etc.)."""
@@ -222,7 +219,7 @@ def analyze(self) -> type:
222219 dimension = self ._model_info .dimension
223220 return Vector [Vector [np .float32 , Literal [dimension ]]] # type: ignore
224221
225- def __call__ (self , query : str ) -> Any :
222+ def __call__ (self , queries : list [ str ] ) -> Any :
226223 try :
227224 import torch
228225 except ImportError as e :
@@ -234,7 +231,7 @@ def __call__(self, query: str) -> Any:
234231 processor = self ._model_info .processor
235232 device = self ._model_info .device
236233
237- inputs = processor .process_queries ([ query ] ).to (device )
234+ inputs = processor .process_queries (queries ).to (device )
238235 with torch .no_grad ():
239236 embeddings = model (** inputs )
240237
@@ -244,7 +241,5 @@ def __call__(self, query: str) -> Any:
244241 f"Expected 3D tensor [batch, tokens, hidden_dim], got shape { embeddings .shape } "
245242 )
246243
247- # Keep token-level embeddings: [batch, tokens, hidden_dim] -> [tokens, hidden_dim]
248- token_embeddings = embeddings [0 ] # Remove batch dimension
249-
250- return token_embeddings .cpu ().to (torch .float32 ).numpy ()
244+ # [tokens, hidden_dim]
245+ return embeddings .cpu ().to (torch .float32 ).numpy ()
0 commit comments