@@ -187,7 +187,7 @@ def analyze(self, _img_bytes: Any) -> type:
187187 dimension = self ._cached_model_data ["dimension" ]
188188 return Vector [Vector [np .float32 , Literal [dimension ]]] # type: ignore
189189
190- def __call__ (self , img_bytes : bytes ) -> list [ list [ float ]] :
190+ def __call__ (self , img_bytes : bytes ) -> Any :
191191 try :
192192 from PIL import Image
193193 import torch
@@ -218,12 +218,7 @@ def __call__(self, img_bytes: bytes) -> list[list[float]]:
218218 # Keep patch-level embeddings: [batch, patches, hidden_dim] -> [patches, hidden_dim]
219219 patch_embeddings = embeddings [0 ] # Remove batch dimension
220220
221- # Convert to list of lists: [[patch1_embedding], [patch2_embedding], ...]
222- result = []
223- for patch in patch_embeddings :
224- result .append (patch .cpu ().numpy ().tolist ())
225-
226- return result
221+ return patch_embeddings .cpu ().numpy ()
227222
228223
229224class ColPaliEmbedQuery (op .FunctionSpec ):
@@ -264,7 +259,7 @@ def analyze(self, _query: Any) -> type:
264259 dimension = self ._cached_model_data ["dimension" ]
265260 return Vector [Vector [np .float32 , Literal [dimension ]]] # type: ignore
266261
267- def __call__ (self , query : str ) -> list [ list [ float ]] :
262+ def __call__ (self , query : str ) -> Any :
268263 try :
269264 import torch
270265 except ImportError as e :
@@ -292,9 +287,4 @@ def __call__(self, query: str) -> list[list[float]]:
292287 # Keep token-level embeddings: [batch, tokens, hidden_dim] -> [tokens, hidden_dim]
293288 token_embeddings = embeddings [0 ] # Remove batch dimension
294289
295- # Convert to list of lists: [[token1_embedding], [token2_embedding], ...]
296- result = []
297- for token in token_embeddings :
298- result .append (token .cpu ().numpy ().tolist ())
299-
300- return result
290+ return token_embeddings .cpu ().numpy ()
0 commit comments