|
1 |
| -from typing import ( |
2 |
| - Any, |
3 |
| - Iterable, |
4 |
| - List, |
5 |
| - Optional, |
6 |
| - Type, |
| 1 | +from langchain_community.graph_vectorstores import ( |
| 2 | + CassandraGraphVectorStore as CassandraGraphStore, |
7 | 3 | )
|
8 | 4 |
|
9 |
| -from cassandra.cluster import Session |
10 |
| -from langchain_community.utilities.cassandra import SetupMode |
11 |
| -from langchain_core.documents import Document |
12 |
| -from langchain_core.embeddings import Embeddings |
13 |
| -from ragstack_knowledge_store import EmbeddingModel, graph_store |
14 |
| -from typing_extensions import override |
15 |
| - |
16 |
| -from .base import GraphStore, Node, nodes_to_documents |
17 |
| - |
18 |
| - |
19 |
| -class _EmbeddingModelAdapter(EmbeddingModel): |
20 |
| - def __init__(self, embeddings: Embeddings): |
21 |
| - self.embeddings = embeddings |
22 |
| - |
23 |
| - def embed_texts(self, texts: List[str]) -> List[List[float]]: |
24 |
| - return self.embeddings.embed_documents(texts) |
25 |
| - |
26 |
| - def embed_query(self, text: str) -> List[float]: |
27 |
| - return self.embeddings.embed_query(text) |
28 |
| - |
29 |
| - async def aembed_texts(self, texts: List[str]) -> List[List[float]]: |
30 |
| - return await self.embeddings.aembed_documents(texts) |
31 |
| - |
32 |
| - async def aembed_query(self, text: str) -> List[float]: |
33 |
| - return await self.embeddings.aembed_query(text) |
34 |
| - |
35 |
| - |
36 |
| -class CassandraGraphStore(GraphStore): |
37 |
| - def __init__( |
38 |
| - self, |
39 |
| - embedding: Embeddings, |
40 |
| - *, |
41 |
| - node_table: str = "graph_nodes", |
42 |
| - targets_table: str = "graph_targets", |
43 |
| - session: Optional[Session] = None, |
44 |
| - keyspace: Optional[str] = None, |
45 |
| - setup_mode: SetupMode = SetupMode.SYNC, |
46 |
| - ): |
47 |
| - """ |
48 |
| - Create the hybrid graph store. |
49 |
| - Parameters configure the ways that edges should be added between |
50 |
| - documents. Many take `Union[bool, Set[str]]`, with `False` disabling |
51 |
| - inference, `True` enabling it globally between all documents, and a set |
52 |
| - of metadata fields defining a scope in which to enable it. Specifically, |
53 |
| - passing a set of metadata fields such as `source` only links documents |
54 |
| - with the same `source` metadata value. |
55 |
| - Args: |
56 |
| - embedding: The embeddings to use for the document content. |
57 |
| - setup_mode: Mode used to create the Cassandra table (SYNC, |
58 |
| - ASYNC or OFF). |
59 |
| - """ |
60 |
| - self._embedding = embedding |
61 |
| - _setup_mode = getattr(graph_store.SetupMode, setup_mode.name) |
62 |
| - |
63 |
| - self.store = graph_store.GraphStore( |
64 |
| - embedding=_EmbeddingModelAdapter(embedding), |
65 |
| - node_table=node_table, |
66 |
| - targets_table=targets_table, |
67 |
| - session=session, |
68 |
| - keyspace=keyspace, |
69 |
| - setup_mode=_setup_mode, |
70 |
| - ) |
71 |
| - |
72 |
| - @property |
73 |
| - @override |
74 |
| - def embeddings(self) -> Optional[Embeddings]: |
75 |
| - return self._embedding |
76 |
| - |
77 |
| - @override |
78 |
| - def add_nodes( |
79 |
| - self, |
80 |
| - nodes: Iterable[Node], |
81 |
| - **kwargs: Any, |
82 |
| - ) -> Iterable[str]: |
83 |
| - _nodes = [ |
84 |
| - graph_store.Node( |
85 |
| - id=node.id, text=node.text, metadata=node.metadata, links=node.links |
86 |
| - ) |
87 |
| - for node in nodes |
88 |
| - ] |
89 |
| - return self.store.add_nodes(_nodes) |
90 |
| - |
91 |
| - @classmethod |
92 |
| - @override |
93 |
| - def from_texts( |
94 |
| - cls: Type["CassandraGraphStore"], |
95 |
| - texts: Iterable[str], |
96 |
| - embedding: Embeddings, |
97 |
| - metadatas: Optional[List[dict]] = None, |
98 |
| - ids: Optional[Iterable[str]] = None, |
99 |
| - **kwargs: Any, |
100 |
| - ) -> "CassandraGraphStore": |
101 |
| - """Return CassandraGraphStore initialized from texts and embeddings.""" |
102 |
| - store = cls(embedding, **kwargs) |
103 |
| - store.add_texts(texts, metadatas, ids=ids) |
104 |
| - return store |
105 |
| - |
106 |
| - @classmethod |
107 |
| - @override |
108 |
| - def from_documents( |
109 |
| - cls: Type["CassandraGraphStore"], |
110 |
| - documents: Iterable[Document], |
111 |
| - embedding: Embeddings, |
112 |
| - ids: Optional[Iterable[str]] = None, |
113 |
| - **kwargs: Any, |
114 |
| - ) -> "CassandraGraphStore": |
115 |
| - """Return CassandraGraphStore initialized from documents and embeddings.""" |
116 |
| - store = cls(embedding, **kwargs) |
117 |
| - store.add_documents(documents, ids=ids) |
118 |
| - return store |
119 |
| - |
120 |
| - @override |
121 |
| - def similarity_search( |
122 |
| - self, query: str, k: int = 4, **kwargs: Any |
123 |
| - ) -> List[Document]: |
124 |
| - embedding_vector = self._embedding.embed_query(query) |
125 |
| - return self.similarity_search_by_vector(embedding_vector, k=k, **kwargs) |
126 |
| - |
127 |
| - @override |
128 |
| - def similarity_search_by_vector( |
129 |
| - self, embedding: List[float], k: int = 4, **kwargs: Any |
130 |
| - ) -> List[Document]: |
131 |
| - nodes = self.store.similarity_search(embedding, k=k) |
132 |
| - return list(nodes_to_documents(nodes)) |
133 |
| - |
134 |
| - @override |
135 |
| - def traversal_search( |
136 |
| - self, |
137 |
| - query: str, |
138 |
| - *, |
139 |
| - k: int = 4, |
140 |
| - depth: int = 1, |
141 |
| - **kwargs: Any, |
142 |
| - ) -> Iterable[Document]: |
143 |
| - nodes = self.store.traversal_search(query, k=k, depth=depth) |
144 |
| - return nodes_to_documents(nodes) |
145 |
| - |
146 |
| - @override |
147 |
| - def mmr_traversal_search( |
148 |
| - self, |
149 |
| - query: str, |
150 |
| - *, |
151 |
| - k: int = 4, |
152 |
| - depth: int = 2, |
153 |
| - fetch_k: int = 100, |
154 |
| - adjacent_k: int = 10, |
155 |
| - lambda_mult: float = 0.5, |
156 |
| - score_threshold: float = float("-inf"), |
157 |
| - **kwargs: Any, |
158 |
| - ) -> Iterable[Document]: |
159 |
| - nodes = self.store.mmr_traversal_search( |
160 |
| - query, |
161 |
| - k=k, |
162 |
| - depth=depth, |
163 |
| - fetch_k=fetch_k, |
164 |
| - adjacent_k=adjacent_k, |
165 |
| - lambda_mult=lambda_mult, |
166 |
| - score_threshold=score_threshold, |
167 |
| - ) |
168 |
| - return nodes_to_documents(nodes) |
| 5 | +__all__ = [ |
| 6 | + "CassandraGraphStore", |
| 7 | +] |
0 commit comments