Skip to content

Commit ad84239

Browse files
authored
feat: mmr traversal starting in neighborhood (#634)
* feat: mmr traversal starting in neighborhood This adds a `neighborhood` parameter to the mmr_traversal, and starts the search with the best candidates adjacent to the provided `neighborhood`.
1 parent f10a2cc commit ad84239

File tree

2 files changed

+91
-31
lines changed

2 files changed

+91
-31
lines changed

libs/knowledge-store/ragstack_knowledge_store/graph_store.py

Lines changed: 75 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -217,14 +217,6 @@ def __init__(
217217
""" # noqa: S608
218218
)
219219

220-
self._query_source_tags_by_id = session.prepare(
221-
f"""
222-
SELECT link_to_tags
223-
FROM {keyspace}.{node_table}
224-
WHERE content_id = ?
225-
""" # noqa: S608
226-
)
227-
228220
def table_name(self) -> str:
229221
"""Returns the fully qualified table name."""
230222
return f"{self._keyspace}.{self._node_table}"
@@ -364,6 +356,7 @@ def mmr_traversal_search(
364356
self,
365357
query: str,
366358
*,
359+
initial_roots: Sequence[str] = (),
367360
k: int = 4,
368361
depth: int = 2,
369362
fetch_k: int = 100,
@@ -384,8 +377,13 @@ def mmr_traversal_search(
384377
385378
Args:
386379
query: The query string to search for.
380+
initial_roots: Optional list of document IDs to use for initializing search.
381+
The top `adjacent_k` nodes adjacent to each initial root will be
382+
included in the set of initial candidates. To fetch only in the
383+
neighborhood of these nodes, set `ftech_k = 0`.
387384
k: Number of Documents to return. Defaults to 4.
388385
fetch_k: Number of initial Documents to fetch via similarity.
386+
Will be added to the nodes adjacent to `initial_roots`.
389387
Defaults to 100.
390388
adjacent_k: Number of adjacent Documents to fetch.
391389
Defaults to 10.
@@ -406,19 +404,12 @@ def mmr_traversal_search(
406404
score_threshold=score_threshold,
407405
)
408406

409-
# For each unvisited node, stores the outgoing tags.
407+
# For each unselected node, stores the outgoing tags.
410408
outgoing_tags: dict[str, set[tuple[str, str]]] = {}
411409

412410
# Fetch the initial candidates and add them to the helper and
413411
# outgoing_tags.
414412
columns = "content_id, text_embedding, link_to_tags"
415-
initial_candidates_query = self._get_search_cql(
416-
has_limit=True,
417-
columns=columns,
418-
metadata_keys=list(metadata_filter.keys()),
419-
has_embedding=True,
420-
)
421-
422413
adjacent_query = self._get_search_cql(
423414
has_limit=True,
424415
columns=columns,
@@ -427,7 +418,46 @@ def mmr_traversal_search(
427418
has_link_from_tags=True,
428419
)
429420

421+
visited_tags: set[tuple[str, str]] = set()
422+
423+
def fetch_neighborhood(neighborhood: Sequence[str]) -> None:
424+
# Put the neighborhood into the outgoing tags, to avoid adding it
425+
# to the candidate set in the future.
426+
outgoing_tags.update({content_id: set() for content_id in neighborhood})
427+
428+
# Initialize the visited_tags with the set of outgoing from the
429+
# neighborhood. This prevents re-visiting them.
430+
visited_tags = self._get_outgoing_tags(neighborhood)
431+
432+
# Call `self._get_adjacent` to fetch the candidates.
433+
adjacents = self._get_adjacent(
434+
visited_tags,
435+
adjacent_query=adjacent_query,
436+
query_embedding=query_embedding,
437+
k_per_tag=adjacent_k,
438+
metadata_filter=metadata_filter,
439+
)
440+
441+
new_candidates = {}
442+
for adjacent in adjacents:
443+
if adjacent.target_content_id not in outgoing_tags:
444+
outgoing_tags[adjacent.target_content_id] = (
445+
adjacent.target_link_to_tags
446+
)
447+
448+
new_candidates[adjacent.target_content_id] = (
449+
adjacent.target_text_embedding
450+
)
451+
helper.add_candidates(new_candidates)
452+
430453
def fetch_initial_candidates() -> None:
454+
initial_candidates_query = self._get_search_cql(
455+
has_limit=True,
456+
columns=columns,
457+
metadata_keys=list(metadata_filter.keys()),
458+
has_embedding=True,
459+
)
460+
431461
params = self._get_search_params(
432462
limit=fetch_k,
433463
metadata=metadata_filter,
@@ -439,16 +469,20 @@ def fetch_initial_candidates() -> None:
439469
)
440470
candidates = {}
441471
for row in fetched:
442-
candidates[row.content_id] = row.text_embedding
443-
outgoing_tags[row.content_id] = set(row.link_to_tags or [])
472+
if row.content_id not in outgoing_tags:
473+
candidates[row.content_id] = row.text_embedding
474+
outgoing_tags[row.content_id] = set(row.link_to_tags or [])
444475
helper.add_candidates(candidates)
445476

446-
fetch_initial_candidates()
477+
if initial_roots:
478+
fetch_neighborhood(initial_roots)
479+
if fetch_k > 0:
480+
fetch_initial_candidates()
447481

448-
# Select the best item, K times.
482+
# Tracks the depth of each candidate.
449483
depths = {candidate_id: 0 for candidate_id in helper.candidate_ids()}
450-
visited_tags: set[tuple[str, str]] = set()
451484

485+
# Select the best item, K times.
452486
for _ in range(k):
453487
selected_id = helper.pop_best()
454488

@@ -683,12 +717,15 @@ def _get_outgoing_tags(
683717

684718
def add_sources(rows: Iterable[Any]) -> None:
685719
for row in rows:
686-
tags.update(row.link_to_tags)
720+
if row.link_to_tags:
721+
tags.update(row.link_to_tags)
687722

688723
with self._concurrent_queries() as cq:
689724
for source_id in source_ids:
690725
cq.execute(
691-
self._query_source_tags_by_id, (source_id,), callback=add_sources
726+
self._query_ids_and_link_to_tags_by_id,
727+
(source_id,),
728+
callback=add_sources,
692729
)
693730

694731
return tags
@@ -699,7 +736,7 @@ def _get_adjacent(
699736
adjacent_query: PreparedStatement,
700737
query_embedding: list[float],
701738
k_per_tag: int | None = None,
702-
metadata_filter: dict[str, Any] = {}, # noqa: B006
739+
metadata_filter: dict[str, Any] | None = None,
703740
) -> Iterable[_Edge]:
704741
"""Return the target nodes with incoming links from any of the given tags.
705742
@@ -809,11 +846,15 @@ def _coerce_string(value: Any) -> str:
809846

810847
def _extract_where_clause_cql(
811848
self,
812-
metadata_keys: list[str] = [], # noqa: B006
849+
has_id: bool = False,
850+
metadata_keys: Sequence[str] = (),
813851
has_link_from_tags: bool = False,
814852
) -> str:
815853
wc_blocks: list[str] = []
816854

855+
if has_id:
856+
wc_blocks.append("content_id == ?")
857+
817858
if has_link_from_tags:
818859
wc_blocks.append("link_from_tags CONTAINS (?, ?)")
819860

@@ -855,12 +896,15 @@ def _get_search_cql(
855896
self,
856897
has_limit: bool = False,
857898
columns: str | None = CONTENT_COLUMNS,
858-
metadata_keys: list[str] = [], # noqa: B006
899+
metadata_keys: Sequence[str] = (),
900+
has_id: bool = False,
859901
has_embedding: bool = False,
860902
has_link_from_tags: bool = False,
861903
) -> PreparedStatement:
862904
where_clause = self._extract_where_clause_cql(
863-
metadata_keys=metadata_keys, has_link_from_tags=has_link_from_tags
905+
has_id=has_id,
906+
metadata_keys=metadata_keys,
907+
has_link_from_tags=has_link_from_tags,
864908
)
865909
limit_clause = " LIMIT ?" if has_limit else ""
866910
order_clause = " ORDER BY text_embedding ANN OF ?" if has_embedding else ""
@@ -885,12 +929,12 @@ def _get_search_cql(
885929
def _get_search_params(
886930
self,
887931
limit: int | None = None,
888-
metadata: dict[str, Any] = {}, # noqa: B006
932+
metadata: dict[str, Any] | None = None,
889933
embedding: list[float] | None = None,
890934
link_from_tags: tuple[str, str] | None = None,
891935
) -> tuple[PreparedStatement, tuple[Any, ...]]:
892936
where_params = self._extract_where_clause_params(
893-
metadata=metadata, link_from_tags=link_from_tags
937+
metadata=metadata or {}, link_from_tags=link_from_tags
894938
)
895939

896940
limit_params = [limit] if limit is not None else []
@@ -902,14 +946,14 @@ def _get_search_cql_and_params(
902946
self,
903947
limit: int | None = None,
904948
columns: str | None = CONTENT_COLUMNS,
905-
metadata: dict[str, Any] = {}, # noqa: B006
949+
metadata: dict[str, Any] | None = None,
906950
embedding: list[float] | None = None,
907951
link_from_tags: tuple[str, str] | None = None,
908952
) -> tuple[PreparedStatement, tuple[Any, ...]]:
909953
query = self._get_search_cql(
910954
has_limit=limit is not None,
911955
columns=columns,
912-
metadata_keys=list(metadata.keys()),
956+
metadata_keys=list(metadata.keys()) if metadata else (),
913957
has_embedding=embedding is not None,
914958
has_link_from_tags=link_from_tags is not None,
915959
)

libs/knowledge-store/tests/integration_tests/test_graph_store.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,22 @@ def test_mmr_traversal(
195195
results = gs.mmr_traversal_search("0.0", k=4, metadata_filter={"even": True})
196196
assert _result_ids(results) == ["v0", "v2"]
197197

198+
# with initial_roots=[v0], we should start traversal there. this means that
199+
# the initial candidates are `v2`,`v3`. `v1` is unreachable and not
200+
# included.
201+
results = gs.mmr_traversal_search("0.0", fetch_k=0, k=4, initial_roots=["v0"])
202+
assert _result_ids(results) == ["v2", "v3"]
203+
204+
# with initial_roots=[v1], we should start traversal there.
205+
# there are no adjacent nodes, so there are no results.
206+
results = gs.mmr_traversal_search("0.0", fetch_k=0, k=4, initial_roots=["v1"])
207+
assert _result_ids(results) == []
208+
209+
# with initial_roots=[v0] and `fetch_k > 0` we should be able to reach everything.
210+
# but we don't re-fetch `v0`.
211+
results = gs.mmr_traversal_search("0.0", fetch_k=2, k=4, initial_roots=["v0"])
212+
assert _result_ids(results) == ["v1", "v3", "v2"]
213+
198214

199215
def test_write_retrieve_keywords(
200216
graph_store_factory: Callable[[MetadataIndexingType], GraphStore],

0 commit comments

Comments
 (0)