Skip to content

Commit 81a60ad

Browse files
author
Eric Liu
committed
review updates. add thread safety by making in-memory updates atomic
1 parent f60b7cd commit 81a60ad

File tree

1 file changed

+51
-46
lines changed

1 file changed

+51
-46
lines changed

app/rag_system.py

Lines changed: 51 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ class RAGSystem:
2222
def __init__(self, knowledge_base_path="./data/knowledge_base.json"):
2323
self.knowledge_base_path = knowledge_base_path
2424

25+
# Lock for atomic updates of in-memory cache
26+
import threading
27+
28+
self._update_lock = threading.Lock()
29+
2530
knowledge_base = self.load_knowledge_base()
2631
self.model = SentenceTransformer("all-MiniLM-L6-v2")
2732

@@ -31,20 +36,25 @@ def __init__(self, knowledge_base_path="./data/knowledge_base.json"):
3136
if os.path.exists(self.DOC_ABOUT_EMBEDDINGS_PATH) and os.path.exists(
3237
self.DOC_EMBEDDINGS_PATH
3338
):
34-
self.doc_about_embeddings = np.load(self.DOC_ABOUT_EMBEDDINGS_PATH)
35-
logging.info("Loaded existing about document about embeddings from disk.")
36-
self.doc_embeddings = np.load(self.DOC_EMBEDDINGS_PATH)
37-
logging.info("Loaded existing document embeddings from disk.")
38-
self.knowledge_base = knowledge_base
39+
with self._update_lock:
40+
self.doc_about_embeddings = np.load(self.DOC_ABOUT_EMBEDDINGS_PATH)
41+
logging.info(
42+
"Loaded existing about document about embeddings from disk."
43+
)
44+
self.doc_embeddings = np.load(self.DOC_EMBEDDINGS_PATH)
45+
logging.info("Loaded existing document embeddings from disk.")
46+
self.knowledge_base = knowledge_base
3947

40-
# Save file timestamps when loading cache
41-
self.doc_embeddings_timestamp = os.path.getmtime(self.DOC_EMBEDDINGS_PATH)
42-
self.doc_about_embeddings_timestamp = os.path.getmtime(
43-
self.DOC_ABOUT_EMBEDDINGS_PATH
44-
)
45-
logging.info(
46-
f"Cache loaded - doc_embeddings timestamp: {self.doc_embeddings_timestamp}, doc_about_embeddings timestamp: {self.doc_about_embeddings_timestamp}"
47-
)
48+
# Save file timestamps when loading cache
49+
self.doc_embeddings_timestamp = os.path.getmtime(
50+
self.DOC_EMBEDDINGS_PATH
51+
)
52+
self.doc_about_embeddings_timestamp = os.path.getmtime(
53+
self.DOC_ABOUT_EMBEDDINGS_PATH
54+
)
55+
logging.info(
56+
f"Cache loaded - doc_embeddings timestamp: {self.doc_embeddings_timestamp}, doc_about_embeddings timestamp: {self.doc_about_embeddings_timestamp}"
57+
)
4858
else:
4959
self.rebuild_embeddings(knowledge_base)
5060

@@ -61,24 +71,33 @@ def rebuild_embeddings(self, knowledge_base):
6171
new_doc_embeddings = self.embed_knowledge_base(knowledge_base)
6272
new_about_embeddings = self.embed_knowledge_base_about(knowledge_base)
6373

64-
# Atomic saves with guaranteed order
65-
self._atomic_save_numpy(
66-
self.DOC_EMBEDDINGS_PATH, new_doc_embeddings.cpu().numpy()
67-
)
68-
self._atomic_save_numpy(
69-
self.DOC_ABOUT_EMBEDDINGS_PATH, new_about_embeddings.cpu().numpy()
70-
)
71-
72-
# Update in-memory embeddings only after successful saves
73-
self.knowledge_base = knowledge_base
74-
self.doc_embeddings = new_doc_embeddings
75-
self.doc_about_embeddings = new_about_embeddings
74+
# Defensive check for size mismatches
75+
sizes = [
76+
len(new_about_embeddings),
77+
len(new_doc_embeddings),
78+
len(knowledge_base),
79+
]
80+
if len(set(sizes)) > 1: # Not all sizes are equal
81+
logging.error(
82+
f"rebuild embeddings Array size mismatch detected: text_similarities={sizes[0]}, about_similarities={sizes[1]}, knowledge_base={sizes[2]}"
83+
)
84+
return # Abandon update
7685

77-
# Update file timestamps after successful saves
78-
self.doc_embeddings_timestamp = os.path.getmtime(self.DOC_EMBEDDINGS_PATH)
79-
self.doc_about_embeddings_timestamp = os.path.getmtime(
80-
self.DOC_ABOUT_EMBEDDINGS_PATH
81-
)
86+
# Atomically update files, in-memory cache, and timestamps
87+
with self._update_lock:
88+
self._atomic_save_numpy(
89+
self.DOC_EMBEDDINGS_PATH, new_doc_embeddings.cpu().numpy()
90+
)
91+
self._atomic_save_numpy(
92+
self.DOC_ABOUT_EMBEDDINGS_PATH, new_about_embeddings.cpu().numpy()
93+
)
94+
self.knowledge_base = knowledge_base
95+
self.doc_embeddings = new_doc_embeddings
96+
self.doc_about_embeddings = new_about_embeddings
97+
self.doc_embeddings_timestamp = os.path.getmtime(self.DOC_EMBEDDINGS_PATH)
98+
self.doc_about_embeddings_timestamp = os.path.getmtime(
99+
self.DOC_ABOUT_EMBEDDINGS_PATH
100+
)
82101

83102
logging.info("Embeddings rebuilt successfully.")
84103

@@ -196,19 +215,7 @@ def compute_relevance_scores(
196215
):
197216
relevance_scores = []
198217

199-
# Defensive check for size mismatches
200-
sizes = [
201-
len(text_similarities),
202-
len(about_similarities),
203-
len(self.knowledge_base),
204-
]
205-
if len(set(sizes)) > 1: # Not all sizes are equal
206-
logging.warning(
207-
f"Array size mismatch detected: text_similarities={sizes[0]}, about_similarities={sizes[1]}, knowledge_base={sizes[2]}"
208-
)
209-
210-
max_index = min(sizes)
211-
for i in range(max_index):
218+
for i, _ in enumerate(self.knowledge_base):
212219
about_similarity = about_similarities[i]
213220
text_similarity = text_similarities[i]
214221
# If either about or text similarity is above the high match threshold, prioritize it
@@ -337,9 +344,7 @@ def rebuild(self):
337344
"""
338345
print("Rebuilding embeddings for the knowledge base...")
339346
knowledge_base = self.load_knowledge_base() # Reload the knowledge base
340-
self.doc_embeddings = self.rebuild_embeddings(
341-
knowledge_base
342-
) # Rebuild the embeddings
347+
self.rebuild_embeddings(knowledge_base) # Rebuild the embeddings
343348
print("Embeddings have been rebuilt.")
344349

345350
def get_citations(self, retrieved_docs):

0 commit comments

Comments
 (0)