Skip to content

Commit cf03bcd

Browse files
authored
example(img-search): use CLIP for both image and text (#538)
* example(img-search): use CLIP for both image and text * refactor: make the code cleaner
1 parent 71395f1 commit cf03bcd

File tree

3 files changed

+46
-58
lines changed

3 files changed

+46
-58
lines changed

examples/image_search/README.md

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
![image](https://github.com/user-attachments/assets/3a696344-c9b4-46e8-9413-6229dbb8672a)
44

55
- Qdrant for Vector Storage
6-
- Ollama Gemma3 (Image to Text)
7-
- CLIP ViT-L/14 - Embeddings Model
6+
- CLIP ViT-L/14 - Embeddings Model for both images and text
87
- Live Update
98

109
## Make sure Postgres and Qdrant are running
@@ -27,16 +26,6 @@ curl -X PUT 'http://localhost:6333/collections/image_search' \
2726
}'
2827
```
2928

30-
## Run Ollama
31-
```
32-
ollama pull gemma3
33-
ollama serve
34-
```
35-
36-
### Place your images in the `img` directory.
37-
- No need to update manually. CocoIndex will automatically update the index as new images are added to the directory.
38-
39-
4029
## Run Backend
4130
- Install dependencies:
4231
```

examples/image_search/main.py

Lines changed: 38 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,57 @@
11
from dotenv import load_dotenv
2+
23
import cocoindex
34
import datetime
5+
import functools
6+
import io
47
import os
5-
import requests
6-
import base64
8+
import torch
9+
10+
from typing import Literal
711
from fastapi import FastAPI, Query
812
from fastapi.middleware.cors import CORSMiddleware
913
from fastapi.staticfiles import StaticFiles
1014
from qdrant_client import QdrantClient
1115

12-
OLLAMA_URL = "http://localhost:11434/api/generate"
13-
OLLAMA_MODEL = "gemma3"
16+
from PIL import Image
17+
from transformers import CLIPModel, CLIPProcessor
18+
19+
1420
QDRANT_GRPC_URL = os.getenv("QDRANT_GRPC_URL", "http://localhost:6334/")
21+
CLIP_MODEL_NAME = "openai/clip-vit-large-patch14"
22+
23+
@functools.cache
24+
def get_clip_model() -> tuple[CLIPModel, CLIPProcessor]:
25+
model = CLIPModel.from_pretrained(CLIP_MODEL_NAME)
26+
processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
27+
return model, processor
1528

16-
# 1. Extract caption from image using Ollama vision model
17-
@cocoindex.op.function(cache=True, behavior_version=1)
18-
def get_image_caption(img_bytes: bytes) -> str:
29+
30+
def embed_query(text: str) -> list[float]:
1931
"""
20-
Use Ollama's gemma3 model to extract a detailed caption from an image.
21-
Returns a full-sentence natural language description of the image.
32+
Embed the caption using CLIP model.
2233
"""
23-
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
24-
prompt = (
25-
"Describe this image in one detailed, natural language sentence. "
26-
"Always explicitly name every visible animal species, object, and the main scene. "
27-
"Be specific about the type, color, and any distinguishing features. "
28-
"Avoid generic words like 'animal' or 'creature'—always use the most precise name (e.g., 'elephant', 'cat', 'lion', 'zebra'). "
29-
"If an animal is present, mention its species and what it is doing. "
30-
"For example: 'A large grey elephant standing in a grassy savanna, with trees in the background.'"
31-
)
32-
payload = {
33-
"model": OLLAMA_MODEL,
34-
"prompt": prompt,
35-
"images": [img_b64],
36-
"stream": False,
37-
}
38-
resp = requests.post(OLLAMA_URL, json=payload)
39-
resp.raise_for_status()
40-
result = resp.json()
41-
text = result.get("response", "")
42-
text = text.strip().replace("\n", "").rstrip(".")
43-
return text
34+
model, processor = get_clip_model()
35+
inputs = processor(text=[text], return_tensors="pt", padding=True)
36+
with torch.no_grad():
37+
features = model.get_text_features(**inputs)
38+
return features[0].tolist()
4439

4540

46-
# 2. Embed the caption string
47-
@cocoindex.transform_flow()
48-
def caption_to_embedding(caption: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[list[float]]:
41+
@cocoindex.op.function(cache=True, behavior_version=1, gpu=True)
42+
def embed_image(img_bytes: bytes) -> cocoindex.Vector[cocoindex.Float32, Literal[384]]:
4943
"""
50-
Embed the caption using a CLIP model.
51-
This is shared logic between indexing and querying.
44+
Convert image to embedding using CLIP model.
5245
"""
53-
return caption.transform(
54-
cocoindex.functions.SentenceTransformerEmbed(
55-
model="clip-ViT-L-14",
56-
)
57-
)
46+
model, processor = get_clip_model()
47+
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
48+
inputs = processor(images=image, return_tensors="pt")
49+
with torch.no_grad():
50+
features = model.get_image_features(**inputs)
51+
return features[0].tolist()
52+
5853

59-
# 3. CocoIndex flow: Ingest images, extract captions, embed, export to Qdrant
54+
# CocoIndex flow: Ingest images, extract captions, embed, export to Qdrant
6055
@cocoindex.flow_def(name="ImageObjectEmbedding")
6156
def image_object_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope):
6257
data_scope["images"] = flow_builder.add_source(
@@ -65,12 +60,10 @@ def image_object_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope:
6560
)
6661
img_embeddings = data_scope.add_collector()
6762
with data_scope["images"].row() as img:
68-
img["caption"] = img["content"].transform(get_image_caption)
69-
img["embedding"] = caption_to_embedding(img["caption"])
63+
img["embedding"] = img["content"].transform(embed_image)
7064
img_embeddings.collect(
7165
id=cocoindex.GeneratedField.UUID,
7266
filename=img["filename"],
73-
caption=img["caption"],
7467
embedding=img["embedding"],
7568
)
7669
img_embeddings.export(
@@ -111,7 +104,7 @@ def startup_event():
111104
@app.get("/search")
112105
def search(q: str = Query(..., description="Search query"), limit: int = Query(5, description="Number of results")):
113106
# Get the embedding for the query
114-
query_embedding = caption_to_embedding.eval(q)
107+
query_embedding = embed_query(q)
115108

116109
# Search in Qdrant
117110
search_results = app.state.qdrant_client.search(

examples/image_search/pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@ name = "image-search"
33
version = "0.1.0"
44
description = "Simple example for cocoindex: build embedding index based on images."
55
requires-python = ">=3.11"
6-
dependencies = ["cocoindex>=0.1.42", "python-dotenv>=1.0.1", "fastapi>=0.100.0"]
6+
dependencies = [
7+
"cocoindex>=0.1.42",
8+
"python-dotenv>=1.0.1",
9+
"fastapi>=0.100.0",
10+
"torch>=2.0.0",
11+
"transformers>=4.29.0",
12+
]
713

814
[tool.setuptools]
915
packages = []

0 commit comments

Comments
 (0)