Skip to content

Commit a46c3d8

Browse files
author
Eric Liu
committed
refresh if cache has changed.
1 parent a130251 commit a46c3d8

File tree

2 files changed

+136
-6
lines changed

2 files changed

+136
-6
lines changed

app/rag_system.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515

1616

1717
class RAGSystem:
18+
# Cache file paths
19+
DOC_EMBEDDINGS_PATH = "./data/doc_embeddings.npy"
20+
DOC_ABOUT_EMBEDDINGS_PATH = "./data/doc_about_embeddings.npy"
21+
1822
def __init__(self, knowledge_base_path="./data/knowledge_base.json"):
1923
self.knowledge_base_path = knowledge_base_path
2024

@@ -24,13 +28,22 @@ def __init__(self, knowledge_base_path="./data/knowledge_base.json"):
2428
# load existing embeddings if available
2529
logging.info("Embedding knowledge base...")
2630

27-
if os.path.exists("./data/doc_about_embeddings.npy") and os.path.exists(
28-
"./data/doc_embeddings.npy"
31+
if os.path.exists(self.DOC_ABOUT_EMBEDDINGS_PATH) and os.path.exists(
32+
self.DOC_EMBEDDINGS_PATH
2933
):
30-
self.doc_about_embeddings = np.load("./data/doc_about_embeddings.npy")
34+
self.doc_about_embeddings = np.load(self.DOC_ABOUT_EMBEDDINGS_PATH)
3135
logging.info("Loaded existing about document about embeddings from disk.")
32-
self.doc_embeddings = np.load("./data/doc_embeddings.npy")
36+
self.doc_embeddings = np.load(self.DOC_EMBEDDINGS_PATH)
3337
logging.info("Loaded existing document embeddings from disk.")
38+
39+
# Save file timestamps when loading cache
40+
self.doc_embeddings_timestamp = os.path.getmtime(self.DOC_EMBEDDINGS_PATH)
41+
self.doc_about_embeddings_timestamp = os.path.getmtime(
42+
self.DOC_ABOUT_EMBEDDINGS_PATH
43+
)
44+
logging.info(
45+
f"Cache loaded - doc_embeddings timestamp: {self.doc_embeddings_timestamp}, doc_about_embeddings timestamp: {self.doc_about_embeddings_timestamp}"
46+
)
3447
else:
3548
self.rebuild_embeddings()
3649

@@ -49,16 +62,22 @@ def rebuild_embeddings(self):
4962

5063
# Atomic saves with guaranteed order
5164
self._atomic_save_numpy(
52-
"./data/doc_embeddings.npy", new_doc_embeddings.cpu().numpy()
65+
self.DOC_EMBEDDINGS_PATH, new_doc_embeddings.cpu().numpy()
5366
)
5467
self._atomic_save_numpy(
55-
"./data/doc_about_embeddings.npy", new_about_embeddings.cpu().numpy()
68+
self.DOC_ABOUT_EMBEDDINGS_PATH, new_about_embeddings.cpu().numpy()
5669
)
5770

5871
# Update in-memory embeddings only after successful saves
5972
self.doc_embeddings = new_doc_embeddings
6073
self.doc_about_embeddings = new_about_embeddings
6174

75+
# Update file timestamps after successful saves
76+
self.doc_embeddings_timestamp = os.path.getmtime(self.DOC_EMBEDDINGS_PATH)
77+
self.doc_about_embeddings_timestamp = os.path.getmtime(
78+
self.DOC_ABOUT_EMBEDDINGS_PATH
79+
)
80+
6281
logging.info("Embeddings rebuilt successfully.")
6382

6483
def load_knowledge_base(self):
@@ -117,6 +136,43 @@ def compute_document_scores(
117136

118137
return result
119138

139+
def cache_check(func):
140+
"""Decorator to automatically check cache consistency"""
141+
142+
def wrapper(self, *args, **kwargs):
143+
try:
144+
current_times = [
145+
os.path.getmtime(self.DOC_EMBEDDINGS_PATH),
146+
os.path.getmtime(self.DOC_ABOUT_EMBEDDINGS_PATH),
147+
]
148+
stored_times = [
149+
self.doc_embeddings_timestamp,
150+
self.doc_about_embeddings_timestamp,
151+
]
152+
153+
# update cache if timestamps are different from out last load
154+
if current_times != stored_times:
155+
self._reload_cache()
156+
157+
except (OSError, FileNotFoundError, PermissionError):
158+
logging.warning("Cache files inaccessible, rebuilding...")
159+
self.rebuild_embeddings()
160+
161+
return func(self, *args, **kwargs)
162+
163+
return wrapper
164+
165+
def _reload_cache(self):
166+
self.doc_embeddings = np.load(self.DOC_EMBEDDINGS_PATH)
167+
self.doc_about_embeddings = np.load(self.DOC_ABOUT_EMBEDDINGS_PATH)
168+
169+
# update our timestamps of the cached files
170+
self.doc_embeddings_timestamp = os.path.getmtime(self.DOC_EMBEDDINGS_PATH)
171+
self.doc_about_embeddings_timestamp = os.path.getmtime(
172+
self.DOC_ABOUT_EMBEDDINGS_PATH
173+
)
174+
175+
@cache_check
120176
def retrieve(
121177
self, query, similarity_threshold=0.4, high_match_threshold=0.8, max_docs=5
122178
):

app/test_rag_system.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22

33
from rag_system import RAGSystem
4+
import os
45

56

67
class TestRAGSystem(unittest.TestCase):
@@ -123,6 +124,79 @@ def test_compute_document_scores(self):
123124

124125
print("Test for compute_document_scores passed successfully!")
125126

127+
def test_cache_check_reload_cache(self):
128+
# Simulate cache file timestamp change to trigger _reload_cache
129+
original_doc_embeddings_timestamp = self.rag_system.doc_embeddings_timestamp
130+
original_doc_about_embeddings_timestamp = (
131+
self.rag_system.doc_about_embeddings_timestamp
132+
)
133+
134+
# Patch os.path.getmtime to return different timestamps
135+
def fake_getmtime(path):
136+
if path == self.rag_system.DOC_EMBEDDINGS_PATH:
137+
return original_doc_embeddings_timestamp + 1
138+
if path == self.rag_system.DOC_ABOUT_EMBEDDINGS_PATH:
139+
return original_doc_about_embeddings_timestamp + 1
140+
return 0
141+
142+
self.rag_system._reload_cache_called = False
143+
144+
def fake_reload_cache():
145+
self.rag_system._reload_cache_called = True
146+
real_reload_cache()
147+
148+
real_getmtime = os.path.getmtime
149+
os.path.getmtime = fake_getmtime
150+
151+
# Patch _reload_cache to set a flag
152+
real_reload_cache = self.rag_system._reload_cache
153+
self.rag_system._reload_cache = fake_reload_cache
154+
155+
# Call a cache_check-decorated method
156+
self.rag_system.retrieve("test query")
157+
158+
self.assertTrue(
159+
self.rag_system._reload_cache_called,
160+
"Cache reload was not triggered when timestamps changed.",
161+
)
162+
163+
# Restore patched methods
164+
os.path.getmtime = real_getmtime
165+
self.rag_system._reload_cache = real_reload_cache
166+
print("Test for cache_check reload_cache passed successfully!")
167+
168+
def test_cache_check_rebuild_embeddings_on_error(self):
169+
# Patch os.path.getmtime to raise OSError
170+
real_getmtime = os.path.getmtime
171+
172+
def raise_oserror(path):
173+
raise OSError("Simulated error")
174+
175+
os.path.getmtime = raise_oserror
176+
177+
self.rag_system._rebuild_embeddings_called = False
178+
179+
def fake_rebuild_embeddings():
180+
self.rag_system._rebuild_embeddings_called = True
181+
return real_rebuild_embeddings()
182+
183+
self.rag_system.rebuild_embeddings = fake_rebuild_embeddings
184+
# Patch rebuild_embeddings to set a flag
185+
real_rebuild_embeddings = self.rag_system.rebuild_embeddings
186+
187+
# Call a cache_check-decorated method
188+
self.rag_system.retrieve("test query")
189+
190+
self.assertTrue(
191+
self.rag_system._rebuild_embeddings_called,
192+
"rebuild_embeddings was not triggered on cache access error.",
193+
)
194+
195+
# Restore patched methods
196+
os.path.getmtime = real_getmtime
197+
self.rag_system.rebuild_embeddings = real_rebuild_embeddings
198+
print("Test for cache_check rebuild_embeddings on error passed successfully!")
199+
126200

127201
if __name__ == "__main__":
128202
unittest.main()

0 commit comments

Comments
 (0)