1- '''
2- Modify the code from the mem0 project, Original file link
3- https://github.com/mem0ai/mem0/blob/main/evaluation/src/rag.py
4- '''
1+ """
2+ Modify the code from the mem0 project, Original file link is: https://github.com/mem0ai/mem0/blob/main/evaluation/src/rag.py
3+ """
54
5+ import argparse
66import json
77import os
8- import numpy as np
98import time
9+
1010from collections import defaultdict
1111
1212import numpy as np
1313import tiktoken
14- import argparse
1514
1615from dotenv import load_dotenv
1716from jinja2 import Template
1817from openai import OpenAI
1918from tqdm import tqdm
2019
20+
2121load_dotenv ()
2222
2323PROMPT = """
24- # Question:
24+ # Question:
2525{{QUESTION}}
2626
27- # Context:
27+ # Context:
2828{{CONTEXT}}
2929
3030# Short answer:
3333TECHNIQUES = ["mem0" , "rag" , "langmem" , "zep" , "openai" ]
3434METHODS = ["add" , "search" ]
3535
36+
3637class RAGManager :
3738 def __init__ (self , data_path = "data/locomo/locomo10_rag.json" , chunk_size = 500 , k = 2 ):
3839 self .model = os .getenv ("MODEL" )
@@ -68,14 +69,13 @@ def generate_response(self, question, context):
6869 temperature = 0 ,
6970 )
7071 t2 = time .time ()
71- # return response.choices[0].message.content.strip(), t2 - t1
7272 if response and response .choices :
7373 content = response .choices [0 ].message .content
7474 if content is not None :
7575 return content .strip (), t2 - t1
7676 else :
7777 return "No content returned" , t2 - t1
78- print (f "❎ No content returned!" )
78+ print ("❎ No content returned!" )
7979 else :
8080 return "Empty response" , t2 - t1
8181 except Exception as e :
@@ -87,7 +87,7 @@ def generate_response(self, question, context):
8787 def clean_chat_history (self , chat_history ):
8888 cleaned_chat_history = ""
8989 for c in chat_history :
90- cleaned_chat_history += f"{ c ['timestamp' ]} | { c ['speaker' ]} : " f" { c ['text' ]} \n "
90+ cleaned_chat_history += f"{ c ['timestamp' ]} | { c ['speaker' ]} : { c ['text' ]} \n "
9191
9292 return cleaned_chat_history
9393
@@ -96,7 +96,9 @@ def calculate_embedding(self, document):
9696 return response .data [0 ].embedding
9797
9898 def calculate_similarity (self , embedding1 , embedding2 ):
99- return np .dot (embedding1 , embedding2 ) / (np .linalg .norm (embedding1 ) * np .linalg .norm (embedding2 ))
99+ return np .dot (embedding1 , embedding2 ) / (
100+ np .linalg .norm (embedding1 ) * np .linalg .norm (embedding2 )
101+ )
100102
101103 def search (self , query , chunks , embeddings , k = 1 ):
102104 """
@@ -114,16 +116,12 @@ def search(self, query, chunks, embeddings, k=1):
114116 """
115117 t1 = time .time ()
116118 query_embedding = self .calculate_embedding (query )
117- similarities = [self .calculate_similarity (query_embedding , embedding ) for embedding in embeddings ]
119+ similarities = [
120+ self .calculate_similarity (query_embedding , embedding ) for embedding in embeddings
121+ ]
118122
119123 # Get indices of top-k most similar chunks
120- if k == 1 :
121- # Original behavior - just get the most similar chunk
122- top_indices = [np .argmax (similarities )]
123- else :
124- # Get indices of top-k chunks
125- top_indices = np .argsort (similarities )[- k :][::- 1 ]
126-
124+ top_indices = [np .argmax (similarities )] if k == 1 else np .argsort (similarities )[- k :][::- 1 ]
127125 # Combine the top-k chunks
128126 combined_chunks = "\n <->\n " .join ([chunks [i ] for i in top_indices ])
129127
@@ -161,10 +159,10 @@ def create_chunks(self, chat_history, chunk_size=500):
161159 return chunks , embeddings
162160
163161 def process_all_conversations (self , output_file_path ):
164- with open (self .data_path , "r" ) as f :
162+ with open (self .data_path ) as f :
165163 data = json .load (f )
166164
167- FINAL_RESULTS = defaultdict (list )
165+ final_results = defaultdict (list )
168166 for key , value in tqdm (data .items (), desc = "Processing conversations" ):
169167 chat_history = value ["conversation" ]
170168 questions = value ["question" ]
@@ -183,7 +181,7 @@ def process_all_conversations(self, output_file_path):
183181 context , search_time = self .search (question , chunks , embeddings , k = self .k )
184182 response , response_time = self .generate_response (question , context )
185183
186- FINAL_RESULTS [key ].append (
184+ final_results [key ].append (
187185 {
188186 "question" : question ,
189187 "answer" : answer ,
@@ -195,11 +193,11 @@ def process_all_conversations(self, output_file_path):
195193 }
196194 )
197195 with open (output_file_path , "w+" ) as f :
198- json .dump (FINAL_RESULTS , f , indent = 4 )
196+ json .dump (final_results , f , indent = 4 )
199197
200198 # Save results
201199 with open (output_file_path , "w+" ) as f :
202- json .dump (FINAL_RESULTS , f , indent = 4 )
200+ json .dump (final_results , f , indent = 4 )
203201
204202
205203class Experiment :
@@ -208,26 +206,36 @@ def __init__(self, technique_type, chunk_size):
208206 self .chunk_size = chunk_size
209207
210208 def run (self ):
211- print (f"Running experiment with technique: { self .technique_type } , chunk size: { self .chunk_size } " )
209+ print (
210+ f"Running experiment with technique: { self .technique_type } , chunk size: { self .chunk_size } "
211+ )
212212
213213
214214def main ():
215-
216215 parser = argparse .ArgumentParser (description = "Run memory experiments" )
217- parser .add_argument ("--technique_type" , choices = TECHNIQUES , default = "rag" , help = "Memory technique to use" )
216+ parser .add_argument (
217+ "--technique_type" , choices = TECHNIQUES , default = "rag" , help = "Memory technique to use"
218+ )
218219 parser .add_argument ("--chunk_size" , type = int , default = 500 , help = "Chunk size for processing" )
219- parser .add_argument ("--output_folder" , type = str , default = "results/" , help = "Output path for results" )
220+ parser .add_argument (
221+ "--output_folder" , type = str , default = "results/" , help = "Output path for results"
222+ )
220223 parser .add_argument ("--top_k" , type = int , default = 30 , help = "Number of top memories to retrieve" )
221224 parser .add_argument ("--num_chunks" , type = int , default = 2 , help = "Number of chunks to process" )
222225
223226 args = parser .parse_args ()
224227
225228 if args .technique_type == "rag" :
226- output_file_path = os .path .join (args .output_folder , f"rag_results_{ args .chunk_size } _k{ args .num_chunks } .json" )
227- rag_manager = RAGManager (data_path = "data/locomo/locomo10_rag.json" , chunk_size = args .chunk_size , k = args .num_chunks )
229+ output_file_path = os .path .join (
230+ args .output_folder , f"rag_results_{ args .chunk_size } _k{ args .num_chunks } .json"
231+ )
232+ rag_manager = RAGManager (
233+ data_path = "data/locomo/locomo10_rag.json" , chunk_size = args .chunk_size , k = args .num_chunks
234+ )
228235 rag_manager .process_all_conversations (output_file_path )
229236
230- if __name__ == "__main__" :
237+
238+ if __name__ == "__main__" :
231239 start = time .time ()
232240 main ()
233241 end = time .time ()
0 commit comments