29
29
30
30
CONTENT_COLUMNS = "content_id, kind, text_content, attributes_blob, metadata_s, links_blob"
31
31
32
- SELECT_CQL_TEMPLATE = "SELECT {columns} FROM {table_name} {where_clause} {limit_clause};"
32
+ SELECT_CQL_TEMPLATE = "SELECT {columns} FROM {table_name} {where_clause} {order_clause} { limit_clause};"
33
33
34
34
@dataclass
35
35
class Node :
@@ -198,28 +198,6 @@ def __init__(
198
198
""" # noqa: S608
199
199
)
200
200
201
- self ._query_by_embedding = session .prepare (
202
- f"""
203
- SELECT { CONTENT_COLUMNS }
204
- FROM { keyspace } .{ node_table }
205
- ORDER BY text_embedding ANN OF ?
206
- LIMIT ?
207
- """ # noqa: S608
208
- )
209
- self ._query_by_embedding .consistency_level = ConsistencyLevel .ONE
210
-
211
- self ._query_ids_and_link_to_tags_by_embedding = session .prepare (
212
- f"""
213
- SELECT content_id, link_to_tags
214
- FROM { keyspace } .{ node_table }
215
- ORDER BY text_embedding ANN OF ?
216
- LIMIT ?
217
- """ # noqa: S608
218
- )
219
- self ._query_ids_and_link_to_tags_by_embedding .consistency_level = (
220
- ConsistencyLevel .ONE
221
- )
222
-
223
201
self ._query_ids_and_link_to_tags_by_id = session .prepare (
224
202
f"""
225
203
SELECT content_id, link_to_tags
@@ -228,18 +206,6 @@ def __init__(
228
206
""" # noqa: S608
229
207
)
230
208
231
- self ._query_ids_and_embedding_by_embedding = session .prepare (
232
- f"""
233
- SELECT content_id, text_embedding, link_to_tags
234
- FROM { keyspace } .{ node_table }
235
- ORDER BY text_embedding ANN OF ?
236
- LIMIT ?
237
- """ # noqa: S608
238
- )
239
- self ._query_ids_and_embedding_by_embedding .consistency_level = (
240
- ConsistencyLevel .ONE
241
- )
242
-
243
209
self ._query_source_tags_by_id = session .prepare (
244
210
f"""
245
211
SELECT link_to_tags
@@ -270,11 +236,14 @@ def __init__(
270
236
"""
271
237
)
272
238
239
+ def table_name (self ) -> str :
240
+ return f"{ self ._keyspace } .{ self ._node_table } "
241
+
273
242
def _apply_schema (self ) -> None :
274
243
"""Apply the schema to the database."""
275
244
embedding_dim = len (self ._embedding .embed_query ("Test Query" ))
276
245
self ._session .execute (f"""
277
- CREATE TABLE IF NOT EXISTS { self ._keyspace } . { self . _node_table } (
246
+ CREATE TABLE IF NOT EXISTS { self .table_name () } (
278
247
content_id TEXT,
279
248
kind TEXT,
280
249
text_content TEXT,
@@ -293,19 +262,19 @@ def _apply_schema(self) -> None:
293
262
# Index on text_embedding (for similarity search)
294
263
self ._session .execute (f"""
295
264
CREATE CUSTOM INDEX IF NOT EXISTS { self ._node_table } _text_embedding_index
296
- ON { self ._keyspace } . { self . _node_table } (text_embedding)
265
+ ON { self .table_name () } (text_embedding)
297
266
USING 'StorageAttachedIndex';
298
267
""" )
299
268
300
269
self ._session .execute (f"""
301
270
CREATE CUSTOM INDEX IF NOT EXISTS { self ._node_table } _link_from_tags
302
- ON { self ._keyspace } . { self . _node_table } (link_from_tags)
271
+ ON { self .table_name () } (link_from_tags)
303
272
USING 'StorageAttachedIndex';
304
273
""" )
305
274
306
275
self ._session .execute (f"""
307
276
CREATE CUSTOM INDEX IF NOT EXISTS { self ._node_table } _metadata_index
308
- ON { self ._keyspace } . { self . _node_table } (ENTRIES(metadata_s))
277
+ ON { self .table_name () } (ENTRIES(metadata_s))
309
278
USING 'StorageAttachedIndex';
310
279
""" )
311
280
@@ -425,6 +394,7 @@ def mmr_traversal_search(
425
394
adjacent_k : int = 10 ,
426
395
lambda_mult : float = 0.5 ,
427
396
score_threshold : float = float ("-inf" ),
397
+ metadata : Optional [Dict [str , Any ]] = [],
428
398
) -> Iterable [Node ]:
429
399
"""Retrieve documents from this graph store using MMR-traversal.
430
400
@@ -450,6 +420,7 @@ def mmr_traversal_search(
450
420
diversity and 1 to minimum diversity. Defaults to 0.5.
451
421
score_threshold: Only documents with a score greater than or equal
452
422
this threshold will be chosen. Defaults to -infinity.
423
+ metadata: Optional metadata to filter the results.
453
424
"""
454
425
query_embedding = self ._embedding .embed_query (query )
455
426
helper = MmrHelper (
@@ -465,10 +436,14 @@ def mmr_traversal_search(
465
436
# Fetch the initial candidates and add them to the helper and
466
437
# outgoing_tags.
467
438
def fetch_initial_candidates () -> None :
468
- fetched = self ._session .execute (
469
- self ._query_ids_and_embedding_by_embedding ,
470
- (query_embedding , fetch_k ),
439
+ query , params = self ._get_search_cql (
440
+ limit = fetch_k ,
441
+ columns = "content_id, text_embedding, link_to_tags" ,
442
+ metadata = metadata ,
443
+ embedding = query_embedding
471
444
)
445
+
446
+ fetched = self ._session .execute (query = query , parameters = params )
472
447
candidates = {}
473
448
for row in fetched :
474
449
candidates [row .content_id ] = row .text_embedding
@@ -540,7 +515,7 @@ def fetch_initial_candidates() -> None:
540
515
return self ._nodes_with_ids (helper .selected_ids )
541
516
542
517
def traversal_search (
543
- self , query : str , * , k : int = 4 , depth : int = 1
518
+ self , query : str , * , k : int = 4 , depth : int = 1 , metadata : Optional [ Dict [ str , Any ]] = [],
544
519
) -> Iterable [Node ]:
545
520
"""Retrieve documents from this knowledge store.
546
521
@@ -553,6 +528,7 @@ def traversal_search(
553
528
k: The number of Documents to return from the initial vector search.
554
529
Defaults to 4.
555
530
depth: The maximum depth of edges to traverse. Defaults to 1.
531
+ metadata: Optional metadata to filter the results.
556
532
557
533
Returns:
558
534
Collection of retrieved documents.
@@ -638,9 +614,15 @@ def visit_targets(d: int, targets: Sequence[Any]) -> None:
638
614
)
639
615
640
616
query_embedding = self ._embedding .embed_query (query )
617
+ query , params = self ._get_search_cql (
618
+ limit = k ,
619
+ metadata = metadata ,
620
+ embedding = query_embedding ,
621
+ )
622
+
641
623
cq .execute (
642
- self . _query_ids_and_link_to_tags_by_embedding ,
643
- parameters = ( query_embedding , k ) ,
624
+ query ,
625
+ parameters = params ,
644
626
callback = lambda nodes : visit_nodes (0 , nodes ),
645
627
)
646
628
@@ -650,17 +632,18 @@ def similarity_search(
650
632
self ,
651
633
embedding : List [float ],
652
634
k : int = 4 ,
635
+ metadata : Optional [Dict [str , Any ]] = [],
653
636
) -> Iterable [Node ]:
654
- """Retrieve nodes similar to the given embedding."""
655
- for row in self ._session .execute (self ._query_by_embedding , (embedding , k )):
637
+ """Retrieve nodes similar to the given embedding, optionally filtered by metadata"""
638
+ query , params = self ._get_search_cql (embedding = embedding , limit = k , metadata = metadata )
639
+
640
+ for row in self ._session .execute (query , params ):
656
641
yield _row_to_node (row )
657
642
658
643
def metadata_search (self , metadata : Dict [str , Any ] = {}, n : Optional [int ] = 5 )-> Iterable [Node ]:
659
- query , params = self ._get_metadata_search_cql (metadata = metadata , n = n )
660
-
661
- prepared_query = self ._session .prepare (query )
644
+ query , params = self ._get_search_cql (metadata = metadata , limit = n )
662
645
663
- for row in self ._session .execute (prepared_query , params ):
646
+ for row in self ._session .execute (query , params ):
664
647
yield _row_to_node (row )
665
648
666
649
def get_node (self , id : str ) -> Node :
@@ -802,7 +785,7 @@ def _extract_where_clause_blocks(
802
785
self , metadata : Dict [str , Any ]
803
786
) -> Tuple [str , List [Any ]]:
804
787
805
- attributes_blob , metadata_s = self ._parse_metadata (metadata = metadata , is_query = True )
788
+ _ , metadata_s = self ._parse_metadata (metadata = metadata , is_query = True )
806
789
807
790
if len (metadata_s ) == 0 :
808
791
return "" , []
@@ -818,17 +801,27 @@ def _extract_where_clause_blocks(
818
801
return where_clause , vals_list
819
802
820
803
821
- def _get_metadata_search_cql (self , n : int , metadata : Dict [str , Any ]) -> Tuple [str , Tuple [Any , ...]]:
804
+ def _get_search_cql (self , limit : int , columns : Optional [ str ] = CONTENT_COLUMNS , metadata : Optional [ Dict [str , Any ]] = {}, embedding : Optional [ List [ float ]] = None ) -> Tuple [str , Tuple [Any , ...]]:
822
805
where_clause , get_cql_vals = self ._extract_where_clause_blocks (metadata = metadata )
823
806
limit_clause = "LIMIT ?"
824
- limit_cql_vals = [n ]
825
- select_vals = tuple (list (get_cql_vals ) + limit_cql_vals )
826
- #
807
+ limit_cql_vals = [limit ]
808
+
809
+ order_clause = ""
810
+ order_cql_vals = []
811
+ if embedding is not None :
812
+ order_clause = "ORDER BY text_embedding ANN OF ?"
813
+ order_cql_vals = [embedding ]
814
+
815
+ select_vals = tuple (list (get_cql_vals ) + order_cql_vals + limit_cql_vals )
827
816
select_cql = SELECT_CQL_TEMPLATE .format (
828
- columns = CONTENT_COLUMNS ,
829
- table_name = f" { self ._keyspace } . { self . _node_table } " ,
817
+ columns = columns ,
818
+ table_name = self .table_name () ,
830
819
where_clause = where_clause ,
820
+ order_clause = order_clause ,
831
821
limit_clause = limit_clause ,
832
822
833
823
)
834
- return select_cql , select_vals
824
+ prepared_query = self ._session .prepare (select_cql )
825
+ prepared_query .consistency_level = ConsistencyLevel .ONE
826
+
827
+ return prepared_query , select_vals
0 commit comments