Skip to content

Commit ade17cc

Browse files
authored
Use setup_mode instead of apply_schema for consistency with other components (#465)
1 parent 39468d1 commit ade17cc

File tree

1 file changed

+32
-10
lines changed

1 file changed

+32
-10
lines changed

libs/knowledge-store/ragstack_knowledge_store/cassandra.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import numpy as np
1414
from cassandra.cluster import ConsistencyLevel, ResponseFuture, Session
1515
from cassio.config import check_resolve_keyspace, check_resolve_session
16+
from langchain_community.utilities.cassandra import SetupMode
1617
from langchain_community.utils.math import cosine_similarity
1718
from langchain_core.documents import Document
1819
from langchain_core.embeddings import Embeddings
@@ -68,7 +69,9 @@ class _Candidate:
6869
redundancy: float
6970
"""(1 - Lambda) * max(Similarity to selected items)."""
7071

71-
def __init__(self, embedding: List[float], lambda_mult: float, query_embedding: np.ndarray):
72+
def __init__(
73+
self, embedding: List[float], lambda_mult: float, query_embedding: np.ndarray
74+
):
7275
self.embedding = emb_to_ndarray(embedding)
7376

7477
# TODO: Refactor to use cosine_similarity_top_k to allow an array of embeddings?
@@ -79,7 +82,9 @@ def __init__(self, embedding: List[float], lambda_mult: float, query_embedding:
7982
self.score = self.similarity_to_query - self.redundancy
8083
self.distance = 0
8184

82-
def update_for_selection(self, lambda_mult: float, selection_embedding: List[float]):
85+
def update_for_selection(
86+
self, lambda_mult: float, selection_embedding: List[float]
87+
):
8388
selected_r_sim = (1 - lambda_mult) * cosine_similarity(
8489
selection_embedding, self.embedding
8590
)[0]
@@ -98,7 +103,7 @@ def __init__(
98103
edge_table: str = "knowledge_edges",
99104
session: Optional[Session] = None,
100105
keyspace: Optional[str] = None,
101-
apply_schema: bool = True,
106+
setup_mode: SetupMode = SetupMode.SYNC,
102107
concurrency: int = 20,
103108
):
104109
"""A hybrid vector-and-graph knowledge store backed by Cassandra.
@@ -130,8 +135,13 @@ def __init__(
130135
self._session = session
131136
self._keyspace = keyspace
132137

133-
if apply_schema:
138+
if setup_mode == SetupMode.SYNC:
134139
self._apply_schema()
140+
elif setup_mode != SetupMode.OFF:
141+
raise ValueError(
142+
f"Invalid setup mode {setup_mode.name}. "
143+
"Only SYNC and OFF are supported at the moment"
144+
)
135145

136146
# Ensure the edge extractor `kind`s are unique.
137147
assert len(edge_extractors) == len(set([e.kind for e in edge_extractors]))
@@ -199,7 +209,9 @@ def __init__(
199209
LIMIT ?
200210
"""
201211
)
202-
self._query_ids_and_embedding_by_embedding.consistency_level = ConsistencyLevel.QUORUM
212+
self._query_ids_and_embedding_by_embedding.consistency_level = (
213+
ConsistencyLevel.QUORUM
214+
)
203215

204216
self._query_linked_ids = session.prepare(
205217
f"""
@@ -352,7 +364,9 @@ def from_documents(
352364
store.add_documents(documents, ids=ids)
353365
return store
354366

355-
def similarity_search(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
367+
def similarity_search(
368+
self, query: str, k: int = 4, **kwargs: Any
369+
) -> List[Document]:
356370
embedding_vector = self._embedding.embed_query(query)
357371
return self.similarity_search_by_vector(
358372
embedding_vector,
@@ -429,7 +443,9 @@ def mmr_traversal_search(
429443
selected_ids = []
430444
selected_set = set()
431445

432-
selected_embeddings = [] # selected embeddings. saved to compute redundancy of new nodes.
446+
selected_embeddings = (
447+
[]
448+
) # selected embeddings. saved to compute redundancy of new nodes.
433449

434450
query_embedding = self._embedding.embed_query(query)
435451
fetched = self._session.execute(
@@ -471,7 +487,9 @@ def mmr_traversal_search(
471487
# Add unselected edges if reached nodes are within `depth`:
472488
next_depth = next_selected.distance + 1
473489
if next_depth < depth:
474-
adjacents = self._session.execute(self._query_edges_by_source, (selected_id,))
490+
adjacents = self._session.execute(
491+
self._query_edges_by_source, (selected_id,)
492+
)
475493
for row in adjacents:
476494
target_id = row.target_content_id
477495
if target_id in selected_set:
@@ -485,7 +503,9 @@ def mmr_traversal_search(
485503
unselected[target_id].distance = next_depth
486504
continue
487505

488-
adjacent = _Candidate(row.target_text_embedding, lambda_mult, query_embedding)
506+
adjacent = _Candidate(
507+
row.target_text_embedding, lambda_mult, query_embedding
508+
)
489509
for selected_embedding in selected_embeddings:
490510
adjacent.update_for_selection(lambda_mult, selected_embedding)
491511

@@ -496,7 +516,9 @@ def mmr_traversal_search(
496516

497517
return self._query_by_ids(selected_ids)
498518

499-
def traversal_search(self, query: str, *, k: int = 4, depth: int = 1) -> Iterable[Document]:
519+
def traversal_search(
520+
self, query: str, *, k: int = 4, depth: int = 1
521+
) -> Iterable[Document]:
500522
"""Retrieve documents from this knowledge store.
501523
502524
First, `k` nodes are retrieved using a vector search for the `query` string.

0 commit comments

Comments
 (0)