Skip to content

Commit e04b063

Browse files
authored
add faiss local saving/loading (#676)
- This uses the faiss built-in `write_index` and `load_index` to save and load faiss indexes locally - Also fixes #674 - The save/load functions also use the faiss library, so I refactored the dependency into a function
1 parent e45f7e4 commit e04b063

File tree

2 files changed

+54
-11
lines changed

2 files changed

+54
-11
lines changed

langchain/vectorstores/faiss.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,19 @@
1414
from langchain.vectorstores.utils import maximal_marginal_relevance
1515

1616

17+
def dependable_faiss_import() -> Any:
18+
"""Import faiss if available, otherwise raise error."""
19+
try:
20+
import faiss
21+
except ImportError:
22+
raise ValueError(
23+
"Could not import faiss python package. "
24+
"Please it install it with `pip install faiss` "
25+
"or `pip install faiss-cpu` (depending on Python version)."
26+
)
27+
return faiss
28+
29+
1730
class FAISS(VectorStore):
1831
"""Wrapper around FAISS vector database.
1932
@@ -174,14 +187,7 @@ def from_texts(
174187
embeddings = OpenAIEmbeddings()
175188
faiss = FAISS.from_texts(texts, embeddings)
176189
"""
177-
try:
178-
import faiss
179-
except ImportError:
180-
raise ValueError(
181-
"Could not import faiss python package. "
182-
"Please it install it with `pip install faiss` "
183-
"or `pip install faiss-cpu` (depending on Python version)."
184-
)
190+
faiss = dependable_faiss_import()
185191
embeddings = embedding.embed_documents(texts)
186192
index = faiss.IndexFlatL2(len(embeddings[0]))
187193
index.add(np.array(embeddings, dtype=np.float32))
@@ -194,3 +200,21 @@ def from_texts(
194200
{index_to_id[i]: doc for i, doc in enumerate(documents)}
195201
)
196202
return cls(embedding.embed_query, index, docstore, index_to_id)
203+
204+
def save_local(self, path: str) -> None:
205+
"""Save FAISS index to disk.
206+
207+
Args:
208+
path: Path to save FAISS index to.
209+
"""
210+
faiss = dependable_faiss_import()
211+
faiss.write_index(self.index, path)
212+
213+
def load_local(self, path: str) -> None:
214+
"""Load FAISS index from disk.
215+
216+
Args:
217+
path: Path to load FAISS index from.
218+
"""
219+
faiss = dependable_faiss_import()
220+
self.index = faiss.read_index(path)

tests/integration_tests/vectorstores/test_faiss.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Test FAISS functionality."""
2+
import tempfile
23
from typing import List
34

45
import pytest
@@ -46,9 +47,15 @@ def test_faiss_with_metadatas() -> None:
4647
docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
4748
expected_docstore = InMemoryDocstore(
4849
{
49-
"0": Document(page_content="foo", metadata={"page": 0}),
50-
"1": Document(page_content="bar", metadata={"page": 1}),
51-
"2": Document(page_content="baz", metadata={"page": 2}),
50+
docsearch.index_to_docstore_id[0]: Document(
51+
page_content="foo", metadata={"page": 0}
52+
),
53+
docsearch.index_to_docstore_id[1]: Document(
54+
page_content="bar", metadata={"page": 1}
55+
),
56+
docsearch.index_to_docstore_id[2]: Document(
57+
page_content="baz", metadata={"page": 2}
58+
),
5259
}
5360
)
5461
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
@@ -82,3 +89,15 @@ def test_faiss_add_texts_not_supported() -> None:
8289
docsearch = FAISS(FakeEmbeddings().embed_query, None, Wikipedia(), {})
8390
with pytest.raises(ValueError):
8491
docsearch.add_texts(["foo"])
92+
93+
94+
def test_faiss_local_save_load() -> None:
95+
"""Test end to end serialization."""
96+
texts = ["foo", "bar", "baz"]
97+
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
98+
99+
with tempfile.NamedTemporaryFile() as temp_file:
100+
docsearch.save_local(temp_file.name)
101+
docsearch.index = None
102+
docsearch.load_local(temp_file.name)
103+
assert docsearch.index is not None

0 commit comments

Comments
 (0)