Skip to content

Commit 30cb441

Browse files
committed
update mem0_rag info
1 parent 977cb12 commit 30cb441

File tree

1 file changed

+40
-32
lines changed

1 file changed

+40
-32
lines changed

evaluation/scripts/locomo/mem0_rag.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,30 @@
1-
'''
2-
Modify the code from the mem0 project, Original file link
3-
https://github.com/mem0ai/mem0/blob/main/evaluation/src/rag.py
4-
'''
1+
"""
2+
Modify the code from the mem0 project, Original file link is: https://github.com/mem0ai/mem0/blob/main/evaluation/src/rag.py
3+
"""
54

5+
import argparse
66
import json
77
import os
8-
import numpy as np
98
import time
9+
1010
from collections import defaultdict
1111

1212
import numpy as np
1313
import tiktoken
14-
import argparse
1514

1615
from dotenv import load_dotenv
1716
from jinja2 import Template
1817
from openai import OpenAI
1918
from tqdm import tqdm
2019

20+
2121
load_dotenv()
2222

2323
PROMPT = """
24-
# Question:
24+
# Question:
2525
{{QUESTION}}
2626
27-
# Context:
27+
# Context:
2828
{{CONTEXT}}
2929
3030
# Short answer:
@@ -33,6 +33,7 @@
3333
TECHNIQUES = ["mem0", "rag", "langmem", "zep", "openai"]
3434
METHODS = ["add", "search"]
3535

36+
3637
class RAGManager:
3738
def __init__(self, data_path="data/locomo/locomo10_rag.json", chunk_size=500, k=2):
3839
self.model = os.getenv("MODEL")
@@ -68,14 +69,13 @@ def generate_response(self, question, context):
6869
temperature=0,
6970
)
7071
t2 = time.time()
71-
# return response.choices[0].message.content.strip(), t2 - t1
7272
if response and response.choices:
7373
content = response.choices[0].message.content
7474
if content is not None:
7575
return content.strip(), t2 - t1
7676
else:
7777
return "No content returned", t2 - t1
78-
print(f"❎ No content returned!")
78+
print("❎ No content returned!")
7979
else:
8080
return "Empty response", t2 - t1
8181
except Exception as e:
@@ -87,7 +87,7 @@ def generate_response(self, question, context):
8787
def clean_chat_history(self, chat_history):
8888
cleaned_chat_history = ""
8989
for c in chat_history:
90-
cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: " f"{c['text']}\n"
90+
cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: {c['text']}\n"
9191

9292
return cleaned_chat_history
9393

@@ -96,7 +96,9 @@ def calculate_embedding(self, document):
9696
return response.data[0].embedding
9797

9898
def calculate_similarity(self, embedding1, embedding2):
99-
return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
99+
return np.dot(embedding1, embedding2) / (
100+
np.linalg.norm(embedding1) * np.linalg.norm(embedding2)
101+
)
100102

101103
def search(self, query, chunks, embeddings, k=1):
102104
"""
@@ -114,16 +116,12 @@ def search(self, query, chunks, embeddings, k=1):
114116
"""
115117
t1 = time.time()
116118
query_embedding = self.calculate_embedding(query)
117-
similarities = [self.calculate_similarity(query_embedding, embedding) for embedding in embeddings]
119+
similarities = [
120+
self.calculate_similarity(query_embedding, embedding) for embedding in embeddings
121+
]
118122

119123
# Get indices of top-k most similar chunks
120-
if k == 1:
121-
# Original behavior - just get the most similar chunk
122-
top_indices = [np.argmax(similarities)]
123-
else:
124-
# Get indices of top-k chunks
125-
top_indices = np.argsort(similarities)[-k:][::-1]
126-
124+
top_indices = [np.argmax(similarities)] if k == 1 else np.argsort(similarities)[-k:][::-1]
127125
# Combine the top-k chunks
128126
combined_chunks = "\n<->\n".join([chunks[i] for i in top_indices])
129127

@@ -161,10 +159,10 @@ def create_chunks(self, chat_history, chunk_size=500):
161159
return chunks, embeddings
162160

163161
def process_all_conversations(self, output_file_path):
164-
with open(self.data_path, "r") as f:
162+
with open(self.data_path) as f:
165163
data = json.load(f)
166164

167-
FINAL_RESULTS = defaultdict(list)
165+
final_results = defaultdict(list)
168166
for key, value in tqdm(data.items(), desc="Processing conversations"):
169167
chat_history = value["conversation"]
170168
questions = value["question"]
@@ -183,7 +181,7 @@ def process_all_conversations(self, output_file_path):
183181
context, search_time = self.search(question, chunks, embeddings, k=self.k)
184182
response, response_time = self.generate_response(question, context)
185183

186-
FINAL_RESULTS[key].append(
184+
final_results[key].append(
187185
{
188186
"question": question,
189187
"answer": answer,
@@ -195,11 +193,11 @@ def process_all_conversations(self, output_file_path):
195193
}
196194
)
197195
with open(output_file_path, "w+") as f:
198-
json.dump(FINAL_RESULTS, f, indent=4)
196+
json.dump(final_results, f, indent=4)
199197

200198
# Save results
201199
with open(output_file_path, "w+") as f:
202-
json.dump(FINAL_RESULTS, f, indent=4)
200+
json.dump(final_results, f, indent=4)
203201

204202

205203
class Experiment:
@@ -208,26 +206,36 @@ def __init__(self, technique_type, chunk_size):
208206
self.chunk_size = chunk_size
209207

210208
def run(self):
211-
print(f"Running experiment with technique: {self.technique_type}, chunk size: {self.chunk_size}")
209+
print(
210+
f"Running experiment with technique: {self.technique_type}, chunk size: {self.chunk_size}"
211+
)
212212

213213

214214
def main():
215-
216215
parser = argparse.ArgumentParser(description="Run memory experiments")
217-
parser.add_argument("--technique_type", choices=TECHNIQUES, default="rag", help="Memory technique to use")
216+
parser.add_argument(
217+
"--technique_type", choices=TECHNIQUES, default="rag", help="Memory technique to use"
218+
)
218219
parser.add_argument("--chunk_size", type=int, default=500, help="Chunk size for processing")
219-
parser.add_argument("--output_folder", type=str, default="results/", help="Output path for results")
220+
parser.add_argument(
221+
"--output_folder", type=str, default="results/", help="Output path for results"
222+
)
220223
parser.add_argument("--top_k", type=int, default=30, help="Number of top memories to retrieve")
221224
parser.add_argument("--num_chunks", type=int, default=2, help="Number of chunks to process")
222225

223226
args = parser.parse_args()
224227

225228
if args.technique_type == "rag":
226-
output_file_path = os.path.join(args.output_folder, f"rag_results_{args.chunk_size}_k{args.num_chunks}.json")
227-
rag_manager = RAGManager(data_path="data/locomo/locomo10_rag.json", chunk_size=args.chunk_size, k=args.num_chunks)
229+
output_file_path = os.path.join(
230+
args.output_folder, f"rag_results_{args.chunk_size}_k{args.num_chunks}.json"
231+
)
232+
rag_manager = RAGManager(
233+
data_path="data/locomo/locomo10_rag.json", chunk_size=args.chunk_size, k=args.num_chunks
234+
)
228235
rag_manager.process_all_conversations(output_file_path)
229236

230-
if __name__ =="__main__":
237+
238+
if __name__ == "__main__":
231239
start = time.time()
232240
main()
233241
end = time.time()

0 commit comments

Comments
 (0)