33import os
44import sys
55import logging
6+ import threading
67from datetime import date
78from sentence_transformers import SentenceTransformer
89import numpy as np
910from sklearn .metrics .pairwise import cosine_similarity
1011import traceback
1112from atomicwrites import atomic_write
1213
14+
1315openai .api_base = os .getenv ("OPENAI_BASE_URL" )
1416openai .api_key = os .getenv ("OPENAI_API_KEY" )
1517
@@ -20,9 +22,10 @@ class RAGSystem:
2022 DOC_ABOUT_EMBEDDINGS_PATH = "./data/doc_about_embeddings.npy"
2123
2224 def __init__ (self , knowledge_base_path = "./data/knowledge_base.json" ):
25+ self ._update_lock = threading .Lock ()
2326 self .knowledge_base_path = knowledge_base_path
2427
25- self . knowledge_base = self .load_knowledge_base ()
28+ knowledge_base = self .load_knowledge_base ()
2629 self .model = SentenceTransformer ("all-MiniLM-L6-v2" )
2730
2831 # load existing embeddings if available
@@ -31,21 +34,27 @@ def __init__(self, knowledge_base_path="./data/knowledge_base.json"):
3134 if os .path .exists (self .DOC_ABOUT_EMBEDDINGS_PATH ) and os .path .exists (
3235 self .DOC_EMBEDDINGS_PATH
3336 ):
34- self .doc_about_embeddings = np .load (self .DOC_ABOUT_EMBEDDINGS_PATH )
35- logging .info ("Loaded existing about document about embeddings from disk." )
36- self .doc_embeddings = np .load (self .DOC_EMBEDDINGS_PATH )
37- logging .info ("Loaded existing document embeddings from disk." )
37+ with self ._update_lock :
38+ self .doc_about_embeddings = np .load (self .DOC_ABOUT_EMBEDDINGS_PATH )
39+ logging .info (
40+ "Loaded existing about document about embeddings from disk."
41+ )
42+ self .doc_embeddings = np .load (self .DOC_EMBEDDINGS_PATH )
43+ logging .info ("Loaded existing document embeddings from disk." )
44+ self .knowledge_base = knowledge_base
3845
39- # Save file timestamps when loading cache
40- self .doc_embeddings_timestamp = os .path .getmtime (self .DOC_EMBEDDINGS_PATH )
41- self .doc_about_embeddings_timestamp = os .path .getmtime (
42- self .DOC_ABOUT_EMBEDDINGS_PATH
43- )
44- logging .info (
45- f"Cache loaded - doc_embeddings timestamp: { self .doc_embeddings_timestamp } , doc_about_embeddings timestamp: { self .doc_about_embeddings_timestamp } "
46- )
46+ # Save file timestamps when loading cache
47+ self .doc_embeddings_timestamp = os .path .getmtime (
48+ self .DOC_EMBEDDINGS_PATH
49+ )
50+ self .doc_about_embeddings_timestamp = os .path .getmtime (
51+ self .DOC_ABOUT_EMBEDDINGS_PATH
52+ )
53+ logging .info (
54+ f"Cache loaded - doc_embeddings timestamp: { self .doc_embeddings_timestamp } , doc_about_embeddings timestamp: { self .doc_about_embeddings_timestamp } "
55+ )
4756 else :
48- self .rebuild_embeddings ()
57+ self .rebuild_embeddings (knowledge_base )
4958
5059 logging .info ("Knowledge base embeddings created" )
5160 self .conversation_history = []
@@ -54,43 +63,53 @@ def _atomic_save_numpy(self, file_path, data):
5463 with atomic_write (file_path , mode = "wb" , overwrite = True ) as f :
5564 np .save (f , data )
5665
57- def rebuild_embeddings (self ):
66+ def rebuild_embeddings (self , knowledge_base ):
5867 logging .info ("Rebuilding document embeddings..." )
5968
60- new_doc_embeddings = self .embed_knowledge_base ()
61- new_about_embeddings = self .embed_knowledge_base_about ()
62-
63- # Atomic saves with guaranteed order
64- self ._atomic_save_numpy (
65- self .DOC_EMBEDDINGS_PATH , new_doc_embeddings .cpu ().numpy ()
66- )
67- self ._atomic_save_numpy (
68- self .DOC_ABOUT_EMBEDDINGS_PATH , new_about_embeddings .cpu ().numpy ()
69- )
69+ new_doc_embeddings = self .embed_knowledge_base (knowledge_base )
70+ new_about_embeddings = self .embed_knowledge_base_about (knowledge_base )
7071
71- # Update in-memory embeddings only after successful saves
72- self .doc_embeddings = new_doc_embeddings
73- self .doc_about_embeddings = new_about_embeddings
72+ # Defensive check for size mismatches
73+ sizes = [
74+ len (new_about_embeddings ),
75+ len (new_doc_embeddings ),
76+ len (knowledge_base ),
77+ ]
78+ if len (set (sizes )) > 1 : # Not all sizes are equal
79+ logging .error (
80+ f"rebuild embeddings Array size mismatch detected: text_similarities={ sizes [0 ]} , about_similarities={ sizes [1 ]} , knowledge_base={ sizes [2 ]} "
81+ )
82+ return # Abandon update
7483
75- # Update file timestamps after successful saves
76- self .doc_embeddings_timestamp = os .path .getmtime (self .DOC_EMBEDDINGS_PATH )
77- self .doc_about_embeddings_timestamp = os .path .getmtime (
78- self .DOC_ABOUT_EMBEDDINGS_PATH
79- )
84+ # Atomically update files, in-memory cache, and timestamps
85+ with self ._update_lock :
86+ self ._atomic_save_numpy (
87+ self .DOC_EMBEDDINGS_PATH , new_doc_embeddings .cpu ().numpy ()
88+ )
89+ self ._atomic_save_numpy (
90+ self .DOC_ABOUT_EMBEDDINGS_PATH , new_about_embeddings .cpu ().numpy ()
91+ )
92+ self .knowledge_base = knowledge_base
93+ self .doc_embeddings = new_doc_embeddings
94+ self .doc_about_embeddings = new_about_embeddings
95+ self .doc_embeddings_timestamp = os .path .getmtime (self .DOC_EMBEDDINGS_PATH )
96+ self .doc_about_embeddings_timestamp = os .path .getmtime (
97+ self .DOC_ABOUT_EMBEDDINGS_PATH
98+ )
8099
81100 logging .info ("Embeddings rebuilt successfully." )
82101
83102 def load_knowledge_base (self ):
84103 with open (self .knowledge_base_path , "r" ) as kb_file :
85104 return json .load (kb_file )
86105
87- def embed_knowledge_base (self ):
88- docs = [f"{ doc ['about' ]} . { doc ['text' ]} " for doc in self . knowledge_base ]
106+ def embed_knowledge_base (self , knowledge_base ):
107+ docs = [f"{ doc ['about' ]} . { doc ['text' ]} " for doc in knowledge_base ]
89108 return self .model .encode (docs , convert_to_tensor = True )
90109
91- def embed_knowledge_base_about (self ):
110+ def embed_knowledge_base_about (self , knowledge_base ):
92111 return self .model .encode (
93- [doc ["about" ] for doc in self . knowledge_base ], convert_to_tensor = True
112+ [doc ["about" ] for doc in knowledge_base ], convert_to_tensor = True
94113 )
95114
96115 def normalize_query (self , query ):
@@ -193,6 +212,7 @@ def compute_relevance_scores(
193212 self , text_similarities , about_similarities , high_match_threshold
194213 ):
195214 relevance_scores = []
215+
196216 for i , _ in enumerate (self .knowledge_base ):
197217 about_similarity = about_similarities [i ]
198218 text_similarity = text_similarities [i ]
@@ -321,8 +341,8 @@ def rebuild(self):
321341 Rebuild the embeddings for the knowledge base. This should be called whenever the knowledge base is updated.
322342 """
323343 print ("Rebuilding embeddings for the knowledge base..." )
324- self . knowledge_base = self .load_knowledge_base () # Reload the knowledge base
325- self .doc_embeddings = self . rebuild_embeddings () # Rebuild the embeddings
344+ knowledge_base = self .load_knowledge_base () # Reload the knowledge base
345+ self .rebuild_embeddings (knowledge_base ) # Rebuild the embeddings
326346 print ("Embeddings have been rebuilt." )
327347
328348 def get_citations (self , retrieved_docs ):
0 commit comments