Skip to content

Commit a3c2f17

Browse files
committed
Use langchain GraphVectorStore
1 parent e69eb45 commit a3c2f17

File tree

13 files changed

+38
-1108
lines changed

13 files changed

+38
-1108
lines changed

libs/langchain/ragstack_langchain/graph_store/base.py

Lines changed: 10 additions & 529 deletions
Large diffs are not rendered by default.
Lines changed: 5 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -1,168 +1,7 @@
1-
from typing import (
2-
Any,
3-
Iterable,
4-
List,
5-
Optional,
6-
Type,
1+
from langchain_community.graph_vectorstores import (
2+
CassandraGraphVectorStore as CassandraGraphStore,
73
)
84

9-
from cassandra.cluster import Session
10-
from langchain_community.utilities.cassandra import SetupMode
11-
from langchain_core.documents import Document
12-
from langchain_core.embeddings import Embeddings
13-
from ragstack_knowledge_store import EmbeddingModel, graph_store
14-
from typing_extensions import override
15-
16-
from .base import GraphStore, Node, nodes_to_documents
17-
18-
19-
class _EmbeddingModelAdapter(EmbeddingModel):
20-
def __init__(self, embeddings: Embeddings):
21-
self.embeddings = embeddings
22-
23-
def embed_texts(self, texts: List[str]) -> List[List[float]]:
24-
return self.embeddings.embed_documents(texts)
25-
26-
def embed_query(self, text: str) -> List[float]:
27-
return self.embeddings.embed_query(text)
28-
29-
async def aembed_texts(self, texts: List[str]) -> List[List[float]]:
30-
return await self.embeddings.aembed_documents(texts)
31-
32-
async def aembed_query(self, text: str) -> List[float]:
33-
return await self.embeddings.aembed_query(text)
34-
35-
36-
class CassandraGraphStore(GraphStore):
37-
def __init__(
38-
self,
39-
embedding: Embeddings,
40-
*,
41-
node_table: str = "graph_nodes",
42-
targets_table: str = "graph_targets",
43-
session: Optional[Session] = None,
44-
keyspace: Optional[str] = None,
45-
setup_mode: SetupMode = SetupMode.SYNC,
46-
):
47-
"""
48-
Create the hybrid graph store.
49-
Parameters configure the ways that edges should be added between
50-
documents. Many take `Union[bool, Set[str]]`, with `False` disabling
51-
inference, `True` enabling it globally between all documents, and a set
52-
of metadata fields defining a scope in which to enable it. Specifically,
53-
passing a set of metadata fields such as `source` only links documents
54-
with the same `source` metadata value.
55-
Args:
56-
embedding: The embeddings to use for the document content.
57-
setup_mode: Mode used to create the Cassandra table (SYNC,
58-
ASYNC or OFF).
59-
"""
60-
self._embedding = embedding
61-
_setup_mode = getattr(graph_store.SetupMode, setup_mode.name)
62-
63-
self.store = graph_store.GraphStore(
64-
embedding=_EmbeddingModelAdapter(embedding),
65-
node_table=node_table,
66-
targets_table=targets_table,
67-
session=session,
68-
keyspace=keyspace,
69-
setup_mode=_setup_mode,
70-
)
71-
72-
@property
73-
@override
74-
def embeddings(self) -> Optional[Embeddings]:
75-
return self._embedding
76-
77-
@override
78-
def add_nodes(
79-
self,
80-
nodes: Iterable[Node],
81-
**kwargs: Any,
82-
) -> Iterable[str]:
83-
_nodes = [
84-
graph_store.Node(
85-
id=node.id, text=node.text, metadata=node.metadata, links=node.links
86-
)
87-
for node in nodes
88-
]
89-
return self.store.add_nodes(_nodes)
90-
91-
@classmethod
92-
@override
93-
def from_texts(
94-
cls: Type["CassandraGraphStore"],
95-
texts: Iterable[str],
96-
embedding: Embeddings,
97-
metadatas: Optional[List[dict]] = None,
98-
ids: Optional[Iterable[str]] = None,
99-
**kwargs: Any,
100-
) -> "CassandraGraphStore":
101-
"""Return CassandraGraphStore initialized from texts and embeddings."""
102-
store = cls(embedding, **kwargs)
103-
store.add_texts(texts, metadatas, ids=ids)
104-
return store
105-
106-
@classmethod
107-
@override
108-
def from_documents(
109-
cls: Type["CassandraGraphStore"],
110-
documents: Iterable[Document],
111-
embedding: Embeddings,
112-
ids: Optional[Iterable[str]] = None,
113-
**kwargs: Any,
114-
) -> "CassandraGraphStore":
115-
"""Return CassandraGraphStore initialized from documents and embeddings."""
116-
store = cls(embedding, **kwargs)
117-
store.add_documents(documents, ids=ids)
118-
return store
119-
120-
@override
121-
def similarity_search(
122-
self, query: str, k: int = 4, **kwargs: Any
123-
) -> List[Document]:
124-
embedding_vector = self._embedding.embed_query(query)
125-
return self.similarity_search_by_vector(embedding_vector, k=k, **kwargs)
126-
127-
@override
128-
def similarity_search_by_vector(
129-
self, embedding: List[float], k: int = 4, **kwargs: Any
130-
) -> List[Document]:
131-
nodes = self.store.similarity_search(embedding, k=k)
132-
return list(nodes_to_documents(nodes))
133-
134-
@override
135-
def traversal_search(
136-
self,
137-
query: str,
138-
*,
139-
k: int = 4,
140-
depth: int = 1,
141-
**kwargs: Any,
142-
) -> Iterable[Document]:
143-
nodes = self.store.traversal_search(query, k=k, depth=depth)
144-
return nodes_to_documents(nodes)
145-
146-
@override
147-
def mmr_traversal_search(
148-
self,
149-
query: str,
150-
*,
151-
k: int = 4,
152-
depth: int = 2,
153-
fetch_k: int = 100,
154-
adjacent_k: int = 10,
155-
lambda_mult: float = 0.5,
156-
score_threshold: float = float("-inf"),
157-
**kwargs: Any,
158-
) -> Iterable[Document]:
159-
nodes = self.store.mmr_traversal_search(
160-
query,
161-
k=k,
162-
depth=depth,
163-
fetch_k=fetch_k,
164-
adjacent_k=adjacent_k,
165-
lambda_mult=lambda_mult,
166-
score_threshold=score_threshold,
167-
)
168-
return nodes_to_documents(nodes)
5+
__all__ = [
6+
"CassandraGraphStore",
7+
]

libs/langchain/ragstack_langchain/graph_store/extractors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .hierarchy_link_extractor import HierarchyInput, HierarchyLinkExtractor
33
from .html_link_extractor import HtmlInput, HtmlLinkExtractor
44
from .keybert_link_extractor import KeybertInput, KeybertLinkExtractor
5+
from .link_extractor import LinkExtractor
56
from .link_extractor_adapter import LinkExtractorAdapter
67
from .link_extractor_transformer import LinkExtractorTransformer
78

libs/langchain/ragstack_langchain/graph_store/extractors/gliner_link_extractor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import Any, Dict, Iterable, List, Optional, Set
22

3+
from langchain_core.graph_vectorstores import Link
4+
35
from ragstack_langchain.graph_store.extractors.link_extractor import LinkExtractor
4-
from ragstack_langchain.graph_store.links import Link
56

67
# TypeAlias is not available in Python 2.9, we can't use that or the newer `type`.
78
GLiNERInput = str

libs/langchain/ragstack_langchain/graph_store/extractors/hierarchy_link_extractor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from typing import Callable, List, Set
22

33
from langchain_core.documents import Document
4-
5-
from ragstack_langchain.graph_store.links import Link
4+
from langchain_core.graph_vectorstores import Link
65

76
from .link_extractor import LinkExtractor
87
from .link_extractor_adapter import LinkExtractorAdapter

libs/langchain/ragstack_langchain/graph_store/extractors/html_link_extractor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from urllib.parse import urldefrag, urljoin, urlparse
44

55
from langchain_core.documents import Document
6-
7-
from ragstack_langchain.graph_store.links import Link
6+
from langchain_core.graph_vectorstores import Link
87

98
from .link_extractor import LinkExtractor
109
from .link_extractor_adapter import LinkExtractorAdapter

libs/langchain/ragstack_langchain/graph_store/extractors/keybert_link_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from typing import Any, Dict, Iterable, Optional, Set, Union
22

33
from langchain_core.documents import Document
4+
from langchain_core.graph_vectorstores import Link
45

56
from ragstack_langchain.graph_store.extractors.link_extractor import LinkExtractor
6-
from ragstack_langchain.graph_store.links import Link
77

88
# TypeAlias is not available in Python 2.9, we can't use that or the newer `type`.
99
KeybertInput = Union[str, Document]

libs/langchain/ragstack_langchain/graph_store/extractors/link_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import TYPE_CHECKING, Generic, Iterable, Set, TypeVar
55

66
if TYPE_CHECKING:
7-
from ragstack_langchain.graph_store.links import Link
7+
from langchain_core.graph_vectorstores import Link
88

99
InputT = TypeVar("InputT")
1010

libs/langchain/ragstack_langchain/graph_store/extractors/link_extractor_adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import Callable, Iterable, Set, TypeVar
22

3+
from langchain_core.graph_vectorstores import Link
4+
35
from ragstack_langchain.graph_store.extractors.link_extractor import LinkExtractor
4-
from ragstack_langchain.graph_store.links import Link
56

67
InputT = TypeVar("InputT")
78
UnderlyingInputT = TypeVar("UnderlyingInputT")

libs/langchain/ragstack_langchain/graph_store/extractors/link_extractor_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from langchain_core.documents import Document
44
from langchain_core.documents.transformers import BaseDocumentTransformer
5+
from langchain_core.graph_vectorstores.links import add_links
56

67
from ragstack_langchain.graph_store.extractors.link_extractor import LinkExtractor
7-
from ragstack_langchain.graph_store.links import add_links
88

99

1010
class LinkExtractorTransformer(BaseDocumentTransformer):

0 commit comments

Comments
 (0)