-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[Example]: Multimodal RAG with LanceDB #2498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 29 commits
73d7cd8
790c075
e9ace62
b01e6ae
6fa1558
9a64331
9abcf20
ae6ce02
6f12a23
ba7256d
49bc0df
1a13169
5af6352
a775fbe
a38c17a
0565153
0efa1f8
0c4c8c3
007d73e
1b8852f
700654a
3b1f5b3
22cdc7c
2d28648
5f78b59
b5edd0c
f458368
67eba55
f5be2f4
f3efc3e
ae893fa
5843f49
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# LanceDB Multimodal (E‑commerce RAG) | ||
|
||
This example turns LanceDB into a simple, local product catalog for multimodal RAG. It fetches products from a live API, stores images and vector embeddings together, and exposes a tool the agent uses to perform semantic search with SQL-like filters. Results render as a quick collage for visual feedback. | ||
|
||
- Object storage: store images next to vectors | ||
- Embedded: fast local prototyping with transactional storage | ||
|
||
Demonstrates: | ||
|
||
- [tools](../tools.md) | ||
- [agent dependencies](../dependencies.md) | ||
- Vector search with LanceDB (CLIP embeddings) | ||
- Object storage of product images in LanceDB | ||
- SQL-like metadata filtering (category, price) | ||
|
||
|
||
## Installation | ||
|
||
```bash | ||
pip install lancedb sentence-transformers torch httpx pandas Pillow | ||
``` | ||
|
||
Set your Google API key (agent text generation): | ||
|
||
```bash | ||
export GOOGLE_API_KEY=your_api_key_here | ||
``` | ||
|
||
## Usage | ||
|
||
Build the product database from the live API: | ||
|
||
```bash | ||
uv run -m pydantic_ai_examples.lancedb_multimodal build | ||
``` | ||
|
||
Search for products (hybrid search): | ||
|
||
```bash | ||
uv run -m pydantic_ai_examples.lancedb_multimodal search "a cool t-shirt in men's clothing under 20 dollars" | ||
``` | ||
|
||
```bash | ||
uv run -m pydantic_ai_examples.lancedb_multimodal search "An external SSD with 1TB or more storage" | ||
``` | ||
|
||
## Architecture | ||
|
||
### Data Schema (Pydantic model) | ||
|
||
```python {test="skip"} | ||
from lancedb.pydantic import LanceModel, Vector | ||
|
||
|
||
class ProductVector(LanceModel): | ||
id: int | ||
title: str | ||
price: float | ||
description: str | ||
category: str | ||
image: bytes | ||
embedding: Vector(512) # CLIP 'clip-ViT-B-32' embedding size | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,277 @@ | ||
"""Multimodal RAG with LanceDB. | ||
|
||
What this does: | ||
- Fetches a product catalog from a live API (fakestoreapi.com). | ||
- Embeds product descriptions using a CLIP model (clip-ViT-B-32) and stores them in LanceDB. | ||
- Implements a RAG agent with a `find_products` tool that can: | ||
1. Perform semantic search (e.g., "something for the cold weather"). | ||
2. Perform metadata filtering (e.g., "show me all electronics"). | ||
(e.g., "a cool t-shirt" from the "men's clothing" category under $20). | ||
- Generates and displays an image collage of the results for instant visual feedback. | ||
- Creates logfire tracing dashboard if api key is set | ||
|
||
Install dependencies: | ||
pip install lancedb sentence-transformers torch httpx pandas Pillow logfire[httpx] | ||
|
||
Set your Google API key (for the agent's text generation): | ||
export GOOGLE_API_KEY=your_api_key_here | ||
|
||
Usage: | ||
# First, build the product database from the live API | ||
python lancedb_multimodal.py build | ||
|
||
# Then, ask for a recommendation: | ||
python lancedb_multimodal.py search "a cool t-shirt in men's clothing under 20 dollars" | ||
""" | ||
|
||
import asyncio | ||
import io | ||
import os | ||
import sys | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Optional, cast | ||
from pydantic_ai import Agent, RunContext | ||
|
||
import httpx | ||
|
||
try: | ||
import lancedb | ||
import logfire | ||
from lancedb.pydantic import LanceModel, Vector | ||
from PIL import Image | ||
from sentence_transformers import SentenceTransformer | ||
except ImportError: | ||
print( | ||
"""Missing dependencies. To run this example, please install the required packages by running: | ||
pip install lancedb sentence-transformers torch httpx pandas Pillow logfire[httpx]""" | ||
) | ||
sys.exit(0) | ||
AyushExel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
# ruff: noqa | ||
# pyright: reportMissingImports=false | ||
# pyright: reportInvalidTypeForm=false | ||
# pyright: reportUnknownVariableType=false | ||
# pyright: reportUnknownMemberType=false | ||
# pyright: reportUnknownArgumentType=false | ||
# pyright: reportUntypedBaseClass=false | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why so many? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to be clear, we're happy to accept an example from you, but this is just a big middle finger to how we operate. it's either saying:
Either lanceDB is typesafe, in which case remove these. or LanceDB is not type-safe in which case the example probably doesn't belong in our docs! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added these to figure out the deps part while fighting CI.. Will revert back |
||
|
||
# 'if-token-present' means nothing will be sent if you don't have logfire configured | ||
logfire.configure(send_to_logfire='if-token-present') | ||
logfire.instrument_pydantic_ai() | ||
logfire.instrument_httpx(capture_all=True) | ||
|
||
DATA_DIR = Path('./lancedb_data/products') | ||
|
||
|
||
# LanceDB schema | ||
class ProductVector(LanceModel): | ||
id: int | ||
title: str | ||
price: float | ||
description: str | ||
category: str | ||
image: bytes | ||
embedding: Vector(512) # CLIP 'clip-ViT-B-32' model produces 512-dim vectors | ||
|
||
|
||
@dataclass | ||
class Deps: | ||
db: lancedb.DBConnection | ||
embedding_model: SentenceTransformer | ||
|
||
|
||
agent = Agent( | ||
'google-gla:gemini-2.5-flash', | ||
deps_type=Deps, | ||
system_prompt=( | ||
'You are a helpful AI Shopping Assistant. Your goal is to help users find the perfect product ' | ||
'by using the `find_products` tool. You can search by a text query, filter by category and ' | ||
'price, or combine all three for a powerful search. ' | ||
'After getting the results, present them clearly to the user and mention that you are ' | ||
'displaying a collage of the findings.' | ||
), | ||
) | ||
|
||
|
||
async def _generate_image_collage(image_bytes_list: list[bytes], title: str): | ||
if not image_bytes_list: | ||
return | ||
|
||
with logfire.span( | ||
'generate image collage', num_images=len(image_bytes_list), title=title | ||
): | ||
images = [ | ||
Image.open(io.BytesIO(image_bytes)) for image_bytes in image_bytes_list | ||
] | ||
|
||
if not images: | ||
print('Could not create any images from bytes to create a collage.') | ||
return | ||
|
||
widths, heights = zip(*(i.size for i in images)) | ||
total_width = sum(widths) | ||
max_height = max(heights) | ||
collage = Image.new('RGB', (total_width, max_height)) | ||
x_offset = 0 | ||
for img in images: | ||
collage.paste(img, (x_offset, 0)) | ||
x_offset += img.width | ||
|
||
collage.show(title=title) | ||
|
||
|
||
@agent.tool | ||
async def find_products( | ||
context: RunContext[Deps], | ||
query: Optional[str] = None, | ||
category: Optional[str] = None, | ||
min_price: Optional[float] = None, | ||
max_price: Optional[float] = None, | ||
top_k: int = 4, | ||
) -> str: | ||
"""Finds products using semantic search, metadata filtering, or both (hybrid search).""" | ||
table = context.deps.db.open_table('products') | ||
|
||
query_embedding = None | ||
if query and query.strip(): | ||
with logfire.span('encode semantic query', query=query): | ||
query_embedding = context.deps.embedding_model.encode( | ||
query, convert_to_tensor=False | ||
) | ||
|
||
# Use vector search when a semantic query is provided, otherwise fall back to metadata-only search | ||
searcher = ( | ||
table.search(query_embedding) if query_embedding is not None else table.search() | ||
) | ||
|
||
conditions = [] | ||
if category: | ||
safe_category = category.replace("'", "''") | ||
conditions.append(f"category = '{safe_category}'") | ||
if min_price is not None: | ||
conditions.append(f'price >= {min_price}') | ||
if max_price is not None: | ||
conditions.append(f'price <= {max_price}') | ||
|
||
if conditions: | ||
where_clause = ' AND '.join(conditions) | ||
searcher = searcher.where(where_clause) | ||
|
||
with logfire.span('search products', top_k=top_k, conditions=conditions): | ||
results_df = searcher.limit(top_k).to_pandas() | ||
|
||
if results_df.empty: | ||
return 'No products found matching your criteria.' | ||
|
||
logfire.info('Displaying collage of results ({n})', n=len(results_df)) | ||
await _generate_image_collage( | ||
cast(list[bytes], results_df['image'].tolist()), title='Product Results' | ||
) | ||
|
||
# Don't return image byte string in the text response | ||
results_df_no_image = results_df.drop(columns=['image']) | ||
results_json = results_df_no_image.to_json(orient='records') | ||
return f'Found {len(results_df)} products.\n Product details: {results_json}' | ||
|
||
|
||
async def build_product_database(): | ||
logfire.info('Building product database from Fake Store API...') | ||
DATA_DIR.mkdir(parents=True, exist_ok=True) | ||
|
||
# Fetch product data from the API | ||
with logfire.span('fetch product catalog', url='https://fakestoreapi.com/products'): | ||
async with httpx.AsyncClient() as client: | ||
response = await client.get('https://fakestoreapi.com/products', timeout=30) | ||
response.raise_for_status() | ||
products_data = response.json() | ||
|
||
# Initialize LanceDB and Embedding Model | ||
with logfire.span('initialize db and model'): | ||
db = lancedb.connect(DATA_DIR) | ||
embedding_model = SentenceTransformer('clip-ViT-B-32') | ||
|
||
# Create embeddings for product descriptions and fetch image bytes | ||
logfire.info( | ||
'Creating embeddings and fetching images for {n} products...', | ||
n=len(products_data), | ||
) | ||
product_vectors: list[ProductVector] = [] | ||
async with httpx.AsyncClient() as client: | ||
for p_data in products_data: | ||
with logfire.span( | ||
'embed + download image', | ||
product_id=p_data.get('id'), | ||
category=p_data.get('category'), | ||
): | ||
content_to_embed = f'Product Name: {p_data["title"]}\nCategory: {p_data["category"]}\nDescription: {p_data["description"]}' | ||
embedding = embedding_model.encode( | ||
content_to_embed, convert_to_tensor=False | ||
) | ||
|
||
try: | ||
image_response = await client.get(p_data['image'], timeout=30) | ||
image_response.raise_for_status() | ||
image_bytes = image_response.content | ||
p_data['image'] = image_bytes | ||
product_vectors.append(ProductVector(**p_data, embedding=embedding)) | ||
except httpx.HTTPStatusError as e: | ||
logfire.warning( | ||
'Skipping product due to image download error: {e}', e=str(e) | ||
) | ||
|
||
# Create a LanceDB table and add the data | ||
with logfire.span( | ||
'create lancedb table and add rows', num_rows=len(product_vectors) | ||
): | ||
table = db.create_table('products', schema=ProductVector, mode='overwrite') | ||
table.add(product_vectors) | ||
|
||
logfire.info( | ||
'Successfully indexed {n} products into LanceDB.', n=len(product_vectors) | ||
) | ||
|
||
|
||
async def run_search(query: str): | ||
db = lancedb.connect(DATA_DIR) | ||
embedding_model = SentenceTransformer('clip-ViT-B-32') | ||
deps = Deps(db=db, embedding_model=embedding_model) | ||
|
||
logfire.info('User Query: {query}', query=query) | ||
result = await agent.run(query, deps=deps) | ||
print(result.output) | ||
|
||
|
||
def main(): | ||
if 'search' in sys.argv and not os.getenv('GOOGLE_API_KEY'): | ||
raise ValueError( | ||
"GOOGLE_API_KEY environment variable is required for the 'search' action." | ||
) | ||
|
||
if len(sys.argv) < 2: | ||
print( | ||
'Usage:\n' | ||
' python lancedb_multimodal.py build\n' | ||
' python lancedb_multimodal.py search <query>', | ||
file=sys.stderr, | ||
) | ||
sys.exit(1) | ||
|
||
action = sys.argv[1] | ||
if action == 'build': | ||
asyncio.run(build_product_database()) | ||
elif action == 'search': | ||
search_query = ( | ||
' '.join(sys.argv[2:]) | ||
if len(sys.argv) > 2 | ||
else 'An external SSD with 1TB or more storage' | ||
) | ||
asyncio.run(run_search(search_query)) | ||
else: | ||
print(f'Unknown action: {action}', file=sys.stderr) | ||
sys.exit(1) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Uh oh!
There was an error while loading. Please reload this page.