|
1 | 1 | import datetime |
2 | | -import functools |
3 | | -import io |
4 | 2 | import os |
5 | 3 | from contextlib import asynccontextmanager |
6 | | -from typing import Any, Literal |
| 4 | +from typing import Any |
7 | 5 |
|
8 | 6 | import cocoindex |
9 | 7 | import numpy as np |
|
13 | 11 | from fastapi.staticfiles import StaticFiles |
14 | 12 | from PIL import Image |
15 | 13 | from qdrant_client import QdrantClient |
16 | | -from colpali_engine.models import ColPali, ColPaliProcessor |
17 | 14 |
|
18 | 15 |
|
19 | 16 | # --- Config --- |
|
29 | 26 | OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434/") |
30 | 27 | QDRANT_COLLECTION = "ImageSearchColpali" |
31 | 28 | COLPALI_MODEL_NAME = os.getenv("COLPALI_MODEL", "vidore/colpali-v1.2") |
32 | | -COLPALI_MODEL_DIMENSION = 1031 # Set to match ColPali's output |
| 29 | +print(f"📐 Using ColPali model {COLPALI_MODEL_NAME}") |
33 | 30 |
|
34 | | -# --- ColPali model cache and embedding functions --- |
35 | | -_colpali_model_cache = {} |
36 | 31 |
|
| 32 | +# Create ColPali embedding function using the class-based pattern |
| 33 | +colpali_embed = cocoindex.functions.ColPaliEmbedImage(model=COLPALI_MODEL_NAME) |
37 | 34 |
|
38 | | -def get_colpali_model(model: str = COLPALI_MODEL_NAME): |
39 | | - global _colpali_model_cache |
40 | | - if model not in _colpali_model_cache: |
41 | | - print(f"Loading ColPali model: {model}") |
42 | | - _colpali_model_cache[model] = { |
43 | | - "model": ColPali.from_pretrained(model), |
44 | | - "processor": ColPaliProcessor.from_pretrained(model), |
45 | | - } |
46 | | - return _colpali_model_cache[model]["model"], _colpali_model_cache[model][ |
47 | | - "processor" |
48 | | - ] |
49 | | - |
50 | | - |
51 | | -def colpali_embed_image( |
52 | | - img_bytes: bytes, model: str = COLPALI_MODEL_NAME |
53 | | -) -> list[float]: |
54 | | - from PIL import Image |
55 | | - import torch |
56 | | - import io |
57 | | - |
58 | | - colpali_model, processor = get_colpali_model(model) |
59 | | - pil_image = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
60 | | - inputs = processor.process_images([pil_image]) |
61 | | - with torch.no_grad(): |
62 | | - embeddings = colpali_model(**inputs) |
63 | | - pooled_embedding = embeddings.mean(dim=-1) |
64 | | - result = pooled_embedding[0].cpu().numpy() # [1031] |
65 | | - return result.tolist() |
66 | | - |
67 | | - |
68 | | -def colpali_embed_query(query: str, model: str = COLPALI_MODEL_NAME) -> list[float]: |
69 | | - import torch |
70 | | - import numpy as np |
71 | | - |
72 | | - colpali_model, processor = get_colpali_model(model) |
73 | | - inputs = processor.process_queries([query]) |
74 | | - with torch.no_grad(): |
75 | | - embeddings = colpali_model(**inputs) |
76 | | - pooled_embedding = embeddings.mean(dim=-1) |
77 | | - query_tokens = pooled_embedding[0].cpu().numpy() # [15] |
78 | | - target_length = COLPALI_MODEL_DIMENSION |
79 | | - result = np.zeros(target_length, dtype=np.float32) |
80 | | - result[: min(len(query_tokens), target_length)] = query_tokens[:target_length] |
81 | | - return result.tolist() |
82 | | - |
83 | | - |
84 | | -# --- End ColPali embedding functions --- |
85 | 35 |
|
86 | | - |
87 | | -def embed_query(text: str) -> list[float]: |
88 | | - """ |
89 | | - Embed the caption using ColPali model. |
90 | | - """ |
91 | | - return colpali_embed_query(text, model=COLPALI_MODEL_NAME) |
92 | | - |
93 | | - |
94 | | -@cocoindex.op.function(cache=True, behavior_version=1, gpu=True) |
95 | | -def embed_image( |
96 | | - img_bytes: bytes, |
97 | | -) -> cocoindex.Vector[cocoindex.Float32, Literal[COLPALI_MODEL_DIMENSION]]: |
| 36 | +@cocoindex.transform_flow() |
| 37 | +def text_to_colpali_embedding( |
| 38 | + text: cocoindex.DataSlice[str], |
| 39 | +) -> cocoindex.DataSlice[list[list[float]]]: |
98 | 40 | """ |
99 | | - Convert image to embedding using ColPali model. |
| 41 | + Embed text using a ColPali model, returning multi-vector format. |
| 42 | + This is shared logic between indexing and querying, ensuring consistent embeddings. |
100 | 43 | """ |
101 | | - return colpali_embed_image(img_bytes, model=COLPALI_MODEL_NAME) |
| 44 | + return text.transform( |
| 45 | + cocoindex.functions.ColPaliEmbedQuery(model=COLPALI_MODEL_NAME) |
| 46 | + ) |
102 | 47 |
|
103 | 48 |
|
104 | 49 | @cocoindex.flow_def(name="ImageObjectEmbeddingColpali") |
@@ -131,7 +76,7 @@ def image_object_embedding_flow( |
131 | 76 | ), |
132 | 77 | image=img["content"], |
133 | 78 | ) |
134 | | - img["embedding"] = img["content"].transform(embed_image) |
| 79 | + img["embedding"] = img["content"].transform(colpali_embed) |
135 | 80 |
|
136 | 81 | collect_fields = { |
137 | 82 | "id": cocoindex.GeneratedField.UUID, |
@@ -189,24 +134,30 @@ def search( |
189 | 134 | q: str = Query(..., description="Search query"), |
190 | 135 | limit: int = Query(5, description="Number of results"), |
191 | 136 | ) -> Any: |
192 | | - # Get the embedding for the query |
193 | | - query_embedding = embed_query(q) |
| 137 | + # Get the multi-vector embedding for the query |
| 138 | + query_embedding = text_to_colpali_embedding.eval(q) |
| 139 | + print( |
| 140 | + f"🔍 Query multi-vector shape: {len(query_embedding)} tokens x {len(query_embedding[0]) if query_embedding else 0} dims" |
| 141 | + ) |
194 | 142 |
|
195 | | - # Search in Qdrant |
196 | | - search_results = app.state.qdrant_client.search( |
| 143 | + # Search in Qdrant with multi-vector MaxSim scoring using query_points API |
| 144 | + search_results = app.state.qdrant_client.query_points( |
197 | 145 | collection_name=QDRANT_COLLECTION, |
198 | | - query_vector=("embedding", query_embedding), |
| 146 | + query=query_embedding, # Multi-vector format: list[list[float]] |
| 147 | + using="embedding", # Specify the vector field name |
199 | 148 | limit=limit, |
200 | 149 | with_payload=True, |
201 | 150 | ) |
202 | 151 |
|
| 152 | + print(f"📈 Found {len(search_results.points)} results with MaxSim scoring") |
| 153 | + |
203 | 154 | return { |
204 | 155 | "results": [ |
205 | 156 | { |
206 | 157 | "filename": result.payload["filename"], |
207 | 158 | "score": result.score, |
208 | 159 | "caption": result.payload.get("caption"), |
209 | 160 | } |
210 | | - for result in search_results |
| 161 | + for result in search_results.points |
211 | 162 | ] |
212 | 163 | } |
0 commit comments