Skip to content

Commit cd858b5

Browse files
authored
ref: Only prepare once during MMR traversal (#613)
* ref: Only prepare once during MMR traversal
1 parent 9c47dc5 commit cd858b5

File tree

1 file changed

+18
-20
lines changed

1 file changed

+18
-20
lines changed

libs/knowledge-store/ragstack_knowledge_store/graph_store.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -414,11 +414,20 @@ def mmr_traversal_search(
414414

415415
# Fetch the initial candidates and add them to the helper and
416416
# outgoing_tags.
417+
columns = "content_id, text_embedding, link_to_tags"
417418
initial_candidates_query = self._get_search_cql(
418419
has_limit=True,
419-
columns="content_id, text_embedding, link_to_tags",
420+
columns=columns,
421+
metadata_keys=list(metadata_filter.keys()),
422+
has_embedding=True,
423+
)
424+
425+
adjacent_query = self._get_search_cql(
426+
has_limit=True,
427+
columns=columns,
420428
metadata_keys=list(metadata_filter.keys()),
421429
has_embedding=True,
430+
has_link_from_tags=True,
422431
)
423432

424433
def fetch_initial_candidates() -> None:
@@ -467,6 +476,7 @@ def fetch_initial_candidates() -> None:
467476
# Find the nodes with incoming links from those tags.
468477
adjacents = self._get_adjacent(
469478
link_to_tags,
479+
adjacent_query=adjacent_query,
470480
query_embedding=query_embedding,
471481
k_per_tag=adjacent_k,
472482
metadata_filter=metadata_filter,
@@ -689,6 +699,7 @@ def add_sources(rows: Iterable[Any]) -> None:
689699
def _get_adjacent(
690700
self,
691701
tags: Set[Tuple[str, str]],
702+
adjacent_query: PreparedStatement,
692703
query_embedding: List[float],
693704
k_per_tag: Optional[int] = None,
694705
metadata_filter: Dict[str, Any] = {}, # noqa: B006
@@ -697,6 +708,7 @@ def _get_adjacent(
697708
698709
Args:
699710
tags: The tags to look for links *from*.
711+
adjacent_query: Prepared query for adjacent nodes.
700712
query_embedding: The query embedding. Used to rank target nodes.
701713
k_per_tag: The number of target nodes to fetch for each outgoing tag.
702714
metadata_filter: Optional metadata to filter the results.
@@ -706,31 +718,17 @@ def _get_adjacent(
706718
"""
707719
targets: Dict[str, _Edge] = {}
708720

709-
columns = """
710-
content_id AS target_content_id,
711-
text_embedding AS target_text_embedding,
712-
link_to_tags AS target_link_to_tags
713-
"""
714-
715-
adjacent_query = self._get_search_cql(
716-
has_limit=True,
717-
columns=columns,
718-
metadata_keys=list(metadata_filter.keys()),
719-
has_embedding=True,
720-
has_link_from_tags=True,
721-
)
722-
723721
def add_targets(rows: Iterable[Any]) -> None:
724722
# TODO: Figure out how to use the "kind" on the edge.
725723
# This is tricky, since we currently issue one query for anything
726724
# adjacent via any kind, and we don't have enough information to
727725
# determine which kind(s) a given target was reached from.
728726
for row in rows:
729-
if row.target_content_id not in targets:
730-
targets[row.target_content_id] = _Edge(
731-
target_content_id=row.target_content_id,
732-
target_text_embedding=row.target_text_embedding,
733-
target_link_to_tags=set(row.target_link_to_tags or []),
727+
if row.content_id not in targets:
728+
targets[row.content_id] = _Edge(
729+
target_content_id=row.content_id,
730+
target_text_embedding=row.text_embedding,
731+
target_link_to_tags=set(row.link_to_tags or []),
734732
)
735733

736734
with self._concurrent_queries() as cq:

0 commit comments

Comments
 (0)