Skip to content

Commit 8366f58

Browse files
committed
Use langchain GraphVectorStore
1 parent fca5edb commit 8366f58

File tree

13 files changed

+38
-1100
lines changed

13 files changed

+38
-1100
lines changed

libs/langchain/ragstack_langchain/graph_store/base.py

Lines changed: 10 additions & 529 deletions
Large diffs are not rendered by default.
Lines changed: 5 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -1,166 +1,8 @@
1-
from typing import (
2-
Any,
3-
Iterable,
4-
List,
5-
Optional,
6-
Type,
1+
from langchain_community.graph_vectorstores import (
2+
CassandraGraphVectorStore as CassandraGraphStore,
73
)
84

9-
from cassandra.cluster import Session
10-
from langchain_community.utilities.cassandra import SetupMode
11-
from langchain_core.documents import Document
12-
from langchain_core.embeddings import Embeddings
13-
from ragstack_knowledge_store import EmbeddingModel, graph_store
14-
from typing_extensions import override
155

16-
from .base import GraphStore, Node, nodes_to_documents
17-
18-
19-
class _EmbeddingModelAdapter(EmbeddingModel):
20-
def __init__(self, embeddings: Embeddings):
21-
self.embeddings = embeddings
22-
23-
def embed_texts(self, texts: List[str]) -> List[List[float]]:
24-
return self.embeddings.embed_documents(texts)
25-
26-
def embed_query(self, text: str) -> List[float]:
27-
return self.embeddings.embed_query(text)
28-
29-
async def aembed_texts(self, texts: List[str]) -> List[List[float]]:
30-
return await self.embeddings.aembed_documents(texts)
31-
32-
async def aembed_query(self, text: str) -> List[float]:
33-
return await self.embeddings.aembed_query(text)
34-
35-
36-
class CassandraGraphStore(GraphStore):
37-
def __init__(
38-
self,
39-
embedding: Embeddings,
40-
*,
41-
node_table: str = "graph_nodes",
42-
session: Optional[Session] = None,
43-
keyspace: Optional[str] = None,
44-
setup_mode: SetupMode = SetupMode.SYNC,
45-
):
46-
"""
47-
Create the hybrid graph store.
48-
Parameters configure the ways that edges should be added between
49-
documents. Many take `Union[bool, Set[str]]`, with `False` disabling
50-
inference, `True` enabling it globally between all documents, and a set
51-
of metadata fields defining a scope in which to enable it. Specifically,
52-
passing a set of metadata fields such as `source` only links documents
53-
with the same `source` metadata value.
54-
Args:
55-
embedding: The embeddings to use for the document content.
56-
setup_mode: Mode used to create the Cassandra table (SYNC,
57-
ASYNC or OFF).
58-
"""
59-
self._embedding = embedding
60-
_setup_mode = getattr(graph_store.SetupMode, setup_mode.name)
61-
62-
self.store = graph_store.GraphStore(
63-
embedding=_EmbeddingModelAdapter(embedding),
64-
node_table=node_table,
65-
session=session,
66-
keyspace=keyspace,
67-
setup_mode=_setup_mode,
68-
)
69-
70-
@property
71-
@override
72-
def embeddings(self) -> Optional[Embeddings]:
73-
return self._embedding
74-
75-
@override
76-
def add_nodes(
77-
self,
78-
nodes: Iterable[Node],
79-
**kwargs: Any,
80-
) -> Iterable[str]:
81-
_nodes = [
82-
graph_store.Node(
83-
id=node.id, text=node.text, metadata=node.metadata, links=node.links
84-
)
85-
for node in nodes
86-
]
87-
return self.store.add_nodes(_nodes)
88-
89-
@classmethod
90-
@override
91-
def from_texts(
92-
cls: Type["CassandraGraphStore"],
93-
texts: Iterable[str],
94-
embedding: Embeddings,
95-
metadatas: Optional[List[dict]] = None,
96-
ids: Optional[Iterable[str]] = None,
97-
**kwargs: Any,
98-
) -> "CassandraGraphStore":
99-
"""Return CassandraGraphStore initialized from texts and embeddings."""
100-
store = cls(embedding, **kwargs)
101-
store.add_texts(texts, metadatas, ids=ids)
102-
return store
103-
104-
@classmethod
105-
@override
106-
def from_documents(
107-
cls: Type["CassandraGraphStore"],
108-
documents: Iterable[Document],
109-
embedding: Embeddings,
110-
ids: Optional[Iterable[str]] = None,
111-
**kwargs: Any,
112-
) -> "CassandraGraphStore":
113-
"""Return CassandraGraphStore initialized from documents and embeddings."""
114-
store = cls(embedding, **kwargs)
115-
store.add_documents(documents, ids=ids)
116-
return store
117-
118-
@override
119-
def similarity_search(
120-
self, query: str, k: int = 4, **kwargs: Any
121-
) -> List[Document]:
122-
embedding_vector = self._embedding.embed_query(query)
123-
return self.similarity_search_by_vector(embedding_vector, k=k, **kwargs)
124-
125-
@override
126-
def similarity_search_by_vector(
127-
self, embedding: List[float], k: int = 4, **kwargs: Any
128-
) -> List[Document]:
129-
nodes = self.store.similarity_search(embedding, k=k)
130-
return list(nodes_to_documents(nodes))
131-
132-
@override
133-
def traversal_search(
134-
self,
135-
query: str,
136-
*,
137-
k: int = 4,
138-
depth: int = 1,
139-
**kwargs: Any,
140-
) -> Iterable[Document]:
141-
nodes = self.store.traversal_search(query, k=k, depth=depth)
142-
return nodes_to_documents(nodes)
143-
144-
@override
145-
def mmr_traversal_search(
146-
self,
147-
query: str,
148-
*,
149-
k: int = 4,
150-
depth: int = 2,
151-
fetch_k: int = 100,
152-
adjacent_k: int = 10,
153-
lambda_mult: float = 0.5,
154-
score_threshold: float = float("-inf"),
155-
**kwargs: Any,
156-
) -> Iterable[Document]:
157-
nodes = self.store.mmr_traversal_search(
158-
query,
159-
k=k,
160-
depth=depth,
161-
fetch_k=fetch_k,
162-
adjacent_k=adjacent_k,
163-
lambda_mult=lambda_mult,
164-
score_threshold=score_threshold,
165-
)
166-
return nodes_to_documents(nodes)
6+
__all__ = [
7+
"CassandraGraphStore",
8+
]

libs/langchain/ragstack_langchain/graph_store/extractors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .hierarchy_link_extractor import HierarchyInput, HierarchyLinkExtractor
33
from .html_link_extractor import HtmlInput, HtmlLinkExtractor
44
from .keybert_link_extractor import KeybertInput, KeybertLinkExtractor
5+
from .link_extractor import LinkExtractor
56
from .link_extractor_adapter import LinkExtractorAdapter
67
from .link_extractor_transformer import LinkExtractorTransformer
78

libs/langchain/ragstack_langchain/graph_store/extractors/gliner_link_extractor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import Any, Dict, Iterable, List, Optional, Set
22

3+
from langchain_core.graph_vectorstores import Link
4+
35
from ragstack_langchain.graph_store.extractors.link_extractor import LinkExtractor
4-
from ragstack_langchain.graph_store.links import Link
56

67
# TypeAlias is not available in Python 2.9, we can't use that or the newer `type`.
78
GLiNERInput = str

libs/langchain/ragstack_langchain/graph_store/extractors/hierarchy_link_extractor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from typing import Callable, List, Set
22

33
from langchain_core.documents import Document
4-
5-
from ragstack_langchain.graph_store.links import Link
4+
from langchain_core.graph_vectorstores import Link
65

76
from .link_extractor import LinkExtractor
87
from .link_extractor_adapter import LinkExtractorAdapter

libs/langchain/ragstack_langchain/graph_store/extractors/html_link_extractor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from urllib.parse import urldefrag, urljoin, urlparse
44

55
from langchain_core.documents import Document
6-
7-
from ragstack_langchain.graph_store.links import Link
6+
from langchain_core.graph_vectorstores import Link
87

98
from .link_extractor import LinkExtractor
109
from .link_extractor_adapter import LinkExtractorAdapter

libs/langchain/ragstack_langchain/graph_store/extractors/keybert_link_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from typing import Any, Dict, Iterable, Optional, Set, Union
22

33
from langchain_core.documents import Document
4+
from langchain_core.graph_vectorstores import Link
45

56
from ragstack_langchain.graph_store.extractors.link_extractor import LinkExtractor
6-
from ragstack_langchain.graph_store.links import Link
77

88
# TypeAlias is not available in Python 2.9, we can't use that or the newer `type`.
99
KeybertInput = Union[str, Document]

libs/langchain/ragstack_langchain/graph_store/extractors/link_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import TYPE_CHECKING, Generic, Iterable, Set, TypeVar
55

66
if TYPE_CHECKING:
7-
from ragstack_langchain.graph_store.links import Link
7+
from langchain_core.graph_vectorstores import Link
88

99
InputT = TypeVar("InputT")
1010

libs/langchain/ragstack_langchain/graph_store/extractors/link_extractor_adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import Callable, Iterable, Set, TypeVar
22

3+
from langchain_core.graph_vectorstores import Link
4+
35
from ragstack_langchain.graph_store.extractors.link_extractor import LinkExtractor
4-
from ragstack_langchain.graph_store.links import Link
56

67
InputT = TypeVar("InputT")
78
UnderlyingInputT = TypeVar("UnderlyingInputT")

libs/langchain/ragstack_langchain/graph_store/extractors/link_extractor_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from langchain_core.documents import Document
44
from langchain_core.documents.transformers import BaseDocumentTransformer
5+
from langchain_core.graph_vectorstores.links import add_links
56

67
from ragstack_langchain.graph_store.extractors.link_extractor import LinkExtractor
7-
from ragstack_langchain.graph_store.links import add_links
88

99

1010
class LinkExtractorTransformer(BaseDocumentTransformer):

0 commit comments

Comments
 (0)