Skip to content

Commit 918b996

Browse files
authored
persistent database fix (#124)
* persistent database fix * changed to abs path * added faiss method and updated hybrid * abs path fix 2 * fixes subsequent similarity chains saving --------- Signed-off-by: Kevin Guan <[email protected]>
1 parent d0c189d commit 918b996

File tree

3 files changed

+36
-8
lines changed

3 files changed

+36
-8
lines changed

backend/src/chains/hybrid_retriever_chain.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import Optional, Union, Any
23

34
from langchain.retrievers import EnsembleRetriever
@@ -75,9 +76,21 @@ def create_hybrid_retriever(self) -> None:
7576
use_cuda=self.use_cuda,
7677
)
7778
if self.vector_db is None:
78-
similarity_retriever_chain.embed_docs(return_docs=True)
79+
cur_path = os.path.abspath(__file__)
80+
path = os.path.join(cur_path, "../../../", "faiss_db")
81+
path = os.path.abspath(path) # Ensure proper parent directory
82+
path_flag = os.path.isdir(path) # Checks if database already exists
83+
database_name = similarity_retriever_chain.name
84+
if path_flag and database_name in os.listdir(path):
85+
if database_name in os.listdir(path):
86+
similarity_retriever_chain.create_vector_db()
87+
similarity_retriever_chain.vector_db.load_db(database_name)
88+
self.vector_db = similarity_retriever_chain.vector_db
89+
self.vector_db.processed_docs = similarity_retriever_chain.vector_db.get_documents()
90+
else:
91+
similarity_retriever_chain.embed_docs(return_docs=True)
92+
self.vector_db = similarity_retriever_chain.vector_db
7993

80-
self.vector_db = similarity_retriever_chain.vector_db
8194
similarity_retriever_chain.create_similarity_retriever(search_k=self.search_k)
8295
similarity_retriever = similarity_retriever_chain.retriever
8396

backend/src/chains/similarity_retriever_chain.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313

1414
class SimilarityRetrieverChain(BaseChain):
15+
count = 0
1516
def __init__(
1617
self,
1718
llm_model: Optional[
@@ -33,6 +34,9 @@ def __init__(
3334
vector_db=vector_db,
3435
)
3536

37+
SimilarityRetrieverChain.count += 1
38+
self.name = f"similarity_INST{SimilarityRetrieverChain.count}"
39+
3640
self.embeddings_config: Optional[dict[str, str]] = embeddings_config
3741
self.use_cuda: bool = use_cuda
3842

@@ -96,6 +100,8 @@ def embed_docs(
96100
folder_paths=self.html_docs_path,
97101
return_docs=return_docs,
98102
)
103+
if self.vector_db is not None:
104+
self.vector_db.save_db(self.name)
99105

100106
return (
101107
self.processed_docs,

backend/src/vectorstores/faiss.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,16 +190,25 @@ def add_documents(
190190

191191
return None
192192

193-
def save_db(self) -> None:
193+
def get_db_path(self) -> str:
194+
cur_path = os.path.abspath(__file__)
195+
path = os.path.join(cur_path, "../../../", "faiss_db")
196+
path = os.path.abspath(path) # Ensure proper parent directory
197+
return path
198+
199+
def save_db(self, name) -> None:
194200
if self._faiss_db is None:
195201
raise ValueError("No documents in FAISS database")
202+
else:
203+
save_path = f"{self.get_db_path()}/{name}"
204+
self._faiss_db.save_local(save_path)
196205

197-
self._faiss_db.save_local(os.getenv("FAISS_DB_PATH", "faiss_db"))
206+
def load_db(self, name) -> None:
207+
load_path = f"{self.get_db_path()}/{name}"
208+
self._faiss_db = FAISS.load_local(load_path, self.embedding_model, allow_dangerous_deserialization=True)
198209

199-
def load_db(self) -> None:
200-
self._faiss_db = FAISS.load_local(
201-
os.getenv("FAISS_DB_PATH", "faiss db"), self.embedding_model
202-
)
210+
def get_documents(self) -> list[Document]:
211+
return self._faiss_db.docstore._dict.values()
203212

204213
def process_json(self, folder_paths: list[str]) -> FAISS:
205214
logging.info("Processing json files...")

0 commit comments

Comments
 (0)