@@ -414,11 +414,20 @@ def mmr_traversal_search(
414
414
415
415
# Fetch the initial candidates and add them to the helper and
416
416
# outgoing_tags.
417
+ columns = "content_id, text_embedding, link_to_tags"
417
418
initial_candidates_query = self ._get_search_cql (
418
419
has_limit = True ,
419
- columns = "content_id, text_embedding, link_to_tags" ,
420
+ columns = columns ,
421
+ metadata_keys = list (metadata_filter .keys ()),
422
+ has_embedding = True ,
423
+ )
424
+
425
+ adjacent_query = self ._get_search_cql (
426
+ has_limit = True ,
427
+ columns = columns ,
420
428
metadata_keys = list (metadata_filter .keys ()),
421
429
has_embedding = True ,
430
+ has_link_from_tags = True ,
422
431
)
423
432
424
433
def fetch_initial_candidates () -> None :
@@ -467,6 +476,7 @@ def fetch_initial_candidates() -> None:
467
476
# Find the nodes with incoming links from those tags.
468
477
adjacents = self ._get_adjacent (
469
478
link_to_tags ,
479
+ adjacent_query = adjacent_query ,
470
480
query_embedding = query_embedding ,
471
481
k_per_tag = adjacent_k ,
472
482
metadata_filter = metadata_filter ,
@@ -689,6 +699,7 @@ def add_sources(rows: Iterable[Any]) -> None:
689
699
def _get_adjacent (
690
700
self ,
691
701
tags : Set [Tuple [str , str ]],
702
+ adjacent_query : PreparedStatement ,
692
703
query_embedding : List [float ],
693
704
k_per_tag : Optional [int ] = None ,
694
705
metadata_filter : Dict [str , Any ] = {}, # noqa: B006
@@ -697,6 +708,7 @@ def _get_adjacent(
697
708
698
709
Args:
699
710
tags: The tags to look for links *from*.
711
+ adjacent_query: Prepared query for adjacent nodes.
700
712
query_embedding: The query embedding. Used to rank target nodes.
701
713
k_per_tag: The number of target nodes to fetch for each outgoing tag.
702
714
metadata_filter: Optional metadata to filter the results.
@@ -706,31 +718,17 @@ def _get_adjacent(
706
718
"""
707
719
targets : Dict [str , _Edge ] = {}
708
720
709
- columns = """
710
- content_id AS target_content_id,
711
- text_embedding AS target_text_embedding,
712
- link_to_tags AS target_link_to_tags
713
- """
714
-
715
- adjacent_query = self ._get_search_cql (
716
- has_limit = True ,
717
- columns = columns ,
718
- metadata_keys = list (metadata_filter .keys ()),
719
- has_embedding = True ,
720
- has_link_from_tags = True ,
721
- )
722
-
723
721
def add_targets (rows : Iterable [Any ]) -> None :
724
722
# TODO: Figure out how to use the "kind" on the edge.
725
723
# This is tricky, since we currently issue one query for anything
726
724
# adjacent via any kind, and we don't have enough information to
727
725
# determine which kind(s) a given target was reached from.
728
726
for row in rows :
729
- if row .target_content_id not in targets :
730
- targets [row .target_content_id ] = _Edge (
731
- target_content_id = row .target_content_id ,
732
- target_text_embedding = row .target_text_embedding ,
733
- target_link_to_tags = set (row .target_link_to_tags or []),
727
+ if row .content_id not in targets :
728
+ targets [row .content_id ] = _Edge (
729
+ target_content_id = row .content_id ,
730
+ target_text_embedding = row .text_embedding ,
731
+ target_link_to_tags = set (row .link_to_tags or []),
734
732
)
735
733
736
734
with self ._concurrent_queries () as cq :
0 commit comments