11"""All builtin functions."""
22
33import dataclasses
4+ import functools
45from typing import Annotated , Any , Literal
56
67import numpy as np
@@ -101,69 +102,51 @@ def __call__(self, text: str) -> NDArray[np.float32]:
101102 return result
102103
103104
104- # Global ColPali model cache to avoid reloading models
105- _COLPALI_MODEL_CACHE = {}
106-
107-
108- class _ColPaliModelManager :
109- """Shared model manager for ColPali models to avoid duplicate loading."""
110-
111- @staticmethod
112- def get_model_and_processor (model_name : str ) -> dict [str , Any ]:
113- """Get or load ColPali model and processor, with caching."""
114- if model_name not in _COLPALI_MODEL_CACHE :
115- try :
116- from colpali_engine .models import ColPali , ColPaliProcessor # type: ignore[import-untyped]
117- except ImportError as e :
118- raise ImportError (
119- "ColPali is not available. Make sure cocoindex is installed with ColPali support."
120- ) from e
121-
122- model = ColPali .from_pretrained (model_name )
123- processor = ColPaliProcessor .from_pretrained (model_name )
124-
125- # Get dimension from FastEmbed API
126- dimension = _ColPaliModelManager ._detect_dimension ()
127-
128- _COLPALI_MODEL_CACHE [model_name ] = {
129- "model" : model ,
130- "processor" : processor ,
131- "dimension" : dimension ,
132- }
133-
134- return _COLPALI_MODEL_CACHE [model_name ]
135-
136- @staticmethod
137- def _detect_dimension () -> int :
138- """Detect ColPali embedding dimension using FastEmbed API."""
139- try :
140- from fastembed import LateInteractionMultimodalEmbedding
141-
142- # Use the standard FastEmbed ColPali model for dimension detection
143- standard_colpali_model = "Qdrant/colpali-v1.3-fp16"
144-
145- supported_models = (
146- LateInteractionMultimodalEmbedding .list_supported_models ()
147- )
148- for supported_model in supported_models :
149- if supported_model ["model" ] == standard_colpali_model :
150- dim = supported_model ["dim" ]
151- if isinstance (dim , int ):
152- return dim
153- else :
154- raise ValueError (
155- f"Expected integer dimension, got { type (dim )} : { dim } "
156- )
157-
158- raise ValueError (
159- f"Could not find dimension for ColPali model in FastEmbed supported models"
160- )
161-
162- except ImportError :
163- raise ImportError (
164- "FastEmbed is required for ColPali dimension detection. "
165- "Install it with: pip install fastembed"
166- )
105+ @functools .cache
106+ def _get_colpali_model_and_processor (model_name : str ) -> dict [str , Any ]:
107+ """Get or load ColPali model and processor, with caching."""
108+ try :
109+ from colpali_engine .models import ColPali , ColPaliProcessor # type: ignore[import-untyped]
110+ except ImportError as e :
111+ raise ImportError (
112+ "ColPali is not available. Make sure cocoindex is installed with ColPali support."
113+ ) from e
114+
115+ model = ColPali .from_pretrained (model_name )
116+ processor = ColPaliProcessor .from_pretrained (model_name )
117+
118+ # Get dimension from the actual model
119+ dimension = _detect_colpali_dimension (model , processor )
120+
121+ return {
122+ "model" : model ,
123+ "processor" : processor ,
124+ "dimension" : dimension ,
125+ }
126+
127+
128+ def _detect_colpali_dimension (model : Any , processor : Any ) -> int :
129+ """Detect ColPali embedding dimension from the actual model config."""
130+ # Try to access embedding dimension
131+ if hasattr (model .config , "embedding_dim" ):
132+ dim = model .config .embedding_dim
133+ else :
134+ # Fallback: infer from output shape with dummy data
135+ from PIL import Image
136+ import numpy as np
137+ import torch
138+
139+ dummy_img = Image .fromarray (np .zeros ((224 , 224 , 3 ), np .uint8 ))
140+ # Use the processor to process the dummy image
141+ processed = processor .process_images ([dummy_img ])
142+ with torch .no_grad ():
143+ output = model (** processed )
144+ dim = int (output .shape [- 1 ])
145+ if isinstance (dim , int ):
146+ return dim
147+ else :
148+ raise ValueError (f"Expected integer dimension, got { type (dim )} : { dim } " )
149+ return dim
167150
168151
169152class ColPaliEmbedImage (op .FunctionSpec ):
@@ -198,9 +181,7 @@ class ColPaliEmbedImageExecutor:
198181 def analyze (self , _img_bytes : Any ) -> type :
199182 # Get shared model and dimension
200183 if self ._cached_model_data is None :
201- self ._cached_model_data = _ColPaliModelManager .get_model_and_processor (
202- self .spec .model
203- )
184+ self ._cached_model_data = _get_colpali_model_and_processor (self .spec .model )
204185
205186 # Return multi-vector type: Variable patches x Fixed hidden dimension
206187 dimension = self ._cached_model_data ["dimension" ]
@@ -218,9 +199,7 @@ def __call__(self, img_bytes: bytes) -> list[list[float]]:
218199
219200 # Get shared model and processor
220201 if self ._cached_model_data is None :
221- self ._cached_model_data = _ColPaliModelManager .get_model_and_processor (
222- self .spec .model
223- )
202+ self ._cached_model_data = _get_colpali_model_and_processor (self .spec .model )
224203
225204 model = self ._cached_model_data ["model" ]
226205 processor = self ._cached_model_data ["processor" ]
@@ -279,9 +258,7 @@ class ColPaliEmbedQueryExecutor:
279258 def analyze (self , _query : Any ) -> type :
280259 # Get shared model and dimension
281260 if self ._cached_model_data is None :
282- self ._cached_model_data = _ColPaliModelManager .get_model_and_processor (
283- self .spec .model
284- )
261+ self ._cached_model_data = _get_colpali_model_and_processor (self .spec .model )
285262
286263 # Return multi-vector type: Variable tokens x Fixed hidden dimension
287264 dimension = self ._cached_model_data ["dimension" ]
@@ -297,9 +274,7 @@ def __call__(self, query: str) -> list[list[float]]:
297274
298275 # Get shared model and processor
299276 if self ._cached_model_data is None :
300- self ._cached_model_data = _ColPaliModelManager .get_model_and_processor (
301- self .spec .model
302- )
277+ self ._cached_model_data = _get_colpali_model_and_processor (self .spec .model )
303278
304279 model = self ._cached_model_data ["model" ]
305280 processor = self ._cached_model_data ["processor" ]
0 commit comments