@@ -22,7 +22,7 @@ class RAGSystem:
2222 def __init__ (self , knowledge_base_path = "./data/knowledge_base.json" ):
2323 self .knowledge_base_path = knowledge_base_path
2424
25- self . knowledge_base = self .load_knowledge_base ()
25+ knowledge_base = self .load_knowledge_base ()
2626 self .model = SentenceTransformer ("all-MiniLM-L6-v2" )
2727
2828 # load existing embeddings if available
@@ -35,6 +35,7 @@ def __init__(self, knowledge_base_path="./data/knowledge_base.json"):
3535 logging .info ("Loaded existing about document about embeddings from disk." )
3636 self .doc_embeddings = np .load (self .DOC_EMBEDDINGS_PATH )
3737 logging .info ("Loaded existing document embeddings from disk." )
38+ self .knowledge_base = knowledge_base
3839
3940 # Save file timestamps when loading cache
4041 self .doc_embeddings_timestamp = os .path .getmtime (self .DOC_EMBEDDINGS_PATH )
@@ -45,7 +46,7 @@ def __init__(self, knowledge_base_path="./data/knowledge_base.json"):
4546 f"Cache loaded - doc_embeddings timestamp: { self .doc_embeddings_timestamp } , doc_about_embeddings timestamp: { self .doc_about_embeddings_timestamp } "
4647 )
4748 else :
48- self .rebuild_embeddings ()
49+ self .rebuild_embeddings (knowledge_base )
4950
5051 logging .info ("Knowledge base embeddings created" )
5152 self .conversation_history = []
@@ -54,11 +55,11 @@ def _atomic_save_numpy(self, file_path, data):
5455 with atomic_write (file_path , mode = "wb" , overwrite = True ) as f :
5556 np .save (f , data )
5657
57- def rebuild_embeddings (self ):
58+ def rebuild_embeddings (self , knowledge_base ):
5859 logging .info ("Rebuilding document embeddings..." )
5960
60- new_doc_embeddings = self .embed_knowledge_base ()
61- new_about_embeddings = self .embed_knowledge_base_about ()
61+ new_doc_embeddings = self .embed_knowledge_base (knowledge_base )
62+ new_about_embeddings = self .embed_knowledge_base_about (knowledge_base )
6263
6364 # Atomic saves with guaranteed order
6465 self ._atomic_save_numpy (
@@ -69,6 +70,7 @@ def rebuild_embeddings(self):
6970 )
7071
7172 # Update in-memory embeddings only after successful saves
73+ self .knowledge_base = knowledge_base
7274 self .doc_embeddings = new_doc_embeddings
7375 self .doc_about_embeddings = new_about_embeddings
7476
@@ -84,13 +86,13 @@ def load_knowledge_base(self):
8486 with open (self .knowledge_base_path , "r" ) as kb_file :
8587 return json .load (kb_file )
8688
87- def embed_knowledge_base (self ):
88- docs = [f"{ doc ['about' ]} . { doc ['text' ]} " for doc in self . knowledge_base ]
89+ def embed_knowledge_base (self , knowledge_base ):
90+ docs = [f"{ doc ['about' ]} . { doc ['text' ]} " for doc in knowledge_base ]
8991 return self .model .encode (docs , convert_to_tensor = True )
9092
91- def embed_knowledge_base_about (self ):
93+ def embed_knowledge_base_about (self , knowledge_base ):
9294 return self .model .encode (
93- [doc ["about" ] for doc in self . knowledge_base ], convert_to_tensor = True
95+ [doc ["about" ] for doc in knowledge_base ], convert_to_tensor = True
9496 )
9597
9698 def normalize_query (self , query ):
@@ -193,7 +195,20 @@ def compute_relevance_scores(
193195 self , text_similarities , about_similarities , high_match_threshold
194196 ):
195197 relevance_scores = []
196- for i , _ in enumerate (self .knowledge_base ):
198+
199+ # Defensive check for size mismatches
200+ sizes = [
201+ len (text_similarities ),
202+ len (about_similarities ),
203+ len (self .knowledge_base ),
204+ ]
205+ if len (set (sizes )) > 1 : # Not all sizes are equal
206+ logging .warning (
207+ f"Array size mismatch detected: text_similarities={ sizes [0 ]} , about_similarities={ sizes [1 ]} , knowledge_base={ sizes [2 ]} "
208+ )
209+
210+ max_index = min (sizes )
211+ for i in range (max_index ):
197212 about_similarity = about_similarities [i ]
198213 text_similarity = text_similarities [i ]
199214 # If either about or text similarity is above the high match threshold, prioritize it
@@ -321,8 +336,10 @@ def rebuild(self):
321336 Rebuild the embeddings for the knowledge base. This should be called whenever the knowledge base is updated.
322337 """
323338 print ("Rebuilding embeddings for the knowledge base..." )
324- self .knowledge_base = self .load_knowledge_base () # Reload the knowledge base
325- self .doc_embeddings = self .rebuild_embeddings () # Rebuild the embeddings
339+ knowledge_base = self .load_knowledge_base () # Reload the knowledge base
340+ self .doc_embeddings = self .rebuild_embeddings (
341+ knowledge_base
342+ ) # Rebuild the embeddings
326343 print ("Embeddings have been rebuilt." )
327344
328345 def get_citations (self , retrieved_docs ):
0 commit comments