@@ -107,6 +107,7 @@ def _row_to_node(row: Any) -> Node:
107
107
class _Edge :
108
108
target_content_id : str
109
109
target_text_embedding : List [float ]
110
+ target_link_to_tags : Set [Tuple [str , str ]]
110
111
111
112
112
113
class GraphStore :
@@ -170,8 +171,8 @@ def __init__(
170
171
self ._insert_tag = session .prepare (
171
172
f"""
172
173
INSERT INTO { keyspace } .{ targets_table } (
173
- target_content_id, kind, tag, target_text_embedding
174
- ) VALUES (?, ?, ?, ?)
174
+ target_content_id, kind, tag, target_text_embedding, target_link_to_tags
175
+ ) VALUES (?, ?, ?, ?, ? )
175
176
""" # noqa: S608
176
177
)
177
178
@@ -215,7 +216,7 @@ def __init__(
215
216
216
217
self ._query_ids_and_embedding_by_embedding = session .prepare (
217
218
f"""
218
- SELECT content_id, text_embedding
219
+ SELECT content_id, text_embedding, link_to_tags
219
220
FROM { keyspace } .{ node_table }
220
221
ORDER BY text_embedding ANN OF ?
221
222
LIMIT ?
@@ -235,7 +236,7 @@ def __init__(
235
236
236
237
self ._query_targets_embeddings_by_kind_and_tag_and_embedding = session .prepare (
237
238
f"""
238
- SELECT target_content_id, target_text_embedding, tag
239
+ SELECT target_content_id, target_text_embedding, tag, target_link_to_tags
239
240
FROM { keyspace } .{ targets_table }
240
241
WHERE kind = ? AND tag = ?
241
242
ORDER BY target_text_embedding ANN of ?
@@ -279,6 +280,7 @@ def _apply_schema(self) -> None:
279
280
-- text_embedding of target node. allows MMR to be applied without
280
281
-- fetching nodes.
281
282
target_text_embedding VECTOR<FLOAT, { embedding_dim } >,
283
+ target_link_to_tags SET<TUPLE<TEXT, TEXT>>,
282
284
283
285
PRIMARY KEY ((kind, tag), target_content_id)
284
286
)
@@ -354,7 +356,7 @@ def add_nodes(
354
356
for kind , value in link_from_tags :
355
357
cq .execute (
356
358
self ._insert_tag ,
357
- parameters = (node_id , kind , value , text_embedding ),
359
+ parameters = (node_id , kind , value , text_embedding , link_to_tags ),
358
360
)
359
361
360
362
return node_ids
@@ -433,16 +435,28 @@ def mmr_traversal_search(
433
435
score_threshold = score_threshold ,
434
436
)
435
437
436
- # Fetch the initial candidates and add them to the helper.
437
- fetched = self ._session .execute (
438
- self ._query_ids_and_embedding_by_embedding ,
439
- (query_embedding , fetch_k ),
440
- )
441
- helper .add_candidates ({row .content_id : row .text_embedding for row in fetched })
438
+ # For each unvisited node, stores the outgoing tags.
439
+ outgoing_tags : Dict [str , Set [Tuple [str , str ]]] = {}
440
+
441
+ # Fetch the initial candidates and add them to the helper and
442
+ # outgoing_tags.
443
+ def fetch_initial_candidates () -> None :
444
+ fetched = self ._session .execute (
445
+ self ._query_ids_and_embedding_by_embedding ,
446
+ (query_embedding , fetch_k ),
447
+ )
448
+ candidates = {}
449
+ for row in fetched :
450
+ candidates [row .content_id ] = row .text_embedding
451
+ outgoing_tags [row .content_id ] = set (row .link_to_tags or [])
452
+ helper .add_candidates (candidates )
453
+
454
+ fetch_initial_candidates ()
442
455
443
456
# Select the best item, K times.
444
457
depths = {candidate_id : 0 for candidate_id in helper .candidate_ids ()}
445
458
visited_tags : Set [Tuple [str , str ]] = set ()
459
+
446
460
for _ in range (k ):
447
461
selected_id = helper .pop_best ()
448
462
@@ -457,30 +471,46 @@ def mmr_traversal_search(
457
471
# TODO: For a big performance win, we should track which tags we've
458
472
# already incorporated. We don't need to issue adjacent queries for
459
473
# those.
474
+
475
+ # Find the tags linked to from the selected ID.
476
+ link_to_tags = outgoing_tags .pop (selected_id )
477
+
478
+ # Don't re-visit already visited tags.
479
+ link_to_tags .difference_update (visited_tags )
480
+
481
+ # Find the nodes with incoming links from those tags.
460
482
adjacents = self ._get_adjacent (
461
- [selected_id ],
462
- visited_tags = visited_tags ,
483
+ link_to_tags ,
463
484
query_embedding = query_embedding ,
464
485
k_per_tag = adjacent_k ,
465
486
)
466
487
488
+ # Record the link_to_tags as visited.
489
+ visited_tags .update (link_to_tags )
490
+
467
491
new_candidates = {}
468
492
for adjacent in adjacents :
469
- new_candidates [adjacent .target_content_id ] = (
470
- adjacent .target_text_embedding
471
- )
472
- if next_depth < depths .get (adjacent .target_content_id , depth + 1 ):
473
- # If this is a new shortest depth, or there was no
474
- # previous depth, update the depths. This ensures that
475
- # when we discover a node we will have the shortest
476
- # depth available.
477
- #
478
- # NOTE: No effort is made to traverse from nodes that
479
- # were previously selected if they become reachable via
480
- # a shorter path via nodes selected later. This is
481
- # currently "intended", but may be worth experimenting
482
- # with.
483
- depths [adjacent .target_content_id ] = next_depth
493
+ if adjacent .target_content_id not in outgoing_tags :
494
+ outgoing_tags [adjacent .target_content_id ] = (
495
+ adjacent .target_link_to_tags
496
+ )
497
+ new_candidates [adjacent .target_content_id ] = (
498
+ adjacent .target_text_embedding
499
+ )
500
+ if next_depth < depths .get (
501
+ adjacent .target_content_id , depth + 1
502
+ ):
503
+ # If this is a new shortest depth, or there was no
504
+ # previous depth, update the depths. This ensures that
505
+ # when we discover a node we will have the shortest
506
+ # depth available.
507
+ #
508
+ # NOTE: No effort is made to traverse from nodes that
509
+ # were previously selected if they become reachable via
510
+ # a shorter path via nodes selected later. This is
511
+ # currently "intended", but may be worth experimenting
512
+ # with.
513
+ depths [adjacent .target_content_id ] = next_depth
484
514
helper .add_candidates (new_candidates )
485
515
486
516
return self ._nodes_with_ids (helper .selected_ids )
@@ -601,61 +631,73 @@ def similarity_search(
601
631
for row in self ._session .execute (self ._query_by_embedding , (embedding , k )):
602
632
yield _row_to_node (row )
603
633
604
- def _get_adjacent (
634
+ def _get_outgoing_tags (
605
635
self ,
606
636
source_ids : Iterable [str ],
607
- visited_tags : Set [Tuple [str , str ]],
637
+ ) -> Set [Tuple [str , str ]]:
638
+ """Return the set of outgoing tags for the given source ID(s).
639
+
640
+ Args:
641
+ source_ids: The IDs of the source nodes to retrieve outgoing tags for.
642
+ """
643
+ tags = set ()
644
+
645
+ def add_sources (rows : Iterable [Any ]) -> None :
646
+ for row in rows :
647
+ tags .update (row .link_to_tags )
648
+
649
+ with self ._concurrent_queries () as cq :
650
+ for source_id in source_ids :
651
+ cq .execute (
652
+ self ._query_source_tags_by_id , (source_id ,), callback = add_sources
653
+ )
654
+
655
+ return tags
656
+
657
+ def _get_adjacent (
658
+ self ,
659
+ tags : Set [Tuple [str , str ]],
608
660
query_embedding : List [float ],
609
661
k_per_tag : Optional [int ] = None ,
610
662
) -> Iterable [_Edge ]:
611
- """Return the target nodes adjacent to any of the source nodes .
663
+ """Return the target nodes with incoming links from any of the given tags .
612
664
613
665
Args:
614
- source_ids: The source IDs to start from when retrieving adjacent nodes.
615
- visited_tags: Tags we've already visited.
666
+ tags: The tags to look for links *from*.
616
667
query_embedding: The query embedding. Used to rank target nodes.
617
668
k_per_tag: The number of target nodes to fetch for each outgoing tag.
618
669
619
670
Returns:
620
671
List of adjacent edges.
621
672
"""
622
- targets : Dict [str , List [float ]] = {}
623
-
624
- def add_sources (rows : Iterable [Any ]) -> None :
625
- for row in rows :
626
- for new_tag in row .link_to_tags or []:
627
- if new_tag not in visited_tags :
628
- visited_tags .add (new_tag )
629
-
630
- cq .execute (
631
- self ._query_targets_embeddings_by_kind_and_tag_and_embedding ,
632
- parameters = (
633
- new_tag [0 ],
634
- new_tag [1 ],
635
- query_embedding ,
636
- k_per_tag or 10 ,
637
- ),
638
- callback = add_targets ,
639
- )
673
+ targets : Dict [str , _Edge ] = {}
640
674
641
675
def add_targets (rows : Iterable [Any ]) -> None :
642
676
# TODO: Figure out how to use the "kind" on the edge.
643
677
# This is tricky, since we currently issue one query for anything
644
678
# adjacent via any kind, and we don't have enough information to
645
679
# determine which kind(s) a given target was reached from.
646
680
for row in rows :
647
- targets .setdefault (row .target_content_id , row .target_text_embedding )
681
+ if row .target_content_id not in targets :
682
+ targets [row .target_content_id ] = _Edge (
683
+ target_content_id = row .target_content_id ,
684
+ target_text_embedding = row .target_text_embedding ,
685
+ target_link_to_tags = set (row .target_link_to_tags or []),
686
+ )
648
687
649
688
with self ._concurrent_queries () as cq :
650
- # TODO: We could eliminate this query by storing the source tags of the
651
- # target node in the targets table.
652
- for source_id in source_ids :
689
+ for kind , value in tags :
653
690
cq .execute (
654
- self ._query_source_tags_by_id , (source_id ,), callback = add_sources
691
+ self ._query_targets_embeddings_by_kind_and_tag_and_embedding ,
692
+ parameters = (
693
+ kind ,
694
+ value ,
695
+ query_embedding ,
696
+ k_per_tag or 10 ,
697
+ ),
698
+ callback = add_targets ,
655
699
)
656
700
657
- # TODO: Consider a combined limit based on the similarity and/or predicated MMR score? # noqa: E501
658
- return [
659
- _Edge (target_content_id = content_id , target_text_embedding = embedding )
660
- for (content_id , embedding ) in targets .items ()
661
- ]
701
+ # TODO: Consider a combined limit based on the similarity and/or
702
+ # predicated MMR score?
703
+ return targets .values ()
0 commit comments