Skip to content

Commit 5f10f91

Browse files
committed
clean up for Colpali functions
1 parent a4122b9 commit 5f10f91

File tree

1 file changed

+4
-14
lines changed

1 file changed

+4
-14
lines changed

python/cocoindex/functions.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def analyze(self, _img_bytes: Any) -> type:
187187
dimension = self._cached_model_data["dimension"]
188188
return Vector[Vector[np.float32, Literal[dimension]]] # type: ignore
189189

190-
def __call__(self, img_bytes: bytes) -> list[list[float]]:
190+
def __call__(self, img_bytes: bytes) -> Any:
191191
try:
192192
from PIL import Image
193193
import torch
@@ -218,12 +218,7 @@ def __call__(self, img_bytes: bytes) -> list[list[float]]:
218218
# Keep patch-level embeddings: [batch, patches, hidden_dim] -> [patches, hidden_dim]
219219
patch_embeddings = embeddings[0] # Remove batch dimension
220220

221-
# Convert to list of lists: [[patch1_embedding], [patch2_embedding], ...]
222-
result = []
223-
for patch in patch_embeddings:
224-
result.append(patch.cpu().numpy().tolist())
225-
226-
return result
221+
return patch_embeddings.cpu().numpy()
227222

228223

229224
class ColPaliEmbedQuery(op.FunctionSpec):
@@ -264,7 +259,7 @@ def analyze(self, _query: Any) -> type:
264259
dimension = self._cached_model_data["dimension"]
265260
return Vector[Vector[np.float32, Literal[dimension]]] # type: ignore
266261

267-
def __call__(self, query: str) -> list[list[float]]:
262+
def __call__(self, query: str) -> Any:
268263
try:
269264
import torch
270265
except ImportError as e:
@@ -292,9 +287,4 @@ def __call__(self, query: str) -> list[list[float]]:
292287
# Keep token-level embeddings: [batch, tokens, hidden_dim] -> [tokens, hidden_dim]
293288
token_embeddings = embeddings[0] # Remove batch dimension
294289

295-
# Convert to list of lists: [[token1_embedding], [token2_embedding], ...]
296-
result = []
297-
for token in token_embeddings:
298-
result.append(token.cpu().numpy().tolist())
299-
300-
return result
290+
return token_embeddings.cpu().numpy()

0 commit comments

Comments
 (0)