Skip to content

Commit 2cfedc4

Browse files
authored
[knowledge-store] feat: Dynamic, MMR-based traversal (#460)
This required denormalizing the `text_embedding` of target nodes into the edge, making it easier to guide traversal based on distance to a query. Also fixed a bug in keyword linking which didn't create links to old nodes with the given keyword.
1 parent d5b2178 commit 2cfedc4

File tree

11 files changed

+651
-181
lines changed

11 files changed

+651
-181
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .cassandra import CassandraKnowledgeStore
21
from .base import KnowledgeStore
2+
from .cassandra import CassandraKnowledgeStore
33

44
__all__ = ["CassandraKnowledgeStore", "KnowledgeStore"]

libs/knowledge-store/ragstack_knowledge_store/_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import sys
2+
13
try:
24
# Try importing the function from itertools (Python 3.12+)
35
from itertools import batched
@@ -16,3 +18,39 @@ def batched(iterable: Iterable[T], n: int) -> Iterator[Iterator[T]]:
1618
it = iter(iterable)
1719
while batch := tuple(islice(it, n)):
1820
yield batch
21+
22+
# TODO: Remove the "polyfill" when we required python is >= 3.10.
23+
24+
if sys.version_info >= (3, 10):
25+
26+
def strict_zip(*iterables):
27+
return zip(*iterables, strict=True)
28+
else:
29+
30+
def strict_zip(*iterables):
31+
# Custom implementation for Python versions older than 3.10
32+
if not iterables:
33+
return
34+
35+
iterators = tuple(iter(iterable) for iterable in iterables)
36+
try:
37+
while True:
38+
items = []
39+
for iterator in iterators:
40+
items.append(next(iterator))
41+
yield tuple(items)
42+
except StopIteration:
43+
pass
44+
45+
if items:
46+
i = len(items)
47+
plural = " " if i == 1 else "s 1-"
48+
msg = f"strict_zip() argument {i+1} is shorter than argument{plural}{i}"
49+
raise ValueError(msg)
50+
51+
sentinel = object()
52+
for i, iterator in enumerate(iterators[1:], 1):
53+
if next(iterator, sentinel) is not sentinel:
54+
plural = " " if i == 1 else "s 1-"
55+
msg = f"strict_zip() argument {i+1} is longer than argument{plural}{i}"
56+
raise ValueError(msg)

libs/knowledge-store/ragstack_knowledge_store/base.py

Lines changed: 110 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class Node(Serializable):
3636
id: Optional[str]
3737
"""Unique ID for the node. Shall be generated by the KnowledgeStore if not set"""
3838
metadata: dict = Field(default_factory=dict)
39-
"""Metadata for the node. May contain information used to link this node
39+
"""Metadata for the node. May contain information used to link this node
4040
with other nodes."""
4141

4242

@@ -164,7 +164,7 @@ async def aadd_documents(
164164
return await self.aadd_nodes(nodes, **kwargs)
165165

166166
@abstractmethod
167-
def traversing_retrieve(
167+
def traversal_search(
168168
self,
169169
query: str,
170170
*,
@@ -187,7 +187,7 @@ def traversing_retrieve(
187187
Retrieved documents.
188188
"""
189189

190-
async def atraversing_retrieve(
190+
async def atraversal_search(
191191
self,
192192
query: str,
193193
*,
@@ -210,42 +210,121 @@ async def atraversing_retrieve(
210210
Retrieved documents.
211211
"""
212212
for doc in await run_in_executor(
213-
None, self.traversing_retrieve, query, k=k, depth=depth, **kwargs
213+
None, self.traversal_search, query, k=k, depth=depth, **kwargs
214214
):
215215
yield doc
216216

217-
def similarity_search(
218-
self, query: str, k: int = 4, **kwargs: Any
219-
) -> List[Document]:
220-
return list(self.traversing_retrieve(query, k=k, depth=0))
217+
@abstractmethod
218+
def mmr_traversal_search(
219+
self,
220+
query: str,
221+
*,
222+
k: int = 4,
223+
depth: int = 2,
224+
fetch_k: int = 100,
225+
lambda_mult: float = 0.5,
226+
score_threshold: float = 0.0,
227+
**kwargs: Any,
228+
) -> Iterable[Document]:
229+
"""Retrieve documents from this knowledge store using MMR-traversal.
221230
222-
async def asimilarity_search(
223-
self, query: str, k: int = 4, **kwargs: Any
224-
) -> List[Document]:
225-
return [doc async for doc in self.atraversing_retrieve(query, k=k, depth=0)]
231+
This strategy first retrieves the top `fetch_k` results by similarity to
232+
the question. It then selects the top `k` results based on
233+
maximum-marginal relevance using the given `lambda_mult`.
234+
235+
At each step, it considers the (remaining) documents from `fetch_k` as
236+
well as any documents connected by edges to a selected document
237+
retrieved based on similarity (a "root").
238+
239+
Args:
240+
query: The query string to search for.
241+
k: Number of Documents to return. Defaults to 4.
242+
fetch_k: Number of Documents to fetch via similarity.
243+
Defaults to 10.
244+
depth: Maximum depth of a node (number of edges) from a node
245+
retrieved via similarity. Defaults to 2.
246+
lambda_mult: Number between 0 and 1 that determines the degree
247+
of diversity among the results with 0 corresponding to maximum
248+
diversity and 1 to minimum diversity. Defaults to 0.5.
249+
score_threshold: Only documents with a score greater than or equal
250+
this threshold will be chosen. Defaults to 0.0 so all scores are
251+
taken.
252+
"""
253+
254+
async def ammr_traversal_search(
255+
self,
256+
query: str,
257+
*,
258+
k: int = 4,
259+
depth: int = 2,
260+
fetch_k: int = 100,
261+
lambda_mult: float = 0.5,
262+
score_threshold: float = 0.0,
263+
**kwargs: Any,
264+
) -> AsyncIterable[Document]:
265+
"""Retrieve documents from this knowledge store using MMR-traversal.
266+
267+
This strategy first retrieves the top `fetch_k` results by similarity to
268+
the question. It then selects the top `k` results based on
269+
maximum-marginal relevance using the given `lambda_mult`.
270+
271+
At each step, it considers the (remaining) documents from `fetch_k` as
272+
well as any documents connected by edges to a selected document
273+
retrieved based on similarity (a "root").
274+
275+
Args:
276+
query: The query string to search for.
277+
k: Number of Documents to return. Defaults to 4.
278+
fetch_k: Number of Documents to fetch via similarity.
279+
Defaults to 10.
280+
depth: Maximum depth of a node (number of edges) from a node
281+
retrieved via similarity. Defaults to 2.
282+
lambda_mult: Number between 0 and 1 that determines the degree
283+
of diversity among the results with 0 corresponding to maximum
284+
diversity and 1 to minimum diversity. Defaults to 0.5.
285+
score_threshold: Only documents with a score greater than or equal
286+
this threshold will be chosen. Defaults to 0.0 so all scores are
287+
taken.
288+
"""
289+
for doc in await run_in_executor(
290+
None,
291+
self.traversal_search,
292+
query,
293+
k=k,
294+
fetch_k=fetch_k,
295+
depth=depth,
296+
lambda_mult=lambda_mult,
297+
score_threshold=score_threshold,
298+
**kwargs,
299+
):
300+
yield doc
301+
302+
def similarity_search(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
303+
return list(self.traversal_search(query, k=k, depth=0))
304+
305+
async def asimilarity_search(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
306+
return [doc async for doc in self.atraversal_search(query, k=k, depth=0)]
226307

227308
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
228309
if search_type == "similarity":
229310
return self.similarity_search(query, **kwargs)
230311
elif search_type == "similarity_score_threshold":
231-
docs_and_similarities = self.similarity_search_with_relevance_scores(
232-
query, **kwargs
233-
)
312+
docs_and_similarities = self.similarity_search_with_relevance_scores(query, **kwargs)
234313
return [doc for doc, _ in docs_and_similarities]
235314
elif search_type == "mmr":
236315
return self.max_marginal_relevance_search(query, **kwargs)
237316
elif search_type == "traversal":
238-
return list(self.traversing_retrieve(query, **kwargs))
317+
return list(self.traversal_search(query, **kwargs))
318+
elif search_type == "mmr_traversal":
319+
return list(self.mmr_traversal_search(query, **kwargs))
239320
else:
240321
raise ValueError(
241322
f"search_type of {search_type} not allowed. Expected "
242323
"search_type to be 'similarity', 'similarity_score_threshold', "
243324
"'mmr' or 'traversal'."
244325
)
245326

246-
async def asearch(
247-
self, query: str, search_type: str, **kwargs: Any
248-
) -> List[Document]:
327+
async def asearch(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
249328
if search_type == "similarity":
250329
return await self.asimilarity_search(query, **kwargs)
251330
elif search_type == "similarity_score_threshold":
@@ -256,7 +335,7 @@ async def asearch(
256335
elif search_type == "mmr":
257336
return await self.amax_marginal_relevance_search(query, **kwargs)
258337
elif search_type == "traversal":
259-
return [doc async for doc in self.atraversing_retrieve(query, **kwargs)]
338+
return [doc async for doc in self.atraversal_search(query, **kwargs)]
260339
else:
261340
raise ValueError(
262341
f"search_type of {search_type} not allowed. Expected "
@@ -334,15 +413,16 @@ class KnowledgeStoreRetriever(VectorStoreRetriever):
334413
"similarity_score_threshold",
335414
"mmr",
336415
"traversal",
416+
"mmr_traversal",
337417
)
338418

339419
def _get_relevant_documents(
340420
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
341421
) -> List[Document]:
342422
if self.search_type == "traversal":
343-
return list(
344-
self.vectorstore.traversing_retrieve(query, **self.search_kwargs)
345-
)
423+
return list(self.vectorstore.traversal_search(query, **self.search_kwargs))
424+
elif self.search_type == "mmr_traversal":
425+
return list(self.vectorstore.traversal_search(query, **self.search_kwargs))
346426
else:
347427
return super()._get_relevant_documents(query, run_manager=run_manager)
348428

@@ -352,11 +432,12 @@ async def _aget_relevant_documents(
352432
if self.search_type == "traversal":
353433
return [
354434
doc
355-
async for doc in self.vectorstore.atraversing_retrieve(
356-
query, **self.search_kwargs
357-
)
435+
async for doc in self.vectorstore.atraversal_search(query, **self.search_kwargs)
436+
]
437+
elif self.search_type == "mmr_traversal":
438+
return [
439+
doc
440+
async for doc in self.vectorstore.ammr_traversal_search(query, **self.search_kwargs)
358441
]
359442
else:
360-
return await super()._aget_relevant_documents(
361-
query, run_manager=run_manager
362-
)
443+
return await super()._aget_relevant_documents(query, run_manager=run_manager)

0 commit comments

Comments
 (0)