Skip to content

Commit 1c4caca

Browse files
committed
add mem0 generate rag file
1 parent 409cab4 commit 1c4caca

File tree

1 file changed

+234
-0
lines changed

1 file changed

+234
-0
lines changed
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
'''
2+
Modify the code from the mem0 project, Original file link
3+
https://github.com/mem0ai/mem0/blob/main/evaluation/src/rag.py
4+
'''
5+
6+
import json
7+
import os
8+
import numpy as np
9+
import time
10+
from collections import defaultdict
11+
12+
import numpy as np
13+
import tiktoken
14+
import argparse
15+
16+
from dotenv import load_dotenv
17+
from jinja2 import Template
18+
from openai import OpenAI
19+
from tqdm import tqdm
20+
21+
load_dotenv()
22+
23+
PROMPT = """
24+
# Question:
25+
{{QUESTION}}
26+
27+
# Context:
28+
{{CONTEXT}}
29+
30+
# Short answer:
31+
"""
32+
33+
TECHNIQUES = ["mem0", "rag", "langmem", "zep", "openai"]
34+
METHODS = ["add", "search"]
35+
36+
class RAGManager:
37+
def __init__(self, data_path="data/locomo/locomo10_rag.json", chunk_size=500, k=2):
38+
self.model = os.getenv("MODEL")
39+
self.client = OpenAI()
40+
self.data_path = data_path
41+
self.chunk_size = chunk_size
42+
self.k = k
43+
44+
def generate_response(self, question, context):
45+
template = Template(PROMPT)
46+
prompt = template.render(CONTEXT=context, QUESTION=question)
47+
48+
max_retries = 3
49+
retries = 0
50+
51+
while retries <= max_retries:
52+
try:
53+
t1 = time.time()
54+
response = self.client.chat.completions.create(
55+
model=self.model,
56+
messages=[
57+
{
58+
"role": "system",
59+
"content": "You are a helpful assistant that can answer "
60+
"questions based on the provided context."
61+
"If the question involves timing, use the conversation date for reference."
62+
"Provide the shortest possible answer."
63+
"Use words directly from the conversation when possible."
64+
"Avoid using subjects in your answer.",
65+
},
66+
{"role": "user", "content": prompt},
67+
],
68+
temperature=0,
69+
)
70+
t2 = time.time()
71+
# return response.choices[0].message.content.strip(), t2 - t1
72+
if response and response.choices:
73+
content = response.choices[0].message.content
74+
if content is not None:
75+
return content.strip(), t2 - t1
76+
else:
77+
return "No content returned", t2 - t1
78+
print(f"❎ No content returned!")
79+
else:
80+
return "Empty response", t2 - t1
81+
except Exception as e:
82+
retries += 1
83+
if retries > max_retries:
84+
raise e
85+
time.sleep(1) # Wait before retrying
86+
87+
def clean_chat_history(self, chat_history):
88+
cleaned_chat_history = ""
89+
for c in chat_history:
90+
cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: " f"{c['text']}\n"
91+
92+
return cleaned_chat_history
93+
94+
def calculate_embedding(self, document):
95+
response = self.client.embeddings.create(model=os.getenv("EMBEDDING_MODEL"), input=document)
96+
return response.data[0].embedding
97+
98+
def calculate_similarity(self, embedding1, embedding2):
99+
return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
100+
101+
def search(self, query, chunks, embeddings, k=1):
102+
"""
103+
Search for the top-k most similar chunks to the query.
104+
105+
Args:
106+
query: The query string
107+
chunks: List of text chunks
108+
embeddings: List of embeddings for each chunk
109+
k: Number of top chunks to return (default: 1)
110+
111+
Returns:
112+
combined_chunks: The combined text of the top-k chunks
113+
search_time: Time taken for the search
114+
"""
115+
t1 = time.time()
116+
query_embedding = self.calculate_embedding(query)
117+
similarities = [self.calculate_similarity(query_embedding, embedding) for embedding in embeddings]
118+
119+
# Get indices of top-k most similar chunks
120+
if k == 1:
121+
# Original behavior - just get the most similar chunk
122+
top_indices = [np.argmax(similarities)]
123+
else:
124+
# Get indices of top-k chunks
125+
top_indices = np.argsort(similarities)[-k:][::-1]
126+
127+
# Combine the top-k chunks
128+
combined_chunks = "\n<->\n".join([chunks[i] for i in top_indices])
129+
130+
t2 = time.time()
131+
return combined_chunks, t2 - t1
132+
133+
def create_chunks(self, chat_history, chunk_size=500):
134+
"""
135+
Create chunks using tiktoken for more accurate token counting
136+
"""
137+
# Get the encoding for the model
138+
encoding = tiktoken.encoding_for_model(os.getenv("EMBEDDING_MODEL"))
139+
140+
documents = self.clean_chat_history(chat_history)
141+
142+
if chunk_size == -1:
143+
return [documents], []
144+
145+
chunks = []
146+
147+
# Encode the document
148+
tokens = encoding.encode(documents)
149+
150+
# Split into chunks based on token count
151+
for i in range(0, len(tokens), chunk_size):
152+
chunk_tokens = tokens[i : i + chunk_size]
153+
chunk = encoding.decode(chunk_tokens)
154+
chunks.append(chunk)
155+
156+
embeddings = []
157+
for chunk in chunks:
158+
embedding = self.calculate_embedding(chunk)
159+
embeddings.append(embedding)
160+
161+
return chunks, embeddings
162+
163+
def process_all_conversations(self, output_file_path):
164+
with open(self.data_path, "r") as f:
165+
data = json.load(f)
166+
167+
FINAL_RESULTS = defaultdict(list)
168+
for key, value in tqdm(data.items(), desc="Processing conversations"):
169+
chat_history = value["conversation"]
170+
questions = value["question"]
171+
172+
chunks, embeddings = self.create_chunks(chat_history, self.chunk_size)
173+
174+
for item in tqdm(questions, desc="Answering questions", leave=False):
175+
question = item["question"]
176+
answer = item.get("answer", "")
177+
category = item["category"]
178+
179+
if self.chunk_size == -1:
180+
context = chunks[0]
181+
search_time = 0
182+
else:
183+
context, search_time = self.search(question, chunks, embeddings, k=self.k)
184+
response, response_time = self.generate_response(question, context)
185+
186+
FINAL_RESULTS[key].append(
187+
{
188+
"question": question,
189+
"answer": answer,
190+
"category": category,
191+
"context": context,
192+
"response": response,
193+
"search_time": search_time,
194+
"response_time": response_time,
195+
}
196+
)
197+
with open(output_file_path, "w+") as f:
198+
json.dump(FINAL_RESULTS, f, indent=4)
199+
200+
# Save results
201+
with open(output_file_path, "w+") as f:
202+
json.dump(FINAL_RESULTS, f, indent=4)
203+
204+
205+
class Experiment:
206+
def __init__(self, technique_type, chunk_size):
207+
self.technique_type = technique_type
208+
self.chunk_size = chunk_size
209+
210+
def run(self):
211+
print(f"Running experiment with technique: {self.technique_type}, chunk size: {self.chunk_size}")
212+
213+
214+
def main():
215+
216+
parser = argparse.ArgumentParser(description="Run memory experiments")
217+
parser.add_argument("--technique_type", choices=TECHNIQUES, default="rag", help="Memory technique to use")
218+
parser.add_argument("--chunk_size", type=int, default=500, help="Chunk size for processing")
219+
parser.add_argument("--output_folder", type=str, default="results/", help="Output path for results")
220+
parser.add_argument("--top_k", type=int, default=30, help="Number of top memories to retrieve")
221+
parser.add_argument("--num_chunks", type=int, default=2, help="Number of chunks to process")
222+
223+
args = parser.parse_args()
224+
225+
if args.technique_type == "rag":
226+
output_file_path = os.path.join(args.output_folder, f"rag_results_{args.chunk_size}_k{args.num_chunks}.json")
227+
rag_manager = RAGManager(data_path="data/locomo/locomo10_rag.json", chunk_size=args.chunk_size, k=args.num_chunks)
228+
rag_manager.process_all_conversations(output_file_path)
229+
230+
if __name__ =="__main__":
231+
start = time.time()
232+
main()
233+
end = time.time()
234+
print(f"Execution time is:{end - start}")

0 commit comments

Comments
 (0)