diff --git a/libs/e2e-tests/pyproject.llamaindex.toml b/libs/e2e-tests/pyproject.llamaindex.toml index 4a0c9908e..939dd1422 100644 --- a/libs/e2e-tests/pyproject.llamaindex.toml +++ b/libs/e2e-tests/pyproject.llamaindex.toml @@ -42,9 +42,9 @@ llama-index-multi-modal-llms-gemini = { git = "https://github.com/run-llama/llam llama-parse = { git = "https://github.com/run-llama/llama_parse.git", branch = "main" } -langchain = "0.2.5" -langchain-core = "0.2.9" -langchain-community = "0.2.5" +langchain = "0.2.10" +langchain-core = "0.2.22" +langchain-community = "0.2.9" langchain-astradb = "0.3.3" langchain-openai = "0.1.8" langchain-google-genai = { version = "1.0.6" } diff --git a/libs/knowledge-store/notebooks/astra_support.ipynb b/libs/knowledge-store/notebooks/astra_support.ipynb index 2f6d11a33..cd3e1ff0c 100644 --- a/libs/knowledge-store/notebooks/astra_support.ipynb +++ b/libs/knowledge-store/notebooks/astra_support.ipynb @@ -14,7 +14,7 @@ "outputs": [], "source": [ "%pip install -q \\\n", - " ragstack-ai-langchain[knowledge-store] \\\n", + " ragstack-ai-langchain[knowledge-store]==1.3.0 \\\n", " beautifulsoup4 markdownify python-dotenv" ] }, diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index bc35a8ee8..92fa39cc9 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -18,9 +18,9 @@ ragstack-ai-colbert = { version = "1.0.5", optional = true } ragstack-ai-knowledge-store = { version = "0.1.0", optional = true } # langchain -langchain = "0.2.5" -langchain-core = "0.2.9" -langchain-community = "0.2.5" +langchain = "0.2.10" +langchain-core = "0.2.22" +langchain-community = "0.2.9" langchain-astradb = "0.3.3" langchain-openai = "0.1.8" langchain-google-genai = { version = "1.0.6", optional = true } diff --git a/libs/langchain/ragstack_langchain/graph_store/__init__.py b/libs/langchain/ragstack_langchain/graph_store/__init__.py index cbf307c94..8b1378917 100644 --- a/libs/langchain/ragstack_langchain/graph_store/__init__.py +++ b/libs/langchain/ragstack_langchain/graph_store/__init__.py @@ -1,4 +1 @@ -from .base import GraphStore, Node -from .cassandra import CassandraGraphStore -__all__ = ["CassandraGraphStore", "GraphStore", "Node"] diff --git a/libs/langchain/ragstack_langchain/graph_store/base.py b/libs/langchain/ragstack_langchain/graph_store/base.py deleted file mode 100644 index b427812cd..000000000 --- a/libs/langchain/ragstack_langchain/graph_store/base.py +++ /dev/null @@ -1,531 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from typing import ( - TYPE_CHECKING, - Any, - AsyncIterable, - ClassVar, - Collection, - Iterable, - Iterator, - List, - Optional, - Set, -) - -from langchain_core.documents import Document -from langchain_core.load import Serializable -from langchain_core.pydantic_v1 import Field -from langchain_core.runnables import run_in_executor -from langchain_core.vectorstores import VectorStore, VectorStoreRetriever - -from ragstack_langchain.graph_store.links import METADATA_LINKS_KEY, Link - -if TYPE_CHECKING: - from langchain_core.callbacks import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, - ) - - -def _has_next(iterator: Iterator) -> bool: - """Checks if the iterator has more elements. - Warning: consumes an element from the iterator""" - sentinel = object() - return next(iterator, sentinel) is not sentinel - - -METADATA_CONTENT_ID_KEY = "content_id" - - -class Node(Serializable): - """Node in the GraphStore.""" - - id: Optional[str] = None - """Unique ID for the node. Will be generated by the GraphStore if not set.""" - text: str - """Text contained by the node.""" - metadata: dict = Field(default_factory=dict) - """Metadata for the node.""" - links: Set[Link] = Field(default_factory=set) - """Links associated with the node.""" - - -def _texts_to_nodes( - texts: Iterable[str], - metadatas: Optional[Iterable[dict]], - ids: Optional[Iterable[str]], -) -> Iterator[Node]: - metadatas_it = iter(metadatas) if metadatas else None - ids_it = iter(ids) if ids else None - for text in texts: - try: - _metadata = next(metadatas_it).copy() if metadatas_it else {} - except StopIteration as e: - raise ValueError("texts iterable longer than metadatas") from e - try: - _id = next(ids_it) if ids_it else None - _id = _id or _metadata.pop(METADATA_CONTENT_ID_KEY, None) - except StopIteration as e: - raise ValueError("texts iterable longer than ids") from e - - links = _metadata.pop(METADATA_LINKS_KEY, set()) - if not isinstance(links, Set): - links = set(links) - yield Node( - id=_id, - metadata=_metadata, - text=text, - links=links, - ) - if ids_it and _has_next(ids_it): - raise ValueError("ids iterable longer than texts") - if metadatas_it and _has_next(metadatas_it): - raise ValueError("metadatas iterable longer than texts") - - -def _documents_to_nodes( - documents: Iterable[Document], ids: Optional[Iterable[str]] -) -> Iterator[Node]: - ids_it = iter(ids) if ids else None - for doc in documents: - try: - _id = next(ids_it) if ids_it else None - _id = _id or doc.metadata.pop(METADATA_CONTENT_ID_KEY, None) - except StopIteration as e: - raise ValueError("documents iterable longer than ids") from e - metadata = doc.metadata.copy() - links = metadata.pop(METADATA_LINKS_KEY, set()) - if not isinstance(links, Set): - links = set(links) - yield Node( - id=_id, - metadata=metadata, - text=doc.page_content, - links=links, - ) - if ids_it and _has_next(ids_it): - raise ValueError("ids iterable longer than documents") - - -def nodes_to_documents(nodes: Iterable[Node]) -> Iterator[Document]: - for node in nodes: - metadata = node.metadata.copy() - metadata[METADATA_CONTENT_ID_KEY] = node.id - metadata[METADATA_LINKS_KEY] = { - # Convert the core `Link` (from the node) back to the local `Link`. - Link(kind=link.kind, direction=link.direction, tag=link.tag) - for link in node.links - } - - yield Document( - page_content=node.text, - metadata=metadata, - ) - - -class GraphStore(VectorStore): - """A hybrid vector-and-graph graph store. - - Document chunks support vector-similarity search as well as edges linking - chunks based on structural and semantic properties. - """ - - @abstractmethod - def add_nodes( - self, - nodes: Iterable[Node], - **kwargs: Any, - ) -> Iterable[str]: - """Add nodes to the graph store. - - Args: - nodes: the nodes to add. - """ - - async def aadd_nodes( - self, - nodes: Iterable[Node], - **kwargs: Any, - ) -> AsyncIterable[str]: - """Add nodes to the graph store. - - Args: - nodes: the nodes to add. - """ - iterator = iter(await run_in_executor(None, self.add_nodes, nodes, **kwargs)) - done = object() - while True: - doc = await run_in_executor(None, next, iterator, done) - if doc is done: - break - yield doc # type: ignore[misc] - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[Iterable[dict]] = None, - *, - ids: Optional[Iterable[str]] = None, - **kwargs: Any, - ) -> List[str]: - nodes = _texts_to_nodes(texts, metadatas, ids) - return list(self.add_nodes(nodes, **kwargs)) - - async def aadd_texts( - self, - texts: Iterable[str], - metadatas: Optional[Iterable[dict]] = None, - *, - ids: Optional[Iterable[str]] = None, - **kwargs: Any, - ) -> List[str]: - nodes = _texts_to_nodes(texts, metadatas, ids) - return [_id async for _id in self.aadd_nodes(nodes, **kwargs)] - - def add_documents( - self, - documents: Iterable[Document], - *, - ids: Optional[Iterable[str]] = None, - **kwargs: Any, - ) -> List[str]: - nodes = _documents_to_nodes(documents, ids) - return list(self.add_nodes(nodes, **kwargs)) - - async def aadd_documents( - self, - documents: Iterable[Document], - *, - ids: Optional[Iterable[str]] = None, - **kwargs: Any, - ) -> List[str]: - nodes = _documents_to_nodes(documents, ids) - return [_id async for _id in self.aadd_nodes(nodes, **kwargs)] - - @abstractmethod - def traversal_search( - self, - query: str, - *, - k: int = 4, - depth: int = 1, - **kwargs: Any, - ) -> Iterable[Document]: - """Retrieve documents from traversing this graph store. - - First, `k` nodes are retrieved using a search for each `query` string. - Then, additional nodes are discovered up to the given `depth` from those - starting nodes. - - Args: - query: The query string. - k: The number of Documents to return from the initial search. - Defaults to 4. Applies to each of the query strings. - depth: The maximum depth of edges to traverse. Defaults to 1. - Returns: - Retrieved documents. - """ - - async def atraversal_search( - self, - query: str, - *, - k: int = 4, - depth: int = 1, - **kwargs: Any, - ) -> AsyncIterable[Document]: - """Retrieve documents from traversing this graph store. - - First, `k` nodes are retrieved using a search for each `query` string. - Then, additional nodes are discovered up to the given `depth` from those - starting nodes. - - Args: - query: The query string. - k: The number of Documents to return from the initial search. - Defaults to 4. Applies to each of the query strings. - depth: The maximum depth of edges to traverse. Defaults to 1. - Returns: - Retrieved documents. - """ - iterator = iter( - await run_in_executor( - None, self.traversal_search, query, k=k, depth=depth, **kwargs - ) - ) - done = object() - while True: - doc = await run_in_executor(None, next, iterator, done) - if doc is done: - break - yield doc # type: ignore[misc] - - @abstractmethod - def mmr_traversal_search( - self, - query: str, - *, - k: int = 4, - depth: int = 2, - fetch_k: int = 100, - adjacent_k: int = 10, - lambda_mult: float = 0.5, - score_threshold: float = float("-inf"), - **kwargs: Any, - ) -> Iterable[Document]: - """Retrieve documents from this graph store using MMR-traversal. - - This strategy first retrieves the top `fetch_k` results by similarity to - the question. It then selects the top `k` results based on - maximum-marginal relevance using the given `lambda_mult`. - - At each step, it considers the (remaining) documents from `fetch_k` as - well as any documents connected by edges to a selected document - retrieved based on similarity (a "root"). - - Args: - query: The query string to search for. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch via similarity. - Defaults to 100. - adjacent_k: Number of adjacent Documents to fetch. - Defaults to 10. - depth: Maximum depth of a node (number of edges) from a node - retrieved via similarity. Defaults to 2. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding to maximum - diversity and 1 to minimum diversity. Defaults to 0.5. - score_threshold: Only documents with a score greater than or equal - this threshold will be chosen. Defaults to negative infinity. - """ - - async def ammr_traversal_search( - self, - query: str, - *, - k: int = 4, - depth: int = 2, - fetch_k: int = 100, - adjacent_k: int = 10, - lambda_mult: float = 0.5, - score_threshold: float = float("-inf"), - **kwargs: Any, - ) -> AsyncIterable[Document]: - """Retrieve documents from this graph store using MMR-traversal. - - This strategy first retrieves the top `fetch_k` results by similarity to - the question. It then selects the top `k` results based on - maximum-marginal relevance using the given `lambda_mult`. - - At each step, it considers the (remaining) documents from `fetch_k` as - well as any documents connected by edges to a selected document - retrieved based on similarity (a "root"). - - Args: - query: The query string to search for. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch via similarity. - Defaults to 100. - adjacent_k: Number of adjacent Documents to fetch. - Defaults to 10. - depth: Maximum depth of a node (number of edges) from a node - retrieved via similarity. Defaults to 2. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding to maximum - diversity and 1 to minimum diversity. Defaults to 0.5. - score_threshold: Only documents with a score greater than or equal - this threshold will be chosen. Defaults to negative infinity. - """ - iterator = iter( - await run_in_executor( - None, - self.mmr_traversal_search, - query, - k=k, - fetch_k=fetch_k, - adjacent_k=adjacent_k, - depth=depth, - lambda_mult=lambda_mult, - score_threshold=score_threshold, - **kwargs, - ) - ) - done = object() - while True: - doc = await run_in_executor(None, next, iterator, done) - if doc is done: - break - yield doc # type: ignore[misc] - - def similarity_search( - self, query: str, k: int = 4, **kwargs: Any - ) -> List[Document]: - kwargs.pop("depth") - return list(self.traversal_search(query, k=k, depth=0, **kwargs)) - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> List[Document]: - kwargs.pop("depth") - return list( - self.mmr_traversal_search( - query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0, **kwargs - ) - ) - - async def asimilarity_search( - self, query: str, k: int = 4, **kwargs: Any - ) -> List[Document]: - return [ - doc async for doc in self.atraversal_search(query, k=k, depth=0, **kwargs) - ] - - def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]: - if search_type == "similarity": - return self.similarity_search(query, **kwargs) - if search_type == "similarity_score_threshold": - docs_and_similarities = self.similarity_search_with_relevance_scores( - query, **kwargs - ) - return [doc for doc, _ in docs_and_similarities] - if search_type == "mmr": - return self.max_marginal_relevance_search(query, **kwargs) - if search_type == "traversal": - return list(self.traversal_search(query, **kwargs)) - if search_type == "mmr_traversal": - return list(self.mmr_traversal_search(query, **kwargs)) - raise ValueError( - f"search_type of {search_type} not allowed. Expected " - "search_type to be 'similarity', 'similarity_score_threshold', " - "'mmr' or 'traversal'." - ) - - async def asearch( - self, query: str, search_type: str, **kwargs: Any - ) -> List[Document]: - if search_type == "similarity": - return await self.asimilarity_search(query, **kwargs) - if search_type == "similarity_score_threshold": - docs_and_similarities = await self.asimilarity_search_with_relevance_scores( - query, **kwargs - ) - return [doc for doc, _ in docs_and_similarities] - if search_type == "mmr": - return await self.amax_marginal_relevance_search(query, **kwargs) - if search_type == "traversal": - return [doc async for doc in self.atraversal_search(query, **kwargs)] - raise ValueError( - f"search_type of {search_type} not allowed. Expected " - "search_type to be 'similarity', 'similarity_score_threshold', " - "'mmr' or 'traversal'." - ) - - def as_retriever(self, **kwargs: Any) -> GraphStoreRetriever: - """Return GraphStoreRetriever initialized from this GraphStore. - - Args: - search_type (Optional[str]): Defines the type of search that - the Retriever should perform. - Can be "traversal" (default), "similarity", "mmr", or - "similarity_score_threshold". - search_kwargs (Optional[Dict]): Keyword arguments to pass to the - search function. Can include things like: - k: Amount of documents to return (Default: 4) - depth: The maximum depth of edges to traverse (Default: 1) - score_threshold: Minimum relevance threshold - for similarity_score_threshold - fetch_k: Amount of documents to pass to MMR algorithm (Default: 20) - lambda_mult: Diversity of results returned by MMR; - 1 for minimum diversity and 0 for maximum. (Default: 0.5) - Returns: - Retriever for this GraphStore. - - Examples: - - .. code-block:: python - - # Retrieve documents traversing edges - docsearch.as_retriever( - search_type="traversal", - search_kwargs={'k': 6, 'depth': 3} - ) - - # Retrieve more documents with higher diversity - # Useful if your dataset has many similar documents - docsearch.as_retriever( - search_type="mmr", - search_kwargs={'k': 6, 'lambda_mult': 0.25} - ) - - # Fetch more documents for the MMR algorithm to consider - # But only return the top 5 - docsearch.as_retriever( - search_type="mmr", - search_kwargs={'k': 5, 'fetch_k': 50} - ) - - # Only retrieve documents that have a relevance score - # Above a certain threshold - docsearch.as_retriever( - search_type="similarity_score_threshold", - search_kwargs={'score_threshold': 0.8} - ) - - # Only get the single most similar document from the dataset - docsearch.as_retriever(search_kwargs={'k': 1}) - - """ - return GraphStoreRetriever(vectorstore=self, **kwargs) - - -class GraphStoreRetriever(VectorStoreRetriever): - """Retriever class for GraphStore.""" - - vectorstore: GraphStore - """GraphStore to use for retrieval.""" - search_type: str = "traversal" - """Type of search to perform. Defaults to "traversal".""" - allowed_search_types: ClassVar[Collection[str]] = ( - "similarity", - "similarity_score_threshold", - "mmr", - "traversal", - "mmr_traversal", - ) - - def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> List[Document]: - if self.search_type == "traversal": - return list(self.vectorstore.traversal_search(query, **self.search_kwargs)) - if self.search_type == "mmr_traversal": - return list( - self.vectorstore.mmr_traversal_search(query, **self.search_kwargs) - ) - return super()._get_relevant_documents(query, run_manager=run_manager) - - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - if self.search_type == "traversal": - return [ - doc - async for doc in self.vectorstore.atraversal_search( - query, **self.search_kwargs - ) - ] - if self.search_type == "mmr_traversal": - return [ - doc - async for doc in self.vectorstore.ammr_traversal_search( - query, **self.search_kwargs - ) - ] - return await super()._aget_relevant_documents(query, run_manager=run_manager) diff --git a/libs/langchain/ragstack_langchain/graph_store/cassandra.py b/libs/langchain/ragstack_langchain/graph_store/cassandra.py deleted file mode 100644 index 300a46f01..000000000 --- a/libs/langchain/ragstack_langchain/graph_store/cassandra.py +++ /dev/null @@ -1,166 +0,0 @@ -from typing import ( - Any, - Iterable, - List, - Optional, - Type, -) - -from cassandra.cluster import Session -from langchain_community.utilities.cassandra import SetupMode -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings -from ragstack_knowledge_store import EmbeddingModel, graph_store -from typing_extensions import override - -from .base import GraphStore, Node, nodes_to_documents - - -class _EmbeddingModelAdapter(EmbeddingModel): - def __init__(self, embeddings: Embeddings): - self.embeddings = embeddings - - def embed_texts(self, texts: List[str]) -> List[List[float]]: - return self.embeddings.embed_documents(texts) - - def embed_query(self, text: str) -> List[float]: - return self.embeddings.embed_query(text) - - async def aembed_texts(self, texts: List[str]) -> List[List[float]]: - return await self.embeddings.aembed_documents(texts) - - async def aembed_query(self, text: str) -> List[float]: - return await self.embeddings.aembed_query(text) - - -class CassandraGraphStore(GraphStore): - def __init__( - self, - embedding: Embeddings, - *, - node_table: str = "graph_nodes", - session: Optional[Session] = None, - keyspace: Optional[str] = None, - setup_mode: SetupMode = SetupMode.SYNC, - ): - """ - Create the hybrid graph store. - Parameters configure the ways that edges should be added between - documents. Many take `Union[bool, Set[str]]`, with `False` disabling - inference, `True` enabling it globally between all documents, and a set - of metadata fields defining a scope in which to enable it. Specifically, - passing a set of metadata fields such as `source` only links documents - with the same `source` metadata value. - Args: - embedding: The embeddings to use for the document content. - setup_mode: Mode used to create the Cassandra table (SYNC, - ASYNC or OFF). - """ - self._embedding = embedding - _setup_mode = getattr(graph_store.SetupMode, setup_mode.name) - - self.store = graph_store.GraphStore( - embedding=_EmbeddingModelAdapter(embedding), - node_table=node_table, - session=session, - keyspace=keyspace, - setup_mode=_setup_mode, - ) - - @property - @override - def embeddings(self) -> Optional[Embeddings]: - return self._embedding - - @override - def add_nodes( - self, - nodes: Iterable[Node], - **kwargs: Any, - ) -> Iterable[str]: - _nodes = [ - graph_store.Node( - id=node.id, text=node.text, metadata=node.metadata, links=node.links - ) - for node in nodes - ] - return self.store.add_nodes(_nodes) - - @classmethod - @override - def from_texts( - cls: Type["CassandraGraphStore"], - texts: Iterable[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - ids: Optional[Iterable[str]] = None, - **kwargs: Any, - ) -> "CassandraGraphStore": - """Return CassandraGraphStore initialized from texts and embeddings.""" - store = cls(embedding, **kwargs) - store.add_texts(texts, metadatas, ids=ids) - return store - - @classmethod - @override - def from_documents( - cls: Type["CassandraGraphStore"], - documents: Iterable[Document], - embedding: Embeddings, - ids: Optional[Iterable[str]] = None, - **kwargs: Any, - ) -> "CassandraGraphStore": - """Return CassandraGraphStore initialized from documents and embeddings.""" - store = cls(embedding, **kwargs) - store.add_documents(documents, ids=ids) - return store - - @override - def similarity_search( - self, query: str, k: int = 4, **kwargs: Any - ) -> List[Document]: - embedding_vector = self._embedding.embed_query(query) - return self.similarity_search_by_vector(embedding_vector, k=k, **kwargs) - - @override - def similarity_search_by_vector( - self, embedding: List[float], k: int = 4, **kwargs: Any - ) -> List[Document]: - nodes = self.store.similarity_search(embedding, k=k) - return list(nodes_to_documents(nodes)) - - @override - def traversal_search( - self, - query: str, - *, - k: int = 4, - depth: int = 1, - **kwargs: Any, - ) -> Iterable[Document]: - nodes = self.store.traversal_search(query, k=k, depth=depth) - return nodes_to_documents(nodes) - - @override - def mmr_traversal_search( - self, - query: str, - *, - k: int = 4, - depth: int = 2, - fetch_k: int = 100, - adjacent_k: int = 10, - lambda_mult: float = 0.5, - score_threshold: float = float("-inf"), - **kwargs: Any, - ) -> Iterable[Document]: - nodes = self.store.mmr_traversal_search( - query, - k=k, - depth=depth, - fetch_k=fetch_k, - adjacent_k=adjacent_k, - lambda_mult=lambda_mult, - score_threshold=score_threshold, - ) - return nodes_to_documents(nodes) diff --git a/libs/langchain/ragstack_langchain/graph_store/extractors/__init__.py b/libs/langchain/ragstack_langchain/graph_store/extractors/__init__.py index e2cdbd3c0..6a059d886 100644 --- a/libs/langchain/ragstack_langchain/graph_store/extractors/__init__.py +++ b/libs/langchain/ragstack_langchain/graph_store/extractors/__init__.py @@ -1,20 +1,5 @@ -from .gliner_link_extractor import GLiNERInput, GLiNERLinkExtractor -from .hierarchy_link_extractor import HierarchyInput, HierarchyLinkExtractor -from .html_link_extractor import HtmlInput, HtmlLinkExtractor -from .keybert_link_extractor import KeybertInput, KeybertLinkExtractor -from .link_extractor_adapter import LinkExtractorAdapter from .link_extractor_transformer import LinkExtractorTransformer __all__ = [ - "LinkExtractor", - "GLiNERInput", - "GLiNERLinkExtractor", - "HierarchyInput", - "HierarchyLinkExtractor", - "HtmlInput", - "HtmlLinkExtractor", - "KeybertInput", - "KeybertLinkExtractor", - "LinkExtractorAdapter", "LinkExtractorTransformer", ] diff --git a/libs/langchain/ragstack_langchain/graph_store/extractors/gliner_link_extractor.py b/libs/langchain/ragstack_langchain/graph_store/extractors/gliner_link_extractor.py deleted file mode 100644 index 98ca9445f..000000000 --- a/libs/langchain/ragstack_langchain/graph_store/extractors/gliner_link_extractor.py +++ /dev/null @@ -1,56 +0,0 @@ -from typing import Any, Dict, Iterable, List, Optional, Set - -from ragstack_langchain.graph_store.extractors.link_extractor import LinkExtractor -from ragstack_langchain.graph_store.links import Link - -# TypeAlias is not available in Python 2.9, we can't use that or the newer `type`. -GLiNERInput = str - - -class GLiNERLinkExtractor(LinkExtractor[GLiNERInput]): - def __init__( - self, - labels: List[str], - *, - kind: str = "entity", - model: str = "urchade/gliner_mediumv2.1", - extract_kwargs: Optional[Dict[str, Any]] = None, - ): - """Extract keywords using GLiNER. - - Args: - kind: Kind of links to produce with this extractor. - labels: List of kinds of entities to extract. - model: GLiNER model to use. - extract_kwargs: Keyword arguments to pass to GLiNER. - """ - try: - from gliner import GLiNER - - self._model = GLiNER.from_pretrained(model) - - except ImportError: - raise ImportError( - "gliner is required for GLiNERLinkExtractor. " - "Please install it with `pip install gliner`." - ) from None - - self._labels = labels - self._kind = kind - self._extract_kwargs = extract_kwargs or {} - - def extract_one(self, input: GLiNERInput) -> Set[Link]: # noqa: A002 - return next(self.extract_many([input])) - - def extract_many( - self, - inputs: Iterable[GLiNERInput], - ) -> Iterable[Set[Link]]: - strs = [i if isinstance(i, str) else i.page_content for i in inputs] - for entities in self._model.batch_predict_entities( - strs, self._labels, **self._extract_kwargs - ): - yield { - Link.bidir(kind=f"{self._kind}:{e['label']}", tag=e["text"]) - for e in entities - } diff --git a/libs/langchain/ragstack_langchain/graph_store/extractors/hierarchy_link_extractor.py b/libs/langchain/ragstack_langchain/graph_store/extractors/hierarchy_link_extractor.py deleted file mode 100644 index 90b21bfa7..000000000 --- a/libs/langchain/ragstack_langchain/graph_store/extractors/hierarchy_link_extractor.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Callable, List, Set - -from langchain_core.documents import Document - -from ragstack_langchain.graph_store.links import Link - -from .link_extractor import LinkExtractor -from .link_extractor_adapter import LinkExtractorAdapter - -# TypeAlias is not available in Python 2.9, we can't use that or the newer `type`. -HierarchyInput = List[str] - - -class HierarchyLinkExtractor(LinkExtractor[HierarchyInput]): - def __init__( - self, - kind: str = "hierarchy", - up_links: bool = True, - down_links: bool = False, - sibling_links: bool = False, - ): - """Extract links from a document hierarchy. - - Args: - kind: Kind of links to produce with this extractor. - up_links: Link from a section to it's parent. - down_links: Link from a section to it's children. - sibling_links: Link from a section to other sections with the same parent. - """ - self._kind = kind - self._up_links = up_links - self._down_links = down_links - self._sibling_links = sibling_links - - def as_document_extractor( - self, hierarchy: Callable[[Document], HierarchyInput] - ) -> LinkExtractor[Document]: - return LinkExtractorAdapter(underlying=self, transform=hierarchy) - - def extract_one( - self, - input: HierarchyInput, # noqa: A002 - ) -> Set[Link]: - this_path = "/".join(input) - parent_path = None - - links = set() - if self._up_links: - links.add(Link.incoming(kind=self._kind, tag=f"up:{this_path}")) - if self._down_links: - links.add(Link.outgoing(kind=self._kind, tag=f"down:{this_path}")) - - if len(input) >= 1: - parent_path = "/".join(input[0:-1]) - if self._up_links and len(input) > 1: - links.add(Link.outgoing(kind=self._kind, tag=f"up:{parent_path}")) - if self._down_links and len(input) > 1: - links.add(Link.incoming(kind=self._kind, tag=f"down:{parent_path}")) - if self._sibling_links: - links.add(Link.bidir(kind=self._kind, tag=f"sib:{parent_path}")) - - return links diff --git a/libs/langchain/ragstack_langchain/graph_store/extractors/html_link_extractor.py b/libs/langchain/ragstack_langchain/graph_store/extractors/html_link_extractor.py deleted file mode 100644 index bca85c1c1..000000000 --- a/libs/langchain/ragstack_langchain/graph_store/extractors/html_link_extractor.py +++ /dev/null @@ -1,120 +0,0 @@ -from dataclasses import dataclass -from typing import TYPE_CHECKING, Set, Union -from urllib.parse import urldefrag, urljoin, urlparse - -from langchain_core.documents import Document - -from ragstack_langchain.graph_store.links import Link - -from .link_extractor import LinkExtractor -from .link_extractor_adapter import LinkExtractorAdapter - -if TYPE_CHECKING: - from bs4 import BeautifulSoup - - -def _parse_url(link, page_url, drop_fragments: bool = True): - href = link.get("href") - if href is None: - return None - url = urlparse(href) - if url.scheme not in ["http", "https", ""]: - return None - - # Join the HREF with the page_url to convert relative paths to absolute. - url = urljoin(page_url, href) - - # Fragments would be useful if we chunked a page based on section. - # Then, each chunk would have a different URL based on the fragment. - # Since we aren't doing that yet, they just "break" links. So, drop - # the fragment. - if drop_fragments: - return urldefrag(url).url - return url - - -def _parse_hrefs( - soup: "BeautifulSoup", url: str, drop_fragments: bool = True -) -> Set[str]: - links = soup.find_all("a") - links = { - _parse_url(link, page_url=url, drop_fragments=drop_fragments) for link in links - } - - # Remove entries for any 'a' tag that failed to parse (didn't have href, - # or invalid domain, etc.) - links.discard(None) - - # Remove self links. - links.discard(url) - - return links - - -@dataclass -class HtmlInput: - content: Union[str, "BeautifulSoup"] - base_url: str - - -class HtmlLinkExtractor(LinkExtractor[HtmlInput]): - def __init__(self, *, kind: str = "hyperlink", drop_fragments: bool = True): - """Extract hyperlinks from HTML content. - - Expects the input to be an HTML string or a `BeautifulSoup` object. - - Args: - kind: The kind of edge to extract. Defaults to "hyperlink". - drop_fragments: Whether fragments in URLs and links shoud be - dropped. Defaults to `True`. - """ - try: - import bs4 # noqa:F401 - except ImportError as e: - raise ImportError( - "BeautifulSoup4 is required for HtmlLinkExtractor. " - "Please install it with `pip install beautifulsoup4`." - ) from e - - self._kind = kind - self.drop_fragments = drop_fragments - - def as_document_extractor( - self, url_metadata_key: str = "source" - ) -> LinkExtractor[Document]: - """Return a LinkExtractor that applies to documents. - - NOTE: Since the HtmlLinkExtractor parses HTML, if you use with other similar - link extractors it may be more efficient to call the link extractors directly - on the parsed BeautifulSoup object. - - Args: - url_metadata_key: The name of the filed in document metadata with the URL of - the document. - """ - return LinkExtractorAdapter( - underlying=self, - transform=lambda doc: HtmlInput( - doc.page_content, doc.metadata[url_metadata_key] - ), - ) - - def extract_one( - self, - input: HtmlInput, # noqa: A002 - ) -> Set[Link]: - content = input.content - if isinstance(content, str): - from bs4 import BeautifulSoup - - content = BeautifulSoup(content, "html.parser") - - base_url = input.base_url - if self.drop_fragments: - base_url = urldefrag(base_url).url - - hrefs = _parse_hrefs(content, base_url, self.drop_fragments) - - links = {Link.outgoing(kind=self._kind, tag=url) for url in hrefs} - links.add(Link.incoming(kind=self._kind, tag=base_url)) - return links diff --git a/libs/langchain/ragstack_langchain/graph_store/extractors/keybert_link_extractor.py b/libs/langchain/ragstack_langchain/graph_store/extractors/keybert_link_extractor.py deleted file mode 100644 index 9656a800e..000000000 --- a/libs/langchain/ragstack_langchain/graph_store/extractors/keybert_link_extractor.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import Any, Dict, Iterable, Optional, Set, Union - -from langchain_core.documents import Document - -from ragstack_langchain.graph_store.extractors.link_extractor import LinkExtractor -from ragstack_langchain.graph_store.links import Link - -# TypeAlias is not available in Python 2.9, we can't use that or the newer `type`. -KeybertInput = Union[str, Document] - - -class KeybertLinkExtractor(LinkExtractor[KeybertInput]): - def __init__( - self, - *, - kind: str = "kw", - embedding_model: str = "all-MiniLM-L6-v2", - extract_keywords_kwargs: Optional[Dict[str, Any]] = None, - ): - """Extract keywords using Keybert. - - Args: - kind: Kind of links to produce with this extractor. - embedding_model: Name of the embedding model to use with Keybert. - extract_keywords_kwargs: Keyword arguments to pass to Keybert's - `extract_keywords` method. - """ - try: - import keybert - - self._kw_model = keybert.KeyBERT(model=embedding_model) - except ImportError: - raise ImportError( - "keybert is required for KeybertLinkExtractor. " - "Please install it with `pip install keybert`." - ) from None - - self._kind = kind - self._extract_keywords_kwargs = extract_keywords_kwargs or {} - - def extract_one(self, input: KeybertInput) -> Set[Link]: # noqa: A002 - keywords = self._kw_model.extract_keywords( - input if isinstance(input, str) else input.page_content, - **self._extract_keywords_kwargs, - ) - return {Link.bidir(kind=self._kind, tag=kw[0]) for kw in keywords} - - def extract_many( - self, - inputs: Iterable[KeybertInput], - ) -> Iterable[Set[Link]]: - if len(inputs) == 1: - # Even though we pass a list, if it contains one item, keybert will - # flatten it. This means it's easier to just call the special case - # for one item. - yield self.extract_one(inputs[0]) - elif len(inputs) > 1: - strs = [i if isinstance(i, str) else i.page_content for i in inputs] - extracted = self._kw_model.extract_keywords( - strs, **self._extract_keywords_kwargs - ) - for keywords in extracted: - yield {Link.bidir(kind=self._kind, tag=kw[0]) for kw in keywords} diff --git a/libs/langchain/ragstack_langchain/graph_store/extractors/link_extractor.py b/libs/langchain/ragstack_langchain/graph_store/extractors/link_extractor.py deleted file mode 100644 index 3c7eaf965..000000000 --- a/libs/langchain/ragstack_langchain/graph_store/extractors/link_extractor.py +++ /dev/null @@ -1,37 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Generic, Iterable, Set, TypeVar - -if TYPE_CHECKING: - from ragstack_langchain.graph_store.links import Link - -InputT = TypeVar("InputT") - -METADATA_LINKS_KEY = "links" - - -class LinkExtractor(ABC, Generic[InputT]): - """Interface for extracting links (incoming, outgoing, bidirectional).""" - - @abstractmethod - def extract_one(self, input: InputT) -> set[Link]: # noqa: A002 - """Add edges from each `input` to the corresponding documents. - - Args: - input: The input content to extract edges from. - - Returns: - Set of links extracted from the input. - """ - - def extract_many(self, inputs: Iterable[InputT]) -> Iterable[Set[Link]]: - """Add edges from each `input` to the corresponding documents. - - Args: - inputs: The input content to extract edges from. - - Returns: - Iterable over the set of links extracted from the input. - """ - return map(self.extract_one, inputs) diff --git a/libs/langchain/ragstack_langchain/graph_store/extractors/link_extractor_adapter.py b/libs/langchain/ragstack_langchain/graph_store/extractors/link_extractor_adapter.py deleted file mode 100644 index b838800b2..000000000 --- a/libs/langchain/ragstack_langchain/graph_store/extractors/link_extractor_adapter.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Callable, Iterable, Set, TypeVar - -from ragstack_langchain.graph_store.extractors.link_extractor import LinkExtractor -from ragstack_langchain.graph_store.links import Link - -InputT = TypeVar("InputT") -UnderlyingInputT = TypeVar("UnderlyingInputT") - - -class LinkExtractorAdapter(LinkExtractor[InputT]): - def __init__( - self, - underlying: LinkExtractor[UnderlyingInputT], - transform: Callable[[InputT], UnderlyingInputT], - ) -> None: - self._underlying = underlying - self._transform = transform - - def extract_one(self, input: InputT) -> Set[Link]: # noqa: A002 - return self.extract_one(self._transform(input)) - - def extract_many(self, inputs: Iterable[InputT]) -> Iterable[Set[Link]]: - underlying_inputs = map(self._transform, inputs) - return self._underlying.extract_many(underlying_inputs) diff --git a/libs/langchain/ragstack_langchain/graph_store/extractors/link_extractor_transformer.py b/libs/langchain/ragstack_langchain/graph_store/extractors/link_extractor_transformer.py index a74ba7122..2fb27e9c7 100644 --- a/libs/langchain/ragstack_langchain/graph_store/extractors/link_extractor_transformer.py +++ b/libs/langchain/ragstack_langchain/graph_store/extractors/link_extractor_transformer.py @@ -1,10 +1,9 @@ from typing import Iterable, Sequence +from langchain_community.graph_vectorstores.extractors import LinkExtractor from langchain_core.documents import Document from langchain_core.documents.transformers import BaseDocumentTransformer - -from ragstack_langchain.graph_store.extractors.link_extractor import LinkExtractor -from ragstack_langchain.graph_store.links import add_links +from langchain_core.graph_vectorstores.links import add_links class LinkExtractorTransformer(BaseDocumentTransformer): diff --git a/libs/langchain/ragstack_langchain/graph_store/links.py b/libs/langchain/ragstack_langchain/graph_store/links.py deleted file mode 100644 index 2475bb850..000000000 --- a/libs/langchain/ragstack_langchain/graph_store/links.py +++ /dev/null @@ -1,56 +0,0 @@ -from dataclasses import dataclass -from typing import Iterable, Literal, Set, Union - -from langchain_core.documents import Document - - -@dataclass(frozen=True) -class Link: - kind: str - direction: Literal["in", "out", "bidir"] - tag: str - - @staticmethod - def incoming(kind: str, tag: str) -> "Link": - return Link(kind=kind, direction="in", tag=tag) - - @staticmethod - def outgoing(kind: str, tag: str) -> "Link": - return Link(kind=kind, direction="out", tag=tag) - - @staticmethod - def bidir(kind: str, tag: str) -> "Link": - return Link(kind=kind, direction="bidir", tag=tag) - - -METADATA_LINKS_KEY = "links" - - -def get_links(doc: Document) -> Set[Link]: - """Get the links from a document. - Args: - doc: The document to get the link tags from. - Returns: - The set of link tags from the document. - """ - - links = doc.metadata.setdefault(METADATA_LINKS_KEY, set()) - if not isinstance(links, Set): - # Convert to a set and remember that. - links = set(links) - doc.metadata[METADATA_LINKS_KEY] = links - return links - - -def add_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> None: - """Add links to the given metadata. - Args: - doc: The document to add the links to. - *links: The links to add to the document. - """ - doc_links = get_links(doc) - for link in links: - if isinstance(link, Link): - doc_links.add(link) - else: - doc_links.update(link) diff --git a/libs/langchain/tests/unit_tests/test_gliner_link_extractor.py b/libs/langchain/tests/unit_tests/test_gliner_link_extractor.py index 7789bde99..f3cd52122 100644 --- a/libs/langchain/tests/unit_tests/test_gliner_link_extractor.py +++ b/libs/langchain/tests/unit_tests/test_gliner_link_extractor.py @@ -1,5 +1,5 @@ -from ragstack_langchain.graph_store.extractors import GLiNERLinkExtractor -from ragstack_langchain.graph_store.links import Link +from langchain_community.graph_vectorstores.extractors import GLiNERLinkExtractor +from langchain_core.graph_vectorstores import Link PAGE_1 = """ Cristiano Ronaldo dos Santos Aveiro (Portuguese pronunciation: [kɾiʃ'tjɐnu diff --git a/libs/langchain/tests/unit_tests/test_graph_store.py b/libs/langchain/tests/unit_tests/test_graph_store.py deleted file mode 100644 index 8380dec86..000000000 --- a/libs/langchain/tests/unit_tests/test_graph_store.py +++ /dev/null @@ -1,65 +0,0 @@ -import pytest -from langchain_core.documents import Document -from ragstack_langchain.graph_store.base import ( - Node, - _documents_to_nodes, - _texts_to_nodes, -) -from ragstack_langchain.graph_store.links import Link - - -def test_texts_to_nodes() -> None: - assert list(_texts_to_nodes(["a", "b"], [{"a": "b"}, {"c": "d"}], ["a", "b"])) == [ - Node(id="a", metadata={"a": "b"}, text="a"), - Node(id="b", metadata={"c": "d"}, text="b"), - ] - assert list(_texts_to_nodes(["a", "b"], None, ["a", "b"])) == [ - Node(id="a", metadata={}, text="a"), - Node(id="b", metadata={}, text="b"), - ] - assert list(_texts_to_nodes(["a", "b"], [{"a": "b"}, {"c": "d"}], None)) == [ - Node(metadata={"a": "b"}, text="a"), - Node(metadata={"c": "d"}, text="b"), - ] - assert list( - _texts_to_nodes( - ["a"], - [{"links": {Link.incoming(kind="hyperlink", tag="http://b")}}], - None, - ) - ) == [Node(links={Link.incoming(kind="hyperlink", tag="http://b")}, text="a")] - with pytest.raises(ValueError, match="texts iterable longer than ids"): - list(_texts_to_nodes(["a", "b"], None, ["a"])) - with pytest.raises(ValueError, match="texts iterable longer than metadatas"): - list(_texts_to_nodes(["a", "b"], [{"a": "b"}], None)) - with pytest.raises(ValueError, match="metadatas iterable longer than texts"): - list(_texts_to_nodes(["a"], [{"a": "b"}, {"c": "d"}], None)) - with pytest.raises(ValueError, match="ids iterable longer than texts"): - list(_texts_to_nodes(["a"], None, ["a", "b"])) - - -def test_documents_to_nodes() -> None: - documents = [ - Document( - page_content="a", - metadata={"links": {Link.incoming(kind="hyperlink", tag="http://b")}}, - ), - Document(page_content="b", metadata={"c": "d"}), - ] - assert list(_documents_to_nodes(documents, ["a", "b"])) == [ - Node( - id="a", - metadata={}, - links={Link.incoming(kind="hyperlink", tag="http://b")}, - text="a", - ), - Node(id="b", metadata={"c": "d"}, text="b"), - ] - assert list(_documents_to_nodes(documents, None)) == [ - Node(links={Link.incoming(kind="hyperlink", tag="http://b")}, text="a"), - Node(metadata={"c": "d"}, text="b"), - ] - with pytest.raises(ValueError, match="documents iterable longer than ids"): - list(_documents_to_nodes(documents, ["a"])) - with pytest.raises(ValueError, match="ids iterable longer than documents"): - list(_documents_to_nodes(documents[1:], ["a", "b"])) diff --git a/libs/langchain/tests/unit_tests/test_hierarchy_link_extractor.py b/libs/langchain/tests/unit_tests/test_hierarchy_link_extractor.py deleted file mode 100644 index 2074c4f3a..000000000 --- a/libs/langchain/tests/unit_tests/test_hierarchy_link_extractor.py +++ /dev/null @@ -1,83 +0,0 @@ -from ragstack_langchain.graph_store.extractors import HierarchyLinkExtractor -from ragstack_langchain.graph_store.links import Link - -PATH_1 = ["Root", "H1", "h2"] - -PATH_2 = ["Root", "H1"] - -PATH_3 = ["Root"] - - -def test_up_only(): - extractor = HierarchyLinkExtractor() - - assert extractor.extract_one(PATH_1) == { - # Path1 links up to Root/H1 - Link.outgoing(kind="hierarchy", tag="up:Root/H1"), - # Path1 is linked to by stuff under Root/H1/h2 - Link.incoming(kind="hierarchy", tag="up:Root/H1/h2"), - } - - assert extractor.extract_one(PATH_2) == { - # Path2 links up to Root - Link.outgoing(kind="hierarchy", tag="up:Root"), - # Path2 is linked to by stuff under Root/H1/h2 - Link.incoming(kind="hierarchy", tag="up:Root/H1"), - } - - assert extractor.extract_one(PATH_3) == { - # Path3 is linked to by stuff under Root - Link.incoming(kind="hierarchy", tag="up:Root"), - } - - -def test_up_and_down(): - extractor = HierarchyLinkExtractor(down_links=True) - - assert extractor.extract_one(PATH_1) == { - # Path1 links up to Root/H1 - Link.outgoing(kind="hierarchy", tag="up:Root/H1"), - # Path1 is linked to by stuff under Root/H1/h2 - Link.incoming(kind="hierarchy", tag="up:Root/H1/h2"), - # Path1 links down to things under Root/H1/h2. - Link.outgoing(kind="hierarchy", tag="down:Root/H1/h2"), - # Path1 is linked down to by Root/H1 - Link.incoming(kind="hierarchy", tag="down:Root/H1"), - } - - assert extractor.extract_one(PATH_2) == { - # Path2 links up to Root - Link.outgoing(kind="hierarchy", tag="up:Root"), - # Path2 is linked to by stuff under Root/H1/h2 - Link.incoming(kind="hierarchy", tag="up:Root/H1"), - # Path2 links down to things under Root/H1. - Link.outgoing(kind="hierarchy", tag="down:Root/H1"), - # Path2 is linked down to by Root - Link.incoming(kind="hierarchy", tag="down:Root"), - } - - assert extractor.extract_one(PATH_3) == { - # Path3 is linked to by stuff under Root - Link.incoming(kind="hierarchy", tag="up:Root"), - # Path3 links down to things under Root/H1. - Link.outgoing(kind="hierarchy", tag="down:Root"), - } - - -def test_sibling(): - extractor = HierarchyLinkExtractor(sibling_links=True, up_links=False) - - assert extractor.extract_one(PATH_1) == { - # Path1 links with anything else in Root/H1 - Link.bidir(kind="hierarchy", tag="sib:Root/H1"), - } - - assert extractor.extract_one(PATH_2) == { - # Path2 links with anything else in Root - Link.bidir(kind="hierarchy", tag="sib:Root"), - } - - assert extractor.extract_one(PATH_3) == { - # Path3 links with anything else at the top level - Link.bidir(kind="hierarchy", tag="sib:"), - } diff --git a/libs/langchain/tests/unit_tests/test_html_link_extractor.py b/libs/langchain/tests/unit_tests/test_html_link_extractor.py deleted file mode 100644 index 3ac84b4c1..000000000 --- a/libs/langchain/tests/unit_tests/test_html_link_extractor.py +++ /dev/null @@ -1,106 +0,0 @@ -from bs4 import BeautifulSoup -from ragstack_langchain.graph_store.extractors import HtmlInput, HtmlLinkExtractor -from ragstack_langchain.graph_store.links import Link - -PAGE_1 = """ - -
-Hello. -Relative -Relative base. -Aboslute -Test - - -""" - -PAGE_2 = """ - - -Hello. -Relative - -""" - - -def test_one_from_str(): - extractor = HtmlLinkExtractor() - - results = extractor.extract_one(HtmlInput(PAGE_1, base_url="https://foo.com/bar/")) - assert results == { - Link.incoming(kind="hyperlink", tag="https://foo.com/bar/"), - Link.outgoing(kind="hyperlink", tag="https://foo.com/bar/relative"), - Link.outgoing(kind="hyperlink", tag="https://foo.com/relative-base"), - Link.outgoing(kind="hyperlink", tag="http://cnn.com"), - Link.outgoing(kind="hyperlink", tag="https://same.foo"), - } - - results = extractor.extract_one(HtmlInput(PAGE_1, base_url="http://foo.com/bar/")) - assert results == { - Link.incoming(kind="hyperlink", tag="http://foo.com/bar/"), - Link.outgoing(kind="hyperlink", tag="http://foo.com/bar/relative"), - Link.outgoing(kind="hyperlink", tag="http://foo.com/relative-base"), - Link.outgoing(kind="hyperlink", tag="http://cnn.com"), - Link.outgoing(kind="hyperlink", tag="http://same.foo"), - } - - -def test_one_from_beautiful_soup(): - extractor = HtmlLinkExtractor() - soup = BeautifulSoup(PAGE_1, "html.parser") - results = extractor.extract_one(HtmlInput(soup, base_url="https://foo.com/bar/")) - assert results == { - Link.incoming(kind="hyperlink", tag="https://foo.com/bar/"), - Link.outgoing(kind="hyperlink", tag="https://foo.com/bar/relative"), - Link.outgoing(kind="hyperlink", tag="https://foo.com/relative-base"), - Link.outgoing(kind="hyperlink", tag="http://cnn.com"), - Link.outgoing(kind="hyperlink", tag="https://same.foo"), - } - - -def test_drop_fragments(): - extractor = HtmlLinkExtractor(drop_fragments=True) - results = extractor.extract_one( - HtmlInput(PAGE_2, base_url="https://foo.com/baz/#fragment") - ) - - assert results == { - Link.incoming(kind="hyperlink", tag="https://foo.com/baz/"), - Link.outgoing(kind="hyperlink", tag="https://foo.com/bar/"), - } - - -def test_include_fragments(): - extractor = HtmlLinkExtractor(drop_fragments=False) - results = extractor.extract_one( - HtmlInput(PAGE_2, base_url="https://foo.com/baz/#fragment") - ) - - assert results == { - Link.incoming(kind="hyperlink", tag="https://foo.com/baz/#fragment"), - Link.outgoing(kind="hyperlink", tag="https://foo.com/bar/#fragment"), - } - - -def test_batch_from_str(): - extractor = HtmlLinkExtractor() - results = list( - extractor.extract_many( - [ - HtmlInput(PAGE_1, base_url="https://foo.com/bar/"), - HtmlInput(PAGE_2, base_url="https://foo.com/baz/"), - ] - ) - ) - - assert results[0] == { - Link.incoming(kind="hyperlink", tag="https://foo.com/bar/"), - Link.outgoing(kind="hyperlink", tag="https://foo.com/bar/relative"), - Link.outgoing(kind="hyperlink", tag="https://foo.com/relative-base"), - Link.outgoing(kind="hyperlink", tag="http://cnn.com"), - Link.outgoing(kind="hyperlink", tag="https://same.foo"), - } - assert results[1] == { - Link.incoming(kind="hyperlink", tag="https://foo.com/baz/"), - Link.outgoing(kind="hyperlink", tag="https://foo.com/bar/"), - } diff --git a/libs/langchain/tests/unit_tests/test_keybert_link_extractor.py b/libs/langchain/tests/unit_tests/test_keybert_link_extractor.py index fdd3a1d79..947c1f117 100644 --- a/libs/langchain/tests/unit_tests/test_keybert_link_extractor.py +++ b/libs/langchain/tests/unit_tests/test_keybert_link_extractor.py @@ -1,5 +1,5 @@ -from ragstack_langchain.graph_store.extractors import KeybertLinkExtractor -from ragstack_langchain.graph_store.links import Link +from langchain_community.graph_vectorstores.extractors import KeybertLinkExtractor +from langchain_core.graph_vectorstores import Link PAGE_1 = """ Supervised learning is the machine learning task of learning a function that diff --git a/libs/langchain/tests/unit_tests/test_link_extractor_transformer.py b/libs/langchain/tests/unit_tests/test_link_extractor_transformer.py index d435c35dc..9902c65b7 100644 --- a/libs/langchain/tests/unit_tests/test_link_extractor_transformer.py +++ b/libs/langchain/tests/unit_tests/test_link_extractor_transformer.py @@ -1,22 +1,37 @@ -from langchain_core.documents import Document -from ragstack_langchain.graph_store.extractors import ( - HtmlLinkExtractor, - LinkExtractorTransformer, -) -from ragstack_langchain.graph_store.extractors.gliner_link_extractor import ( +from langchain_community.graph_vectorstores.extractors import ( GLiNERLinkExtractor, -) -from ragstack_langchain.graph_store.extractors.keybert_link_extractor import ( + HtmlLinkExtractor, KeybertLinkExtractor, ) -from ragstack_langchain.graph_store.links import Link, get_links +from langchain_core.documents import Document +from langchain_core.graph_vectorstores.links import Link, get_links +from ragstack_langchain.graph_store.extractors import LinkExtractorTransformer from . import ( test_gliner_link_extractor, - test_html_link_extractor, test_keybert_link_extractor, ) +PAGE_1 = """ + + +Hello. +Relative +Relative base. +Aboslute +Test + + +""" + +PAGE_2 = """ + + +Hello. +Relative + +""" + def test_html_extractor(): transformer = LinkExtractorTransformer( @@ -25,13 +40,13 @@ def test_html_extractor(): ] ) doc1 = Document( - page_content=test_html_link_extractor.PAGE_1, + page_content=PAGE_1, metadata={ "source": "https://foo.com/bar/", }, ) doc2 = Document( - page_content=test_html_link_extractor.PAGE_2, + page_content=PAGE_2, metadata={ "source": "https://foo.com/baz/", }, @@ -40,7 +55,7 @@ def test_html_extractor(): assert results[0] == doc1 assert results[1] == doc2 - assert get_links(doc1) == { + assert set(get_links(doc1)) == { Link.incoming(kind="hyperlink", tag="https://foo.com/bar/"), Link.outgoing(kind="hyperlink", tag="https://foo.com/bar/relative"), Link.outgoing(kind="hyperlink", tag="https://foo.com/relative-base"), @@ -48,7 +63,7 @@ def test_html_extractor(): Link.outgoing(kind="hyperlink", tag="https://same.foo"), } - assert get_links(doc2) == { + assert set(get_links(doc2)) == { Link.incoming(kind="hyperlink", tag="https://foo.com/baz/"), Link.outgoing(kind="hyperlink", tag="https://foo.com/bar/"), } @@ -71,7 +86,7 @@ def test_multiple_extractors(): assert results[0] == doc1 assert results[1] == doc2 - assert get_links(doc1) == { + assert set(get_links(doc1)) == { Link(kind="kw", direction="bidir", tag="labeled"), Link(kind="kw", direction="bidir", tag="learning"), Link(kind="kw", direction="bidir", tag="training"), @@ -79,7 +94,7 @@ def test_multiple_extractors(): Link(kind="kw", direction="bidir", tag="labels"), } - assert get_links(doc2) == { + assert set(get_links(doc2)) == { Link(kind="kw", direction="bidir", tag="cristiano"), Link(kind="kw", direction="bidir", tag="goalscorer"), Link(kind="kw", direction="bidir", tag="footballer"),