Skip to content

Commit 0f40dc5

Browse files
authored
Extract API for coarse knowledge graph (#452)
* Extract API for coarse knowledge graph * Add traversal search_type to search method * Remove links from API * Add async methods * Changes from review * Check iterator sizes in conversion to nodes
1 parent fa98828 commit 0f40dc5

File tree

10 files changed

+511
-150
lines changed

10 files changed

+511
-150
lines changed
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
from .knowledge_store import KnowledgeStore
1+
from .cassandra import CassandraKnowledgeStore
2+
from .base import KnowledgeStore
23

3-
__all__ = ["KnowledgeStore"]
4+
__all__ = ["CassandraKnowledgeStore", "KnowledgeStore"]
Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
from __future__ import annotations
2+
3+
from abc import abstractmethod
4+
from typing import (
5+
Any,
6+
AsyncIterable,
7+
ClassVar,
8+
Collection,
9+
Iterable,
10+
Iterator,
11+
List,
12+
Optional,
13+
)
14+
15+
from langchain_core.callbacks import (
16+
AsyncCallbackManagerForRetrieverRun,
17+
CallbackManagerForRetrieverRun,
18+
)
19+
from langchain_core.documents import Document
20+
from langchain_core.load import Serializable
21+
from langchain_core.runnables import run_in_executor
22+
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
23+
from pydantic import Field
24+
25+
26+
def _has_next(iterator: Iterator) -> None:
27+
"""Checks if the iterator has more elements.
28+
Warning: consumes an element from the iterator"""
29+
sentinel = object()
30+
return next(iterator, sentinel) is not sentinel
31+
32+
33+
class Node(Serializable):
34+
"""Node in the KnowledgeStore graph"""
35+
36+
id: Optional[str]
37+
"""Unique ID for the node. Shall be generated by the KnowledgeStore if not set"""
38+
metadata: dict = Field(default_factory=dict)
39+
"""Metadata for the node. May contain information used to link this node
40+
with other nodes."""
41+
42+
43+
class TextNode(Node):
44+
text: str
45+
"""Text contained by the node"""
46+
47+
48+
def _texts_to_nodes(
49+
texts: Iterable[str],
50+
metadatas: Optional[Iterable[dict]],
51+
ids: Optional[Iterable[str]],
52+
) -> Iterator[Node]:
53+
metadatas_it = iter(metadatas) if metadatas else None
54+
ids_it = iter(ids) if ids else None
55+
for text in texts:
56+
try:
57+
_metadata = next(metadatas_it) if metadatas_it else {}
58+
except StopIteration:
59+
raise ValueError("texts iterable longer than metadatas")
60+
try:
61+
_id = next(ids_it) if ids_it else None
62+
except StopIteration:
63+
raise ValueError("texts iterable longer than ids")
64+
yield TextNode(
65+
id=_id,
66+
metadata=_metadata,
67+
text=text,
68+
)
69+
if ids and _has_next(ids_it):
70+
raise ValueError("ids iterable longer than texts")
71+
if metadatas and _has_next(metadatas_it):
72+
raise ValueError("metadatas iterable longer than texts")
73+
74+
75+
def _documents_to_nodes(
76+
documents: Iterable[Document], ids: Optional[Iterable[str]]
77+
) -> Iterator[Node]:
78+
ids_it = iter(ids) if ids else None
79+
for doc in documents:
80+
try:
81+
_id = next(ids_it) if ids_it else None
82+
except StopIteration:
83+
raise ValueError("documents iterable longer than ids")
84+
yield TextNode(
85+
id=_id,
86+
metadata=doc.metadata,
87+
text=doc.page_content,
88+
)
89+
if ids and _has_next(ids_it):
90+
raise ValueError("ids iterable longer than documents")
91+
92+
93+
class KnowledgeStore(VectorStore):
94+
"""A hybrid vector-and-graph knowledge store.
95+
96+
Document chunks support vector-similarity search as well as edges linking
97+
chunks based on structural and semantic properties.
98+
"""
99+
100+
@abstractmethod
101+
def add_nodes(
102+
self,
103+
nodes: Iterable[Node],
104+
**kwargs: Any,
105+
) -> List[str]:
106+
"""Add nodes to the knowledge store
107+
108+
Args:
109+
nodes: the nodes to add.
110+
"""
111+
112+
async def aadd_nodes(
113+
self,
114+
nodes: Iterable[Node],
115+
**kwargs: Any,
116+
) -> List[str]:
117+
"""Add nodes to the knowledge store
118+
119+
Args:
120+
nodes: the nodes to add.
121+
"""
122+
return await run_in_executor(None, self.add_nodes, nodes, **kwargs)
123+
124+
def add_texts(
125+
self,
126+
texts: Iterable[str],
127+
metadatas: Optional[Iterable[dict]] = None,
128+
*,
129+
ids: Optional[Iterable[str]] = None,
130+
**kwargs: Any,
131+
) -> List[str]:
132+
nodes = _texts_to_nodes(texts, metadatas, ids)
133+
return self.add_nodes(nodes, **kwargs)
134+
135+
async def aadd_texts(
136+
self,
137+
texts: Iterable[str],
138+
metadatas: Optional[Iterable[dict]] = None,
139+
*,
140+
ids: Optional[Iterable[str]] = None,
141+
**kwargs: Any,
142+
) -> List[str]:
143+
nodes = _texts_to_nodes(texts, metadatas, ids)
144+
return await self.aadd_nodes(nodes, **kwargs)
145+
146+
def add_documents(
147+
self,
148+
documents: Iterable[Document] = None,
149+
*,
150+
ids: Optional[Iterable[str]] = None,
151+
**kwargs: Any,
152+
) -> List[str]:
153+
nodes = _documents_to_nodes(documents, ids)
154+
return self.add_nodes(nodes, **kwargs)
155+
156+
async def aadd_documents(
157+
self,
158+
documents: Iterable[Document] = None,
159+
*,
160+
ids: Optional[Iterable[str]] = None,
161+
**kwargs: Any,
162+
) -> List[str]:
163+
nodes = _documents_to_nodes(documents, ids)
164+
return await self.aadd_nodes(nodes, **kwargs)
165+
166+
@abstractmethod
167+
def traversing_retrieve(
168+
self,
169+
query: str,
170+
*,
171+
k: int = 4,
172+
depth: int = 1,
173+
**kwargs: Any,
174+
) -> Iterable[Document]:
175+
"""Retrieve documents from traversing this knowledge store.
176+
177+
First, `k` nodes are retrieved using a search for each `query` string.
178+
Then, additional nodes are discovered up to the given `depth` from those
179+
starting nodes.
180+
181+
Args:
182+
query: The query string.
183+
k: The number of Documents to return from the initial search.
184+
Defaults to 4. Applies to each of the query strings.
185+
depth: The maximum depth of edges to traverse. Defaults to 1.
186+
Returns:
187+
Retrieved documents.
188+
"""
189+
190+
async def atraversing_retrieve(
191+
self,
192+
query: str,
193+
*,
194+
k: int = 4,
195+
depth: int = 1,
196+
**kwargs: Any,
197+
) -> AsyncIterable[Document]:
198+
"""Retrieve documents from traversing this knowledge store.
199+
200+
First, `k` nodes are retrieved using a search for each `query` string.
201+
Then, additional nodes are discovered up to the given `depth` from those
202+
starting nodes.
203+
204+
Args:
205+
query: The query string.
206+
k: The number of Documents to return from the initial search.
207+
Defaults to 4. Applies to each of the query strings.
208+
depth: The maximum depth of edges to traverse. Defaults to 1.
209+
Returns:
210+
Retrieved documents.
211+
"""
212+
for doc in await run_in_executor(
213+
None, self.traversing_retrieve, query, k=k, depth=depth, **kwargs
214+
):
215+
yield doc
216+
217+
def similarity_search(
218+
self, query: str, k: int = 4, **kwargs: Any
219+
) -> List[Document]:
220+
return list(self.traversing_retrieve(query, k=k, depth=0))
221+
222+
async def asimilarity_search(
223+
self, query: str, k: int = 4, **kwargs: Any
224+
) -> List[Document]:
225+
return [doc async for doc in self.atraversing_retrieve(query, k=k, depth=0)]
226+
227+
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
228+
if search_type == "similarity":
229+
return self.similarity_search(query, **kwargs)
230+
elif search_type == "similarity_score_threshold":
231+
docs_and_similarities = self.similarity_search_with_relevance_scores(
232+
query, **kwargs
233+
)
234+
return [doc for doc, _ in docs_and_similarities]
235+
elif search_type == "mmr":
236+
return self.max_marginal_relevance_search(query, **kwargs)
237+
elif search_type == "traversal":
238+
return list(self.traversing_retrieve(query, **kwargs))
239+
else:
240+
raise ValueError(
241+
f"search_type of {search_type} not allowed. Expected "
242+
"search_type to be 'similarity', 'similarity_score_threshold', "
243+
"'mmr' or 'traversal'."
244+
)
245+
246+
async def asearch(
247+
self, query: str, search_type: str, **kwargs: Any
248+
) -> List[Document]:
249+
if search_type == "similarity":
250+
return await self.asimilarity_search(query, **kwargs)
251+
elif search_type == "similarity_score_threshold":
252+
docs_and_similarities = await self.asimilarity_search_with_relevance_scores(
253+
query, **kwargs
254+
)
255+
return [doc for doc, _ in docs_and_similarities]
256+
elif search_type == "mmr":
257+
return await self.amax_marginal_relevance_search(query, **kwargs)
258+
elif search_type == "traversal":
259+
return [doc async for doc in self.atraversing_retrieve(query, **kwargs)]
260+
else:
261+
raise ValueError(
262+
f"search_type of {search_type} not allowed. Expected "
263+
"search_type to be 'similarity', 'similarity_score_threshold', "
264+
"'mmr' or 'traversal'."
265+
)
266+
267+
def as_retriever(self, **kwargs: Any) -> "KnowledgeStoreRetriever":
268+
"""Return KnowledgeStoreRetriever initialized from this KnowledgeStore.
269+
270+
Args:
271+
search_type (Optional[str]): Defines the type of search that
272+
the Retriever should perform.
273+
Can be "traversal" (default), "similarity", "mmr", or
274+
"similarity_score_threshold".
275+
search_kwargs (Optional[Dict]): Keyword arguments to pass to the
276+
search function. Can include things like:
277+
k: Amount of documents to return (Default: 4)
278+
depth: The maximum depth of edges to traverse (Default: 1)
279+
score_threshold: Minimum relevance threshold
280+
for similarity_score_threshold
281+
fetch_k: Amount of documents to pass to MMR algorithm (Default: 20)
282+
lambda_mult: Diversity of results returned by MMR;
283+
1 for minimum diversity and 0 for maximum. (Default: 0.5)
284+
Returns:
285+
Retriever for this KnowledgeStore.
286+
287+
Examples:
288+
289+
.. code-block:: python
290+
291+
# Retrieve documents traversing edges
292+
docsearch.as_retriever(
293+
search_type="traversal",
294+
search_kwargs={'k': 6, 'depth': 3}
295+
)
296+
297+
# Retrieve more documents with higher diversity
298+
# Useful if your dataset has many similar documents
299+
docsearch.as_retriever(
300+
search_type="mmr",
301+
search_kwargs={'k': 6, 'lambda_mult': 0.25}
302+
)
303+
304+
# Fetch more documents for the MMR algorithm to consider
305+
# But only return the top 5
306+
docsearch.as_retriever(
307+
search_type="mmr",
308+
search_kwargs={'k': 5, 'fetch_k': 50}
309+
)
310+
311+
# Only retrieve documents that have a relevance score
312+
# Above a certain threshold
313+
docsearch.as_retriever(
314+
search_type="similarity_score_threshold",
315+
search_kwargs={'score_threshold': 0.8}
316+
)
317+
318+
# Only get the single most similar document from the dataset
319+
docsearch.as_retriever(search_kwargs={'k': 1})
320+
321+
"""
322+
return KnowledgeStoreRetriever(vectorstore=self, **kwargs)
323+
324+
325+
class KnowledgeStoreRetriever(VectorStoreRetriever):
326+
"""Retriever class for KnowledgeStore."""
327+
328+
vectorstore: KnowledgeStore
329+
"""KnowledgeStore to use for retrieval."""
330+
search_type: str = "traversal"
331+
"""Type of search to perform. Defaults to "traversal"."""
332+
allowed_search_types: ClassVar[Collection[str]] = (
333+
"similarity",
334+
"similarity_score_threshold",
335+
"mmr",
336+
"traversal",
337+
)
338+
339+
def _get_relevant_documents(
340+
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
341+
) -> List[Document]:
342+
if self.search_type == "traversal":
343+
return list(
344+
self.vectorstore.traversing_retrieve(query, **self.search_kwargs)
345+
)
346+
else:
347+
return super()._get_relevant_documents(query, run_manager=run_manager)
348+
349+
async def _aget_relevant_documents(
350+
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
351+
) -> List[Document]:
352+
if self.search_type == "traversal":
353+
return [
354+
doc
355+
async for doc in self.vectorstore.atraversing_retrieve(
356+
query, **self.search_kwargs
357+
)
358+
]
359+
else:
360+
return await super()._aget_relevant_documents(
361+
query, run_manager=run_manager
362+
)

0 commit comments

Comments
 (0)