|
6 | 6 | from typing import Any, Literal |
7 | 7 |
|
8 | 8 | import cocoindex |
9 | | -import torch |
| 9 | +import numpy as np |
10 | 10 | from dotenv import load_dotenv |
11 | | -from fastapi import FastAPI, Query |
| 11 | +from fastapi import FastAPI, Query, HTTPException |
12 | 12 | from fastapi.middleware.cors import CORSMiddleware |
13 | 13 | from fastapi.staticfiles import StaticFiles |
14 | 14 | from PIL import Image |
15 | 15 | from qdrant_client import QdrantClient |
16 | | -from transformers import CLIPModel, CLIPProcessor |
| 16 | +from colpali_engine.models import ColPali, ColPaliProcessor |
| 17 | + |
| 18 | + |
| 19 | +# --- Config --- |
| 20 | + |
| 21 | +# Use GRPC |
| 22 | +QDRANT_URL = os.getenv("QDRANT_URL", "localhost:6334") |
| 23 | +PREFER_GRPC = os.getenv("QDRANT_PREFER_GRPC", "true").lower() == "true" |
| 24 | + |
| 25 | +# Use HTTP |
| 26 | +# QDRANT_URL = os.getenv("QDRANT_URL", "localhost:6333") |
| 27 | +# PREFER_GRPC = os.getenv("QDRANT_PREFER_GRPC", "false").lower() == "true" |
17 | 28 |
|
18 | 29 | OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434/") |
19 | | -QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6334/") |
20 | | -QDRANT_COLLECTION = "ImageSearch" |
21 | | -CLIP_MODEL_NAME = "openai/clip-vit-large-patch14" |
22 | | -CLIP_MODEL_DIMENSION = 768 |
| 30 | +QDRANT_COLLECTION = "ImageSearchColpali" |
| 31 | +COLPALI_MODEL_NAME = os.getenv("COLPALI_MODEL", "vidore/colpali-v1.2") |
| 32 | +COLPALI_MODEL_DIMENSION = 1031 # Set to match ColPali's output |
| 33 | + |
| 34 | +# --- ColPali model cache and embedding functions --- |
| 35 | +_colpali_model_cache = {} |
| 36 | + |
| 37 | + |
| 38 | +def get_colpali_model(model: str = COLPALI_MODEL_NAME): |
| 39 | + global _colpali_model_cache |
| 40 | + if model not in _colpali_model_cache: |
| 41 | + print(f"Loading ColPali model: {model}") |
| 42 | + _colpali_model_cache[model] = { |
| 43 | + "model": ColPali.from_pretrained(model), |
| 44 | + "processor": ColPaliProcessor.from_pretrained(model), |
| 45 | + } |
| 46 | + return _colpali_model_cache[model]["model"], _colpali_model_cache[model][ |
| 47 | + "processor" |
| 48 | + ] |
| 49 | + |
23 | 50 |
|
| 51 | +def colpali_embed_image( |
| 52 | + img_bytes: bytes, model: str = COLPALI_MODEL_NAME |
| 53 | +) -> list[float]: |
| 54 | + from PIL import Image |
| 55 | + import torch |
| 56 | + import io |
24 | 57 |
|
25 | | -@functools.cache |
26 | | -def get_clip_model() -> tuple[CLIPModel, CLIPProcessor]: |
27 | | - model = CLIPModel.from_pretrained(CLIP_MODEL_NAME) |
28 | | - processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME) |
29 | | - return model, processor |
| 58 | + colpali_model, processor = get_colpali_model(model) |
| 59 | + pil_image = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
| 60 | + inputs = processor.process_images([pil_image]) |
| 61 | + with torch.no_grad(): |
| 62 | + embeddings = colpali_model(**inputs) |
| 63 | + pooled_embedding = embeddings.mean(dim=-1) |
| 64 | + result = pooled_embedding[0].cpu().numpy() # [1031] |
| 65 | + return result.tolist() |
| 66 | + |
| 67 | + |
| 68 | +def colpali_embed_query(query: str, model: str = COLPALI_MODEL_NAME) -> list[float]: |
| 69 | + import torch |
| 70 | + import numpy as np |
| 71 | + |
| 72 | + colpali_model, processor = get_colpali_model(model) |
| 73 | + inputs = processor.process_queries([query]) |
| 74 | + with torch.no_grad(): |
| 75 | + embeddings = colpali_model(**inputs) |
| 76 | + pooled_embedding = embeddings.mean(dim=-1) |
| 77 | + query_tokens = pooled_embedding[0].cpu().numpy() # [15] |
| 78 | + target_length = COLPALI_MODEL_DIMENSION |
| 79 | + result = np.zeros(target_length, dtype=np.float32) |
| 80 | + result[: min(len(query_tokens), target_length)] = query_tokens[:target_length] |
| 81 | + return result.tolist() |
| 82 | + |
| 83 | + |
| 84 | +# --- End ColPali embedding functions --- |
30 | 85 |
|
31 | 86 |
|
32 | 87 | def embed_query(text: str) -> list[float]: |
33 | 88 | """ |
34 | | - Embed the caption using CLIP model. |
| 89 | + Embed the caption using ColPali model. |
35 | 90 | """ |
36 | | - model, processor = get_clip_model() |
37 | | - inputs = processor(text=[text], return_tensors="pt", padding=True) |
38 | | - with torch.no_grad(): |
39 | | - features = model.get_text_features(**inputs) |
40 | | - return features[0].tolist() |
| 91 | + return colpali_embed_query(text, model=COLPALI_MODEL_NAME) |
41 | 92 |
|
42 | 93 |
|
43 | 94 | @cocoindex.op.function(cache=True, behavior_version=1, gpu=True) |
44 | 95 | def embed_image( |
45 | 96 | img_bytes: bytes, |
46 | | -) -> cocoindex.Vector[cocoindex.Float32, Literal[CLIP_MODEL_DIMENSION]]: |
| 97 | +) -> cocoindex.Vector[cocoindex.Float32, Literal[COLPALI_MODEL_DIMENSION]]: |
47 | 98 | """ |
48 | | - Convert image to embedding using CLIP model. |
| 99 | + Convert image to embedding using ColPali model. |
49 | 100 | """ |
50 | | - model, processor = get_clip_model() |
51 | | - image = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
52 | | - inputs = processor(images=image, return_tensors="pt") |
53 | | - with torch.no_grad(): |
54 | | - features = model.get_image_features(**inputs) |
55 | | - return features[0].tolist() |
| 101 | + return colpali_embed_image(img_bytes, model=COLPALI_MODEL_NAME) |
56 | 102 |
|
57 | 103 |
|
58 | | -# CocoIndex flow: Ingest images, extract captions, embed, export to Qdrant |
59 | | -@cocoindex.flow_def(name="ImageObjectEmbedding") |
| 104 | +@cocoindex.flow_def(name="ImageObjectEmbeddingColpali") |
60 | 105 | def image_object_embedding_flow( |
61 | 106 | flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope |
62 | 107 | ) -> None: |
63 | 108 | data_scope["images"] = flow_builder.add_source( |
64 | 109 | cocoindex.sources.LocalFile( |
65 | 110 | path="img", included_patterns=["*.jpg", "*.jpeg", "*.png"], binary=True |
66 | 111 | ), |
67 | | - refresh_interval=datetime.timedelta( |
68 | | - minutes=1 |
69 | | - ), # Poll for changes every 1 minute |
| 112 | + refresh_interval=datetime.timedelta(minutes=1), |
70 | 113 | ) |
71 | 114 | img_embeddings = data_scope.add_collector() |
72 | 115 | with data_scope["images"].row() as img: |
@@ -117,7 +160,7 @@ async def lifespan(app: FastAPI) -> None: |
117 | 160 | cocoindex.init() |
118 | 161 | image_object_embedding_flow.setup(report_to_stdout=True) |
119 | 162 |
|
120 | | - app.state.qdrant_client = QdrantClient(url=QDRANT_URL, prefer_grpc=True) |
| 163 | + app.state.qdrant_client = QdrantClient(url=QDRANT_URL, prefer_grpc=PREFER_GRPC) |
121 | 164 |
|
122 | 165 | # Start updater |
123 | 166 | app.state.live_updater = cocoindex.FlowLiveUpdater(image_object_embedding_flow) |
@@ -162,9 +205,7 @@ def search( |
162 | 205 | { |
163 | 206 | "filename": result.payload["filename"], |
164 | 207 | "score": result.score, |
165 | | - "caption": result.payload.get( |
166 | | - "caption" |
167 | | - ), # Include caption if available |
| 208 | + "caption": result.payload.get("caption"), |
168 | 209 | } |
169 | 210 | for result in search_results |
170 | 211 | ] |
|
0 commit comments