Skip to content

Commit a843f13

Browse files
committed
Add Multimodal RAG with Elasticsearch Gotham City tutorial
1 parent 7a8db64 commit a843f13

File tree

13 files changed

+866
-0
lines changed

13 files changed

+866
-0
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Building a Multimodal RAG Pipeline with Elasticsearch: The Story of Gotham City
2+
3+
This repository contains the code for implementing a Multimodal Retrieval-Augmented Generation (RAG) system using Elasticsearch. The system processes and analyzes different types of evidence (images, audio, text, and depth maps) to solve a crime in Gotham City.
4+
5+
## Overview
6+
7+
The pipeline demonstrates how to:
8+
- Generate unified embeddings for multiple modalities using ImageBind
9+
- Store and search vectors efficiently in Elasticsearch
10+
- Analyze evidence using GPT-4 to generate forensic reports
11+
12+
## Prerequisites
13+
14+
- Python 3.10+
15+
- Elasticsearch cluster (cloud or local)
16+
- OpenAI API key
17+
- 8GB+ RAM
18+
- GPU (optional but recommended)
19+
20+
## Quick Start
21+
22+
1. **Setup Environment**
23+
```bash
24+
# Create and activate virtual environment
25+
python -m venv env_mmrag
26+
source env_mmrag/bin/activate # Unix/MacOS
27+
# or
28+
.\env_mmrag\Scripts\activate # Windows
29+
30+
# Install dependencies
31+
pip install -r requirements.txt
32+
```
33+
34+
2. **Configure Credentials**
35+
Create a `.env` file:
36+
```env
37+
ELASTICSEARCH_ENDPOINT="your-elasticsearch-endpoint"
38+
ELASTIC_API_KEY="your-elastic-api-key"
39+
OPENAI_API_KEY="your-openai-api-key"
40+
```
41+
42+
3. **Run the Demo**
43+
```bash
44+
# Verify file structure
45+
python stages/01-stage/files_check.py
46+
47+
# Generate embeddings
48+
python stages/02-stage/test_embedding_generation.py
49+
50+
# Index content
51+
python stages/03-stage/index_all_modalities.py
52+
53+
# Search and analyze
54+
python stages/04-stage/rag_crime_analyze.py
55+
```
56+
57+
## Project Structure
58+
59+
```
60+
├── README.md
61+
├── requirements.txt
62+
├── src/
63+
│ ├── embedding_generator.py # ImageBind wrapper
64+
│ ├── elastic_manager.py # Elasticsearch operations
65+
│ └── llm_analyzer.py # GPT-4 integration
66+
├── stages/
67+
│ ├── 01-stage/ # File organization
68+
│ ├── 02-stage/ # Embedding generation
69+
│ ├── 03-stage/ # Elasticsearch indexing/search
70+
│ └── 04-stage/ # Evidence analysis
71+
└── data/ # Sample data
72+
├── images/
73+
├── audios/
74+
├── texts/
75+
└── depths/
76+
```
77+
78+
## Sample Data
79+
80+
The repository includes sample evidence files:
81+
- Images: Crime scene photos and security camera footage
82+
- Audio: Suspicious sound recordings
83+
- Text: Mysterious notes and riddles
84+
- Depth Maps: 3D scene captures
85+
86+
## How It Works
87+
88+
1. **Evidence Collection**: Files are organized by modality in the `data/` directory
89+
2. **Embedding Generation**: ImageBind converts each piece of evidence into a 1024-dimensional vector
90+
3. **Vector Storage**: Elasticsearch stores embeddings with metadata for efficient retrieval
91+
4. **Similarity Search**: New evidence is compared against the database using k-NN search
92+
5. **Analysis**: GPT-4 analyzes the connections between evidence to identify suspects
93+
94+
## License
95+
96+
This project is licensed under the Elastic License 2.0.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
elasticsearch>=8.11.0
2+
torch>=2.0.0
3+
torchvision>=0.15.0
4+
torchaudio>=2.0.0
5+
imagebind @ git+https://github.com/facebookresearch/ImageBind.git
6+
openai>=1.0.0
7+
python-dotenv>=1.0.0
8+
numpy>=1.24.0
9+
pillow>=10.0.0
10+
opencv-python>=4.8.0
11+
librosa>=0.10.0
12+
matplotlib>=3.7.0
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from elasticsearch import Elasticsearch, helpers
2+
import base64
3+
import os
4+
from dotenv import load_dotenv
5+
import numpy as np
6+
7+
class ElasticsearchManager:
8+
"""Manages multimodal operations in Elasticsearch"""
9+
10+
def __init__(self):
11+
load_dotenv() # Load variables from .env
12+
self.es = self._connect_elastic()
13+
self.index_name = "multimodal_content"
14+
self._setup_index()
15+
16+
def _connect_elastic(self):
17+
"""Connects to Elasticsearch"""
18+
return Elasticsearch(
19+
os.getenv("ELASTICSEARCH_ENDPOINT"), # Elasticsearch endpoint
20+
api_key=os.getenv("ELASTIC_API_KEY")
21+
)
22+
23+
def _setup_index(self):
24+
"""Sets up the index if it doesn't exist"""
25+
if not self.es.indices.exists(index=self.index_name):
26+
mapping = {
27+
"mappings": {
28+
"properties": {
29+
"embedding": {
30+
"type": "dense_vector",
31+
"dims": 1024,
32+
"index": True,
33+
"similarity": "cosine"
34+
},
35+
"modality": {"type": "keyword"},
36+
"content": {"type": "binary"},
37+
"description": {"type": "text"},
38+
"metadata": {"type": "object"},
39+
"content_path": {"type": "text"}
40+
}
41+
}
42+
}
43+
self.es.indices.create(index=self.index_name, body=mapping)
44+
45+
def index_content(self, embedding, modality, content=None, description="", metadata=None, content_path=None):
46+
"""Indexes multimodal content"""
47+
doc = {
48+
"embedding": embedding.tolist(),
49+
"modality": modality,
50+
"description": description,
51+
"metadata": metadata or {},
52+
"content_path": content_path
53+
}
54+
55+
if content:
56+
doc["content"] = base64.b64encode(content).decode() if isinstance(content, bytes) else content
57+
58+
return self.es.index(index=self.index_name, document=doc)
59+
60+
def search_similar(self, query_embedding, modality=None, k=5):
61+
"""Searches for similar contents"""
62+
query = {
63+
"knn": {
64+
"field": "embedding",
65+
"query_vector": query_embedding.tolist(),
66+
"k": k,
67+
"num_candidates": 100,
68+
"filter": [{"term": {"modality": modality}}] if modality else []
69+
}
70+
}
71+
72+
try:
73+
response = self.es.search(
74+
index=self.index_name,
75+
query=query,
76+
size=k
77+
)
78+
79+
# Return both source data and score for each hit
80+
return [{
81+
**hit["_source"],
82+
"score": hit["_score"]
83+
} for hit in response["hits"]["hits"]]
84+
85+
except Exception as e:
86+
print(f"Error: processing search_evidence: {str(e)}")
87+
return "Error generating search evidence"
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import os
2+
import cv2
3+
from io import BytesIO
4+
import logging
5+
from torch.hub import download_url_to_file
6+
7+
import torch
8+
import numpy as np
9+
from PIL import Image
10+
from imagebind import data
11+
from imagebind.models import imagebind_model
12+
13+
from torchvision import transforms
14+
15+
16+
logging.basicConfig(level=logging.INFO)
17+
logger = logging.getLogger(__name__)
18+
19+
class EmbeddingGenerator:
20+
"""Generates multimodal embeddings using ImageBind"""
21+
22+
def __init__(self, device="cpu"):
23+
self.device = device
24+
self.model = self._load_model()
25+
26+
def _load_model(self):
27+
"""Initialize and test the ImageBind model."""
28+
checkpoint_path = "~/.cache/torch/checkpoints/imagebind_huge.pth"
29+
os.makedirs(os.path.expanduser("~/.cache/torch/checkpoints"), exist_ok=True)
30+
31+
if not os.path.exists(os.path.expanduser(checkpoint_path)):
32+
print("Downloading ImageBind weights...")
33+
download_url_to_file(
34+
"https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth",
35+
os.path.expanduser(checkpoint_path)
36+
)
37+
38+
try:
39+
checkpoint_path = os.path.expanduser("~/.cache/torch/checkpoints/imagebind_huge.pth")
40+
41+
# Check if file exists
42+
if not os.path.exists(checkpoint_path):
43+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
44+
45+
model = imagebind_model.imagebind_huge(pretrained=False)
46+
model.load_state_dict(torch.load(checkpoint_path))
47+
model.eval().to(self.device)
48+
49+
# Quick test with empty text input
50+
logger.info("Testing model with sample input...")
51+
test_input = data.load_and_transform_text([""], self.device)
52+
with torch.no_grad():
53+
_ = model({"text": test_input})
54+
55+
logger.info("🤖 ImageBind model initialized successfully")
56+
return model
57+
except Exception as e:
58+
logger.error(f"🚨 Model initialization failed: {str(e)}")
59+
raise
60+
61+
def generate_embedding(self, input_data, modality):
62+
"""Generates embedding for different modalities"""
63+
processors = {
64+
"vision": lambda x: data.load_and_transform_vision_data(x, self.device),
65+
"audio": lambda x: data.load_and_transform_audio_data(x, self.device),
66+
"text": lambda x: data.load_and_transform_text(x, self.device),
67+
"depth": self.process_depth
68+
}
69+
70+
try:
71+
# Input type verification
72+
if not isinstance(input_data, list):
73+
raise ValueError(f"Input data must be a list. Received: {type(input_data)}")
74+
75+
# Convert input data to a tensor format that the model can process
76+
# For images: [batch_size, channels, height, width]
77+
# For audio: [batch_size, channels, time]
78+
# For text: [batch_size, sequence_length]
79+
inputs = {modality: processors[modality](input_data)}
80+
with torch.no_grad():
81+
embedding = self.model(inputs)[modality]
82+
return embedding.squeeze(0).cpu().numpy()
83+
except Exception as e:
84+
logger.error(f"Error generating {modality} embedding: {str(e)}", exc_info=True)
85+
raise
86+
87+
88+
def process_vision(self, image_path):
89+
"""Processes image"""
90+
return data.load_and_transform_vision_data([image_path], self.device)
91+
92+
def process_audio(self, audio_path):
93+
"""Processes audio"""
94+
return data.load_and_transform_audio_data([audio_path], self.device)
95+
96+
def process_text(self, text):
97+
"""Processes text"""
98+
return data.load_and_transform_text([text], self.device)
99+
100+
def process_depth(self, depth_paths, device="cpu"):
101+
"""Custom processing for depth maps"""
102+
try:
103+
# Check file existence
104+
for path in depth_paths:
105+
if not os.path.exists(path):
106+
raise FileNotFoundError(f"Depth map file not found: {path}")
107+
108+
# Load and transform
109+
depth_images = [Image.open(path).convert("L") for path in depth_paths]
110+
111+
transform = transforms.Compose([
112+
transforms.Resize((224, 224)),
113+
transforms.ToTensor(),
114+
])
115+
116+
return torch.stack([transform(img) for img in depth_images]).to(device)
117+
118+
except Exception as e:
119+
logger.error(f"🚨 - Error processing depth map: {str(e)}")
120+
raise

0 commit comments

Comments
 (0)