@@ -217,14 +217,6 @@ def __init__(
217
217
""" # noqa: S608
218
218
)
219
219
220
- self ._query_source_tags_by_id = session .prepare (
221
- f"""
222
- SELECT link_to_tags
223
- FROM { keyspace } .{ node_table }
224
- WHERE content_id = ?
225
- """ # noqa: S608
226
- )
227
-
228
220
def table_name (self ) -> str :
229
221
"""Returns the fully qualified table name."""
230
222
return f"{ self ._keyspace } .{ self ._node_table } "
@@ -364,6 +356,7 @@ def mmr_traversal_search(
364
356
self ,
365
357
query : str ,
366
358
* ,
359
+ initial_roots : Sequence [str ] = (),
367
360
k : int = 4 ,
368
361
depth : int = 2 ,
369
362
fetch_k : int = 100 ,
@@ -384,8 +377,13 @@ def mmr_traversal_search(
384
377
385
378
Args:
386
379
query: The query string to search for.
380
+ initial_roots: Optional list of document IDs to use for initializing search.
381
+ The top `adjacent_k` nodes adjacent to each initial root will be
382
+ included in the set of initial candidates. To fetch only in the
383
+ neighborhood of these nodes, set `ftech_k = 0`.
387
384
k: Number of Documents to return. Defaults to 4.
388
385
fetch_k: Number of initial Documents to fetch via similarity.
386
+ Will be added to the nodes adjacent to `initial_roots`.
389
387
Defaults to 100.
390
388
adjacent_k: Number of adjacent Documents to fetch.
391
389
Defaults to 10.
@@ -406,19 +404,12 @@ def mmr_traversal_search(
406
404
score_threshold = score_threshold ,
407
405
)
408
406
409
- # For each unvisited node, stores the outgoing tags.
407
+ # For each unselected node, stores the outgoing tags.
410
408
outgoing_tags : dict [str , set [tuple [str , str ]]] = {}
411
409
412
410
# Fetch the initial candidates and add them to the helper and
413
411
# outgoing_tags.
414
412
columns = "content_id, text_embedding, link_to_tags"
415
- initial_candidates_query = self ._get_search_cql (
416
- has_limit = True ,
417
- columns = columns ,
418
- metadata_keys = list (metadata_filter .keys ()),
419
- has_embedding = True ,
420
- )
421
-
422
413
adjacent_query = self ._get_search_cql (
423
414
has_limit = True ,
424
415
columns = columns ,
@@ -427,7 +418,46 @@ def mmr_traversal_search(
427
418
has_link_from_tags = True ,
428
419
)
429
420
421
+ visited_tags : set [tuple [str , str ]] = set ()
422
+
423
+ def fetch_neighborhood (neighborhood : Sequence [str ]) -> None :
424
+ # Put the neighborhood into the outgoing tags, to avoid adding it
425
+ # to the candidate set in the future.
426
+ outgoing_tags .update ({content_id : set () for content_id in neighborhood })
427
+
428
+ # Initialize the visited_tags with the set of outgoing from the
429
+ # neighborhood. This prevents re-visiting them.
430
+ visited_tags = self ._get_outgoing_tags (neighborhood )
431
+
432
+ # Call `self._get_adjacent` to fetch the candidates.
433
+ adjacents = self ._get_adjacent (
434
+ visited_tags ,
435
+ adjacent_query = adjacent_query ,
436
+ query_embedding = query_embedding ,
437
+ k_per_tag = adjacent_k ,
438
+ metadata_filter = metadata_filter ,
439
+ )
440
+
441
+ new_candidates = {}
442
+ for adjacent in adjacents :
443
+ if adjacent .target_content_id not in outgoing_tags :
444
+ outgoing_tags [adjacent .target_content_id ] = (
445
+ adjacent .target_link_to_tags
446
+ )
447
+
448
+ new_candidates [adjacent .target_content_id ] = (
449
+ adjacent .target_text_embedding
450
+ )
451
+ helper .add_candidates (new_candidates )
452
+
430
453
def fetch_initial_candidates () -> None :
454
+ initial_candidates_query = self ._get_search_cql (
455
+ has_limit = True ,
456
+ columns = columns ,
457
+ metadata_keys = list (metadata_filter .keys ()),
458
+ has_embedding = True ,
459
+ )
460
+
431
461
params = self ._get_search_params (
432
462
limit = fetch_k ,
433
463
metadata = metadata_filter ,
@@ -439,16 +469,20 @@ def fetch_initial_candidates() -> None:
439
469
)
440
470
candidates = {}
441
471
for row in fetched :
442
- candidates [row .content_id ] = row .text_embedding
443
- outgoing_tags [row .content_id ] = set (row .link_to_tags or [])
472
+ if row .content_id not in outgoing_tags :
473
+ candidates [row .content_id ] = row .text_embedding
474
+ outgoing_tags [row .content_id ] = set (row .link_to_tags or [])
444
475
helper .add_candidates (candidates )
445
476
446
- fetch_initial_candidates ()
477
+ if initial_roots :
478
+ fetch_neighborhood (initial_roots )
479
+ if fetch_k > 0 :
480
+ fetch_initial_candidates ()
447
481
448
- # Select the best item, K times .
482
+ # Tracks the depth of each candidate .
449
483
depths = {candidate_id : 0 for candidate_id in helper .candidate_ids ()}
450
- visited_tags : set [tuple [str , str ]] = set ()
451
484
485
+ # Select the best item, K times.
452
486
for _ in range (k ):
453
487
selected_id = helper .pop_best ()
454
488
@@ -683,12 +717,15 @@ def _get_outgoing_tags(
683
717
684
718
def add_sources (rows : Iterable [Any ]) -> None :
685
719
for row in rows :
686
- tags .update (row .link_to_tags )
720
+ if row .link_to_tags :
721
+ tags .update (row .link_to_tags )
687
722
688
723
with self ._concurrent_queries () as cq :
689
724
for source_id in source_ids :
690
725
cq .execute (
691
- self ._query_source_tags_by_id , (source_id ,), callback = add_sources
726
+ self ._query_ids_and_link_to_tags_by_id ,
727
+ (source_id ,),
728
+ callback = add_sources ,
692
729
)
693
730
694
731
return tags
@@ -699,7 +736,7 @@ def _get_adjacent(
699
736
adjacent_query : PreparedStatement ,
700
737
query_embedding : list [float ],
701
738
k_per_tag : int | None = None ,
702
- metadata_filter : dict [str , Any ] = {}, # noqa: B006
739
+ metadata_filter : dict [str , Any ] | None = None ,
703
740
) -> Iterable [_Edge ]:
704
741
"""Return the target nodes with incoming links from any of the given tags.
705
742
@@ -809,11 +846,15 @@ def _coerce_string(value: Any) -> str:
809
846
810
847
def _extract_where_clause_cql (
811
848
self ,
812
- metadata_keys : list [str ] = [], # noqa: B006
849
+ has_id : bool = False ,
850
+ metadata_keys : Sequence [str ] = (),
813
851
has_link_from_tags : bool = False ,
814
852
) -> str :
815
853
wc_blocks : list [str ] = []
816
854
855
+ if has_id :
856
+ wc_blocks .append ("content_id == ?" )
857
+
817
858
if has_link_from_tags :
818
859
wc_blocks .append ("link_from_tags CONTAINS (?, ?)" )
819
860
@@ -855,12 +896,15 @@ def _get_search_cql(
855
896
self ,
856
897
has_limit : bool = False ,
857
898
columns : str | None = CONTENT_COLUMNS ,
858
- metadata_keys : list [str ] = [], # noqa: B006
899
+ metadata_keys : Sequence [str ] = (),
900
+ has_id : bool = False ,
859
901
has_embedding : bool = False ,
860
902
has_link_from_tags : bool = False ,
861
903
) -> PreparedStatement :
862
904
where_clause = self ._extract_where_clause_cql (
863
- metadata_keys = metadata_keys , has_link_from_tags = has_link_from_tags
905
+ has_id = has_id ,
906
+ metadata_keys = metadata_keys ,
907
+ has_link_from_tags = has_link_from_tags ,
864
908
)
865
909
limit_clause = " LIMIT ?" if has_limit else ""
866
910
order_clause = " ORDER BY text_embedding ANN OF ?" if has_embedding else ""
@@ -885,12 +929,12 @@ def _get_search_cql(
885
929
def _get_search_params (
886
930
self ,
887
931
limit : int | None = None ,
888
- metadata : dict [str , Any ] = {}, # noqa: B006
932
+ metadata : dict [str , Any ] | None = None ,
889
933
embedding : list [float ] | None = None ,
890
934
link_from_tags : tuple [str , str ] | None = None ,
891
935
) -> tuple [PreparedStatement , tuple [Any , ...]]:
892
936
where_params = self ._extract_where_clause_params (
893
- metadata = metadata , link_from_tags = link_from_tags
937
+ metadata = metadata or {} , link_from_tags = link_from_tags
894
938
)
895
939
896
940
limit_params = [limit ] if limit is not None else []
@@ -902,14 +946,14 @@ def _get_search_cql_and_params(
902
946
self ,
903
947
limit : int | None = None ,
904
948
columns : str | None = CONTENT_COLUMNS ,
905
- metadata : dict [str , Any ] = {}, # noqa: B006
949
+ metadata : dict [str , Any ] | None = None ,
906
950
embedding : list [float ] | None = None ,
907
951
link_from_tags : tuple [str , str ] | None = None ,
908
952
) -> tuple [PreparedStatement , tuple [Any , ...]]:
909
953
query = self ._get_search_cql (
910
954
has_limit = limit is not None ,
911
955
columns = columns ,
912
- metadata_keys = list (metadata .keys ()),
956
+ metadata_keys = list (metadata .keys ()) if metadata else () ,
913
957
has_embedding = embedding is not None ,
914
958
has_link_from_tags = link_from_tags is not None ,
915
959
)
0 commit comments