Skip to content

Commit 507d1c0

Browse files
authored
feat: Reduce sequential queries for MMR (#586)
* feat: Reduce sequential queries for MMR This denormalizes the outgoing tags into the `links` table, so that traversal doesn't need to fetch the set of outgoing tags in an initial query. This means that traversing is one batch of concurrent queries, rather than two steps in sequence. This lowers average retrieval time in 6 experiments from 1.3364s to 1.0920s (would increase as more nodes are retrieved). Writing the additional copies of the denormalized tags adds a little to indexing time.
1 parent 4b940f0 commit 507d1c0

File tree

1 file changed

+103
-61
lines changed

1 file changed

+103
-61
lines changed

libs/knowledge-store/ragstack_knowledge_store/graph_store.py

Lines changed: 103 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _row_to_node(row: Any) -> Node:
107107
class _Edge:
108108
target_content_id: str
109109
target_text_embedding: List[float]
110+
target_link_to_tags: Set[Tuple[str, str]]
110111

111112

112113
class GraphStore:
@@ -170,8 +171,8 @@ def __init__(
170171
self._insert_tag = session.prepare(
171172
f"""
172173
INSERT INTO {keyspace}.{targets_table} (
173-
target_content_id, kind, tag, target_text_embedding
174-
) VALUES (?, ?, ?, ?)
174+
target_content_id, kind, tag, target_text_embedding, target_link_to_tags
175+
) VALUES (?, ?, ?, ?, ?)
175176
""" # noqa: S608
176177
)
177178

@@ -215,7 +216,7 @@ def __init__(
215216

216217
self._query_ids_and_embedding_by_embedding = session.prepare(
217218
f"""
218-
SELECT content_id, text_embedding
219+
SELECT content_id, text_embedding, link_to_tags
219220
FROM {keyspace}.{node_table}
220221
ORDER BY text_embedding ANN OF ?
221222
LIMIT ?
@@ -235,7 +236,7 @@ def __init__(
235236

236237
self._query_targets_embeddings_by_kind_and_tag_and_embedding = session.prepare(
237238
f"""
238-
SELECT target_content_id, target_text_embedding, tag
239+
SELECT target_content_id, target_text_embedding, tag, target_link_to_tags
239240
FROM {keyspace}.{targets_table}
240241
WHERE kind = ? AND tag = ?
241242
ORDER BY target_text_embedding ANN of ?
@@ -279,6 +280,7 @@ def _apply_schema(self) -> None:
279280
-- text_embedding of target node. allows MMR to be applied without
280281
-- fetching nodes.
281282
target_text_embedding VECTOR<FLOAT, {embedding_dim}>,
283+
target_link_to_tags SET<TUPLE<TEXT, TEXT>>,
282284
283285
PRIMARY KEY ((kind, tag), target_content_id)
284286
)
@@ -354,7 +356,7 @@ def add_nodes(
354356
for kind, value in link_from_tags:
355357
cq.execute(
356358
self._insert_tag,
357-
parameters=(node_id, kind, value, text_embedding),
359+
parameters=(node_id, kind, value, text_embedding, link_to_tags),
358360
)
359361

360362
return node_ids
@@ -433,16 +435,28 @@ def mmr_traversal_search(
433435
score_threshold=score_threshold,
434436
)
435437

436-
# Fetch the initial candidates and add them to the helper.
437-
fetched = self._session.execute(
438-
self._query_ids_and_embedding_by_embedding,
439-
(query_embedding, fetch_k),
440-
)
441-
helper.add_candidates({row.content_id: row.text_embedding for row in fetched})
438+
# For each unvisited node, stores the outgoing tags.
439+
outgoing_tags: Dict[str, Set[Tuple[str, str]]] = {}
440+
441+
# Fetch the initial candidates and add them to the helper and
442+
# outgoing_tags.
443+
def fetch_initial_candidates() -> None:
444+
fetched = self._session.execute(
445+
self._query_ids_and_embedding_by_embedding,
446+
(query_embedding, fetch_k),
447+
)
448+
candidates = {}
449+
for row in fetched:
450+
candidates[row.content_id] = row.text_embedding
451+
outgoing_tags[row.content_id] = set(row.link_to_tags or [])
452+
helper.add_candidates(candidates)
453+
454+
fetch_initial_candidates()
442455

443456
# Select the best item, K times.
444457
depths = {candidate_id: 0 for candidate_id in helper.candidate_ids()}
445458
visited_tags: Set[Tuple[str, str]] = set()
459+
446460
for _ in range(k):
447461
selected_id = helper.pop_best()
448462

@@ -457,30 +471,46 @@ def mmr_traversal_search(
457471
# TODO: For a big performance win, we should track which tags we've
458472
# already incorporated. We don't need to issue adjacent queries for
459473
# those.
474+
475+
# Find the tags linked to from the selected ID.
476+
link_to_tags = outgoing_tags.pop(selected_id)
477+
478+
# Don't re-visit already visited tags.
479+
link_to_tags.difference_update(visited_tags)
480+
481+
# Find the nodes with incoming links from those tags.
460482
adjacents = self._get_adjacent(
461-
[selected_id],
462-
visited_tags=visited_tags,
483+
link_to_tags,
463484
query_embedding=query_embedding,
464485
k_per_tag=adjacent_k,
465486
)
466487

488+
# Record the link_to_tags as visited.
489+
visited_tags.update(link_to_tags)
490+
467491
new_candidates = {}
468492
for adjacent in adjacents:
469-
new_candidates[adjacent.target_content_id] = (
470-
adjacent.target_text_embedding
471-
)
472-
if next_depth < depths.get(adjacent.target_content_id, depth + 1):
473-
# If this is a new shortest depth, or there was no
474-
# previous depth, update the depths. This ensures that
475-
# when we discover a node we will have the shortest
476-
# depth available.
477-
#
478-
# NOTE: No effort is made to traverse from nodes that
479-
# were previously selected if they become reachable via
480-
# a shorter path via nodes selected later. This is
481-
# currently "intended", but may be worth experimenting
482-
# with.
483-
depths[adjacent.target_content_id] = next_depth
493+
if adjacent.target_content_id not in outgoing_tags:
494+
outgoing_tags[adjacent.target_content_id] = (
495+
adjacent.target_link_to_tags
496+
)
497+
new_candidates[adjacent.target_content_id] = (
498+
adjacent.target_text_embedding
499+
)
500+
if next_depth < depths.get(
501+
adjacent.target_content_id, depth + 1
502+
):
503+
# If this is a new shortest depth, or there was no
504+
# previous depth, update the depths. This ensures that
505+
# when we discover a node we will have the shortest
506+
# depth available.
507+
#
508+
# NOTE: No effort is made to traverse from nodes that
509+
# were previously selected if they become reachable via
510+
# a shorter path via nodes selected later. This is
511+
# currently "intended", but may be worth experimenting
512+
# with.
513+
depths[adjacent.target_content_id] = next_depth
484514
helper.add_candidates(new_candidates)
485515

486516
return self._nodes_with_ids(helper.selected_ids)
@@ -601,61 +631,73 @@ def similarity_search(
601631
for row in self._session.execute(self._query_by_embedding, (embedding, k)):
602632
yield _row_to_node(row)
603633

604-
def _get_adjacent(
634+
def _get_outgoing_tags(
605635
self,
606636
source_ids: Iterable[str],
607-
visited_tags: Set[Tuple[str, str]],
637+
) -> Set[Tuple[str, str]]:
638+
"""Return the set of outgoing tags for the given source ID(s).
639+
640+
Args:
641+
source_ids: The IDs of the source nodes to retrieve outgoing tags for.
642+
"""
643+
tags = set()
644+
645+
def add_sources(rows: Iterable[Any]) -> None:
646+
for row in rows:
647+
tags.update(row.link_to_tags)
648+
649+
with self._concurrent_queries() as cq:
650+
for source_id in source_ids:
651+
cq.execute(
652+
self._query_source_tags_by_id, (source_id,), callback=add_sources
653+
)
654+
655+
return tags
656+
657+
def _get_adjacent(
658+
self,
659+
tags: Set[Tuple[str, str]],
608660
query_embedding: List[float],
609661
k_per_tag: Optional[int] = None,
610662
) -> Iterable[_Edge]:
611-
"""Return the target nodes adjacent to any of the source nodes.
663+
"""Return the target nodes with incoming links from any of the given tags.
612664
613665
Args:
614-
source_ids: The source IDs to start from when retrieving adjacent nodes.
615-
visited_tags: Tags we've already visited.
666+
tags: The tags to look for links *from*.
616667
query_embedding: The query embedding. Used to rank target nodes.
617668
k_per_tag: The number of target nodes to fetch for each outgoing tag.
618669
619670
Returns:
620671
List of adjacent edges.
621672
"""
622-
targets: Dict[str, List[float]] = {}
623-
624-
def add_sources(rows: Iterable[Any]) -> None:
625-
for row in rows:
626-
for new_tag in row.link_to_tags or []:
627-
if new_tag not in visited_tags:
628-
visited_tags.add(new_tag)
629-
630-
cq.execute(
631-
self._query_targets_embeddings_by_kind_and_tag_and_embedding,
632-
parameters=(
633-
new_tag[0],
634-
new_tag[1],
635-
query_embedding,
636-
k_per_tag or 10,
637-
),
638-
callback=add_targets,
639-
)
673+
targets: Dict[str, _Edge] = {}
640674

641675
def add_targets(rows: Iterable[Any]) -> None:
642676
# TODO: Figure out how to use the "kind" on the edge.
643677
# This is tricky, since we currently issue one query for anything
644678
# adjacent via any kind, and we don't have enough information to
645679
# determine which kind(s) a given target was reached from.
646680
for row in rows:
647-
targets.setdefault(row.target_content_id, row.target_text_embedding)
681+
if row.target_content_id not in targets:
682+
targets[row.target_content_id] = _Edge(
683+
target_content_id=row.target_content_id,
684+
target_text_embedding=row.target_text_embedding,
685+
target_link_to_tags=set(row.target_link_to_tags or []),
686+
)
648687

649688
with self._concurrent_queries() as cq:
650-
# TODO: We could eliminate this query by storing the source tags of the
651-
# target node in the targets table.
652-
for source_id in source_ids:
689+
for kind, value in tags:
653690
cq.execute(
654-
self._query_source_tags_by_id, (source_id,), callback=add_sources
691+
self._query_targets_embeddings_by_kind_and_tag_and_embedding,
692+
parameters=(
693+
kind,
694+
value,
695+
query_embedding,
696+
k_per_tag or 10,
697+
),
698+
callback=add_targets,
655699
)
656700

657-
# TODO: Consider a combined limit based on the similarity and/or predicated MMR score? # noqa: E501
658-
return [
659-
_Edge(target_content_id=content_id, target_text_embedding=embedding)
660-
for (content_id, embedding) in targets.items()
661-
]
701+
# TODO: Consider a combined limit based on the similarity and/or
702+
# predicated MMR score?
703+
return targets.values()

0 commit comments

Comments
 (0)