13
13
import numpy as np
14
14
from cassandra .cluster import ConsistencyLevel , ResponseFuture , Session
15
15
from cassio .config import check_resolve_keyspace , check_resolve_session
16
+ from langchain_community .utilities .cassandra import SetupMode
16
17
from langchain_community .utils .math import cosine_similarity
17
18
from langchain_core .documents import Document
18
19
from langchain_core .embeddings import Embeddings
@@ -68,7 +69,9 @@ class _Candidate:
68
69
redundancy : float
69
70
"""(1 - Lambda) * max(Similarity to selected items)."""
70
71
71
- def __init__ (self , embedding : List [float ], lambda_mult : float , query_embedding : np .ndarray ):
72
+ def __init__ (
73
+ self , embedding : List [float ], lambda_mult : float , query_embedding : np .ndarray
74
+ ):
72
75
self .embedding = emb_to_ndarray (embedding )
73
76
74
77
# TODO: Refactor to use cosine_similarity_top_k to allow an array of embeddings?
@@ -79,7 +82,9 @@ def __init__(self, embedding: List[float], lambda_mult: float, query_embedding:
79
82
self .score = self .similarity_to_query - self .redundancy
80
83
self .distance = 0
81
84
82
- def update_for_selection (self , lambda_mult : float , selection_embedding : List [float ]):
85
+ def update_for_selection (
86
+ self , lambda_mult : float , selection_embedding : List [float ]
87
+ ):
83
88
selected_r_sim = (1 - lambda_mult ) * cosine_similarity (
84
89
selection_embedding , self .embedding
85
90
)[0 ]
@@ -98,7 +103,7 @@ def __init__(
98
103
edge_table : str = "knowledge_edges" ,
99
104
session : Optional [Session ] = None ,
100
105
keyspace : Optional [str ] = None ,
101
- apply_schema : bool = True ,
106
+ setup_mode : SetupMode = SetupMode . SYNC ,
102
107
concurrency : int = 20 ,
103
108
):
104
109
"""A hybrid vector-and-graph knowledge store backed by Cassandra.
@@ -130,8 +135,13 @@ def __init__(
130
135
self ._session = session
131
136
self ._keyspace = keyspace
132
137
133
- if apply_schema :
138
+ if setup_mode == SetupMode . SYNC :
134
139
self ._apply_schema ()
140
+ elif setup_mode != SetupMode .OFF :
141
+ raise ValueError (
142
+ f"Invalid setup mode { setup_mode .name } . "
143
+ "Only SYNC and OFF are supported at the moment"
144
+ )
135
145
136
146
# Ensure the edge extractor `kind`s are unique.
137
147
assert len (edge_extractors ) == len (set ([e .kind for e in edge_extractors ]))
@@ -199,7 +209,9 @@ def __init__(
199
209
LIMIT ?
200
210
"""
201
211
)
202
- self ._query_ids_and_embedding_by_embedding .consistency_level = ConsistencyLevel .QUORUM
212
+ self ._query_ids_and_embedding_by_embedding .consistency_level = (
213
+ ConsistencyLevel .QUORUM
214
+ )
203
215
204
216
self ._query_linked_ids = session .prepare (
205
217
f"""
@@ -352,7 +364,9 @@ def from_documents(
352
364
store .add_documents (documents , ids = ids )
353
365
return store
354
366
355
- def similarity_search (self , query : str , k : int = 4 , ** kwargs : Any ) -> List [Document ]:
367
+ def similarity_search (
368
+ self , query : str , k : int = 4 , ** kwargs : Any
369
+ ) -> List [Document ]:
356
370
embedding_vector = self ._embedding .embed_query (query )
357
371
return self .similarity_search_by_vector (
358
372
embedding_vector ,
@@ -429,7 +443,9 @@ def mmr_traversal_search(
429
443
selected_ids = []
430
444
selected_set = set ()
431
445
432
- selected_embeddings = [] # selected embeddings. saved to compute redundancy of new nodes.
446
+ selected_embeddings = (
447
+ []
448
+ ) # selected embeddings. saved to compute redundancy of new nodes.
433
449
434
450
query_embedding = self ._embedding .embed_query (query )
435
451
fetched = self ._session .execute (
@@ -471,7 +487,9 @@ def mmr_traversal_search(
471
487
# Add unselected edges if reached nodes are within `depth`:
472
488
next_depth = next_selected .distance + 1
473
489
if next_depth < depth :
474
- adjacents = self ._session .execute (self ._query_edges_by_source , (selected_id ,))
490
+ adjacents = self ._session .execute (
491
+ self ._query_edges_by_source , (selected_id ,)
492
+ )
475
493
for row in adjacents :
476
494
target_id = row .target_content_id
477
495
if target_id in selected_set :
@@ -485,7 +503,9 @@ def mmr_traversal_search(
485
503
unselected [target_id ].distance = next_depth
486
504
continue
487
505
488
- adjacent = _Candidate (row .target_text_embedding , lambda_mult , query_embedding )
506
+ adjacent = _Candidate (
507
+ row .target_text_embedding , lambda_mult , query_embedding
508
+ )
489
509
for selected_embedding in selected_embeddings :
490
510
adjacent .update_for_selection (lambda_mult , selected_embedding )
491
511
@@ -496,7 +516,9 @@ def mmr_traversal_search(
496
516
497
517
return self ._query_by_ids (selected_ids )
498
518
499
- def traversal_search (self , query : str , * , k : int = 4 , depth : int = 1 ) -> Iterable [Document ]:
519
+ def traversal_search (
520
+ self , query : str , * , k : int = 4 , depth : int = 1
521
+ ) -> Iterable [Document ]:
500
522
"""Retrieve documents from this knowledge store.
501
523
502
524
First, `k` nodes are retrieved using a vector search for the `query` string.
0 commit comments