|
| 1 | +''' |
| 2 | +Modify the code from the mem0 project, Original file link |
| 3 | +https://github.com/mem0ai/mem0/blob/main/evaluation/src/rag.py |
| 4 | +''' |
| 5 | + |
| 6 | +import json |
| 7 | +import os |
| 8 | +import numpy as np |
| 9 | +import time |
| 10 | +from collections import defaultdict |
| 11 | + |
| 12 | +import numpy as np |
| 13 | +import tiktoken |
| 14 | +import argparse |
| 15 | + |
| 16 | +from dotenv import load_dotenv |
| 17 | +from jinja2 import Template |
| 18 | +from openai import OpenAI |
| 19 | +from tqdm import tqdm |
| 20 | + |
| 21 | +load_dotenv() |
| 22 | + |
| 23 | +PROMPT = """ |
| 24 | +# Question: |
| 25 | +{{QUESTION}} |
| 26 | +
|
| 27 | +# Context: |
| 28 | +{{CONTEXT}} |
| 29 | +
|
| 30 | +# Short answer: |
| 31 | +""" |
| 32 | + |
| 33 | +TECHNIQUES = ["mem0", "rag", "langmem", "zep", "openai"] |
| 34 | +METHODS = ["add", "search"] |
| 35 | + |
| 36 | +class RAGManager: |
| 37 | + def __init__(self, data_path="data/locomo/locomo10_rag.json", chunk_size=500, k=2): |
| 38 | + self.model = os.getenv("MODEL") |
| 39 | + self.client = OpenAI() |
| 40 | + self.data_path = data_path |
| 41 | + self.chunk_size = chunk_size |
| 42 | + self.k = k |
| 43 | + |
| 44 | + def generate_response(self, question, context): |
| 45 | + template = Template(PROMPT) |
| 46 | + prompt = template.render(CONTEXT=context, QUESTION=question) |
| 47 | + |
| 48 | + max_retries = 3 |
| 49 | + retries = 0 |
| 50 | + |
| 51 | + while retries <= max_retries: |
| 52 | + try: |
| 53 | + t1 = time.time() |
| 54 | + response = self.client.chat.completions.create( |
| 55 | + model=self.model, |
| 56 | + messages=[ |
| 57 | + { |
| 58 | + "role": "system", |
| 59 | + "content": "You are a helpful assistant that can answer " |
| 60 | + "questions based on the provided context." |
| 61 | + "If the question involves timing, use the conversation date for reference." |
| 62 | + "Provide the shortest possible answer." |
| 63 | + "Use words directly from the conversation when possible." |
| 64 | + "Avoid using subjects in your answer.", |
| 65 | + }, |
| 66 | + {"role": "user", "content": prompt}, |
| 67 | + ], |
| 68 | + temperature=0, |
| 69 | + ) |
| 70 | + t2 = time.time() |
| 71 | + # return response.choices[0].message.content.strip(), t2 - t1 |
| 72 | + if response and response.choices: |
| 73 | + content = response.choices[0].message.content |
| 74 | + if content is not None: |
| 75 | + return content.strip(), t2 - t1 |
| 76 | + else: |
| 77 | + return "No content returned", t2 - t1 |
| 78 | + print(f"❎ No content returned!") |
| 79 | + else: |
| 80 | + return "Empty response", t2 - t1 |
| 81 | + except Exception as e: |
| 82 | + retries += 1 |
| 83 | + if retries > max_retries: |
| 84 | + raise e |
| 85 | + time.sleep(1) # Wait before retrying |
| 86 | + |
| 87 | + def clean_chat_history(self, chat_history): |
| 88 | + cleaned_chat_history = "" |
| 89 | + for c in chat_history: |
| 90 | + cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: " f"{c['text']}\n" |
| 91 | + |
| 92 | + return cleaned_chat_history |
| 93 | + |
| 94 | + def calculate_embedding(self, document): |
| 95 | + response = self.client.embeddings.create(model=os.getenv("EMBEDDING_MODEL"), input=document) |
| 96 | + return response.data[0].embedding |
| 97 | + |
| 98 | + def calculate_similarity(self, embedding1, embedding2): |
| 99 | + return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) |
| 100 | + |
| 101 | + def search(self, query, chunks, embeddings, k=1): |
| 102 | + """ |
| 103 | + Search for the top-k most similar chunks to the query. |
| 104 | +
|
| 105 | + Args: |
| 106 | + query: The query string |
| 107 | + chunks: List of text chunks |
| 108 | + embeddings: List of embeddings for each chunk |
| 109 | + k: Number of top chunks to return (default: 1) |
| 110 | +
|
| 111 | + Returns: |
| 112 | + combined_chunks: The combined text of the top-k chunks |
| 113 | + search_time: Time taken for the search |
| 114 | + """ |
| 115 | + t1 = time.time() |
| 116 | + query_embedding = self.calculate_embedding(query) |
| 117 | + similarities = [self.calculate_similarity(query_embedding, embedding) for embedding in embeddings] |
| 118 | + |
| 119 | + # Get indices of top-k most similar chunks |
| 120 | + if k == 1: |
| 121 | + # Original behavior - just get the most similar chunk |
| 122 | + top_indices = [np.argmax(similarities)] |
| 123 | + else: |
| 124 | + # Get indices of top-k chunks |
| 125 | + top_indices = np.argsort(similarities)[-k:][::-1] |
| 126 | + |
| 127 | + # Combine the top-k chunks |
| 128 | + combined_chunks = "\n<->\n".join([chunks[i] for i in top_indices]) |
| 129 | + |
| 130 | + t2 = time.time() |
| 131 | + return combined_chunks, t2 - t1 |
| 132 | + |
| 133 | + def create_chunks(self, chat_history, chunk_size=500): |
| 134 | + """ |
| 135 | + Create chunks using tiktoken for more accurate token counting |
| 136 | + """ |
| 137 | + # Get the encoding for the model |
| 138 | + encoding = tiktoken.encoding_for_model(os.getenv("EMBEDDING_MODEL")) |
| 139 | + |
| 140 | + documents = self.clean_chat_history(chat_history) |
| 141 | + |
| 142 | + if chunk_size == -1: |
| 143 | + return [documents], [] |
| 144 | + |
| 145 | + chunks = [] |
| 146 | + |
| 147 | + # Encode the document |
| 148 | + tokens = encoding.encode(documents) |
| 149 | + |
| 150 | + # Split into chunks based on token count |
| 151 | + for i in range(0, len(tokens), chunk_size): |
| 152 | + chunk_tokens = tokens[i : i + chunk_size] |
| 153 | + chunk = encoding.decode(chunk_tokens) |
| 154 | + chunks.append(chunk) |
| 155 | + |
| 156 | + embeddings = [] |
| 157 | + for chunk in chunks: |
| 158 | + embedding = self.calculate_embedding(chunk) |
| 159 | + embeddings.append(embedding) |
| 160 | + |
| 161 | + return chunks, embeddings |
| 162 | + |
| 163 | + def process_all_conversations(self, output_file_path): |
| 164 | + with open(self.data_path, "r") as f: |
| 165 | + data = json.load(f) |
| 166 | + |
| 167 | + FINAL_RESULTS = defaultdict(list) |
| 168 | + for key, value in tqdm(data.items(), desc="Processing conversations"): |
| 169 | + chat_history = value["conversation"] |
| 170 | + questions = value["question"] |
| 171 | + |
| 172 | + chunks, embeddings = self.create_chunks(chat_history, self.chunk_size) |
| 173 | + |
| 174 | + for item in tqdm(questions, desc="Answering questions", leave=False): |
| 175 | + question = item["question"] |
| 176 | + answer = item.get("answer", "") |
| 177 | + category = item["category"] |
| 178 | + |
| 179 | + if self.chunk_size == -1: |
| 180 | + context = chunks[0] |
| 181 | + search_time = 0 |
| 182 | + else: |
| 183 | + context, search_time = self.search(question, chunks, embeddings, k=self.k) |
| 184 | + response, response_time = self.generate_response(question, context) |
| 185 | + |
| 186 | + FINAL_RESULTS[key].append( |
| 187 | + { |
| 188 | + "question": question, |
| 189 | + "answer": answer, |
| 190 | + "category": category, |
| 191 | + "context": context, |
| 192 | + "response": response, |
| 193 | + "search_time": search_time, |
| 194 | + "response_time": response_time, |
| 195 | + } |
| 196 | + ) |
| 197 | + with open(output_file_path, "w+") as f: |
| 198 | + json.dump(FINAL_RESULTS, f, indent=4) |
| 199 | + |
| 200 | + # Save results |
| 201 | + with open(output_file_path, "w+") as f: |
| 202 | + json.dump(FINAL_RESULTS, f, indent=4) |
| 203 | + |
| 204 | + |
| 205 | +class Experiment: |
| 206 | + def __init__(self, technique_type, chunk_size): |
| 207 | + self.technique_type = technique_type |
| 208 | + self.chunk_size = chunk_size |
| 209 | + |
| 210 | + def run(self): |
| 211 | + print(f"Running experiment with technique: {self.technique_type}, chunk size: {self.chunk_size}") |
| 212 | + |
| 213 | + |
| 214 | +def main(): |
| 215 | + |
| 216 | + parser = argparse.ArgumentParser(description="Run memory experiments") |
| 217 | + parser.add_argument("--technique_type", choices=TECHNIQUES, default="rag", help="Memory technique to use") |
| 218 | + parser.add_argument("--chunk_size", type=int, default=500, help="Chunk size for processing") |
| 219 | + parser.add_argument("--output_folder", type=str, default="results/", help="Output path for results") |
| 220 | + parser.add_argument("--top_k", type=int, default=30, help="Number of top memories to retrieve") |
| 221 | + parser.add_argument("--num_chunks", type=int, default=2, help="Number of chunks to process") |
| 222 | + |
| 223 | + args = parser.parse_args() |
| 224 | + |
| 225 | + if args.technique_type == "rag": |
| 226 | + output_file_path = os.path.join(args.output_folder, f"rag_results_{args.chunk_size}_k{args.num_chunks}.json") |
| 227 | + rag_manager = RAGManager(data_path="data/locomo/locomo10_rag.json", chunk_size=args.chunk_size, k=args.num_chunks) |
| 228 | + rag_manager.process_all_conversations(output_file_path) |
| 229 | + |
| 230 | +if __name__ =="__main__": |
| 231 | + start = time.time() |
| 232 | + main() |
| 233 | + end = time.time() |
| 234 | + print(f"Execution time is:{end - start}") |
0 commit comments