Skip to content

Commit 296d4eb

Browse files
authored
feat: Add extractors for keywords, ner & hierarchy (#527)
* feat: Add extractors for keywords, ner & hierarchy This simplifies the use of Keybert for keyword based links and GLiNER for named-entity based links. This also makes it easier to create links representing a document and/or page hierarchy.
1 parent c0aa8ea commit 296d4eb

15 files changed

+576
-12
lines changed

libs/knowledge-store/ragstack_knowledge_store/graph_store.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -272,21 +272,18 @@ def _apply_schema(self):
272272
)
273273

274274
# Index on text_embedding (for similarity search)
275-
self._session.execute(
276-
f"""CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_text_embedding_index
275+
self._session.execute(f"""
276+
CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_text_embedding_index
277277
ON {self._keyspace}.{self._node_table}(text_embedding)
278278
USING 'StorageAttachedIndex';
279-
""" # noqa: E501
280-
)
279+
""") # noqa: E501
281280

282281
# Index on target_text_embedding (for similarity search)
283-
self._session.execute(
284-
f"""
282+
self._session.execute(f"""
285283
CREATE CUSTOM INDEX IF NOT EXISTS {self._targets_table}_target_text_embedding_index
286284
ON {self._keyspace}.{self._targets_table}(target_text_embedding)
287285
USING 'StorageAttachedIndex';
288-
""" # noqa: E501
289-
)
286+
""") # noqa: E501
290287

291288
def _concurrent_queries(self) -> ConcurrentQueries:
292289
return ConcurrentQueries(self._session)

libs/langchain/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ ragstack-ai-tests-utils = { path = "../tests-utils", develop = true }
3939
ragstack-ai-colbert = { path = "../colbert", develop = true }
4040
ragstack-ai-knowledge-store = { path = "../knowledge-store", develop = true }
4141
pytest-asyncio = "^0.23.6"
42+
keybert = "^0.8.5"
43+
gliner = "^0.2.5"
4244

4345
[tool.poetry.group.dev.dependencies]
4446
setuptools = "^70.0.0"
Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
1+
from .gliner_link_extractor import GLiNERInput, GLiNERLinkExtractor
2+
from .hierarchy_link_extractor import HierarchyInput, HierarchyLinkExtractor
13
from .html_link_extractor import HtmlInput, HtmlLinkExtractor
2-
from .link_extractor import LinkExtractor
4+
from .keybert_link_extractor import KeybertInput, KeybertLinkExtractor
5+
from .link_extractor_adapter import LinkExtractorAdapter
6+
from .link_extractor_transformer import LinkExtractorTransformer
37

48
__all__ = [
59
"LinkExtractor",
10+
"GLiNERInput",
11+
"GLiNERLinkExtractor",
12+
"HierarchyInput",
13+
"HierarchyLinkExtractor",
614
"HtmlInput",
715
"HtmlLinkExtractor",
16+
"KeybertInput",
17+
"KeybertLinkExtractor",
18+
"LinkExtractorAdapter",
19+
"LinkExtractorTransformer",
820
]
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from typing import Any, Dict, Iterable, List, Optional, Set
2+
3+
from ragstack_langchain.graph_store.extractors.link_extractor import LinkExtractor
4+
from ragstack_langchain.graph_store.links import Link
5+
6+
# TypeAlias is not available in Python 2.9, we can't use that or the newer `type`.
7+
GLiNERInput = str
8+
9+
10+
class GLiNERLinkExtractor(LinkExtractor[GLiNERInput]):
11+
def __init__(
12+
self,
13+
labels: List[str],
14+
*,
15+
kind: str = "entity",
16+
model: str = "urchade/gliner_mediumv2.1",
17+
extract_kwargs: Optional[Dict[str, Any]] = None,
18+
):
19+
"""Extract keywords using GLiNER.
20+
21+
Args:
22+
kind: Kind of links to produce with this extractor.
23+
labels: List of kinds of entities to extract.
24+
model: GLiNER model to use.
25+
extract_kwargs: Keyword arguments to pass to GLiNER.
26+
"""
27+
try:
28+
from gliner import GLiNER
29+
30+
self._model = GLiNER.from_pretrained(model)
31+
32+
except ImportError:
33+
raise ImportError(
34+
"gliner is required for GLiNERLinkExtractor. "
35+
"Please install it with `pip install gliner`."
36+
) from None
37+
38+
self._labels = labels
39+
self._kind = kind
40+
self._extract_kwargs = extract_kwargs or {}
41+
42+
def extract_one(self, input: GLiNERInput) -> Set[Link]:
43+
return next(self.extract_many([input]))
44+
45+
def extract_many(
46+
self,
47+
inputs: Iterable[GLiNERInput],
48+
) -> Iterable[Set[Link]]:
49+
strs = [i if isinstance(i, str) else i.page_content for i in inputs]
50+
for entities in self._model.batch_predict_entities(
51+
strs, self._labels, **self._extract_kwargs
52+
):
53+
yield {
54+
Link.bidir(kind=f"{self._kind}:{e['label']}", tag=e["text"])
55+
for e in entities
56+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from typing import Callable, List, Set
2+
3+
from langchain_core.documents import Document
4+
5+
from ragstack_langchain.graph_store.links import Link
6+
7+
from .link_extractor import LinkExtractor
8+
from .link_extractor_adapter import LinkExtractorAdapter
9+
10+
# TypeAlias is not available in Python 2.9, we can't use that or the newer `type`.
11+
HierarchyInput = List[str]
12+
13+
14+
class HierarchyLinkExtractor(LinkExtractor[HierarchyInput]):
15+
def __init__(
16+
self,
17+
kind: str = "hierarchy",
18+
up_links: bool = True,
19+
down_links: bool = False,
20+
sibling_links: bool = False,
21+
):
22+
"""Extract links from a document hierarchy.
23+
24+
Args:
25+
kind: Kind of links to produce with this extractor.
26+
up_links: Link from a section to it's parent.
27+
down_links: Link from a section to it's children.
28+
sibling_links: Link from a section to other sections with the same parent.
29+
"""
30+
self._kind = kind
31+
self._up_links = up_links
32+
self._down_links = down_links
33+
self._sibling_links = sibling_links
34+
35+
def as_document_extractor(
36+
self, hierarchy: Callable[[Document], HierarchyInput]
37+
) -> LinkExtractor[Document]:
38+
return LinkExtractorAdapter(underlying=self, transform=hierarchy)
39+
40+
def extract_one(
41+
self,
42+
input: HierarchyInput,
43+
) -> Set[Link]:
44+
this_path = "/".join(input)
45+
parent_path = None
46+
47+
links = set()
48+
if self._up_links:
49+
links.add(Link.incoming(kind=self._kind, tag=f"up:{this_path}"))
50+
if self._down_links:
51+
links.add(Link.outgoing(kind=self._kind, tag=f"down:{this_path}"))
52+
53+
if len(input) >= 1:
54+
parent_path = "/".join(input[0:-1])
55+
if self._up_links and len(input) > 1:
56+
links.add(Link.outgoing(kind=self._kind, tag=f"up:{parent_path}"))
57+
if self._down_links and len(input) > 1:
58+
links.add(Link.incoming(kind=self._kind, tag=f"down:{parent_path}"))
59+
if self._sibling_links:
60+
links.add(Link.bidir(kind=self._kind, tag=f"sib:{parent_path}"))
61+
62+
return links

libs/langchain/ragstack_langchain/graph_store/extractors/html_link_extractor.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
from typing import TYPE_CHECKING, Set, Union
33
from urllib.parse import urldefrag, urljoin, urlparse
44

5+
from langchain_core.documents import Document
6+
57
from ragstack_langchain.graph_store.links import Link
68

79
from .link_extractor import LinkExtractor
10+
from .link_extractor_adapter import LinkExtractorAdapter
811

912
if TYPE_CHECKING:
1013
from bs4 import BeautifulSoup
@@ -77,6 +80,26 @@ def __init__(self, *, kind: str = "hyperlink", drop_fragments: bool = True):
7780
self._kind = kind
7881
self.drop_fragments = drop_fragments
7982

83+
def as_document_extractor(
84+
self, url_metadata_key: str = "source"
85+
) -> LinkExtractor[Document]:
86+
"""Return a LinkExtractor that applies to documents.
87+
88+
NOTE: Since the HtmlLinkExtractor parses HTML, if you use with other similar
89+
link extractors it may be more efficient to call the link extractors directly
90+
on the parsed BeautifulSoup object.
91+
92+
Args:
93+
url_metadata_key: The name of the filed in document metadata with the URL of
94+
the document.
95+
"""
96+
return LinkExtractorAdapter(
97+
underlying=self,
98+
transform=lambda doc: HtmlInput(
99+
doc.page_content, doc.metadata[url_metadata_key]
100+
),
101+
)
102+
80103
def extract_one(
81104
self,
82105
input: HtmlInput,
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from typing import Any, Dict, Iterable, Optional, Set, Union
2+
3+
from langchain_core.documents import Document
4+
5+
from ragstack_langchain.graph_store.extractors.link_extractor import LinkExtractor
6+
from ragstack_langchain.graph_store.links import Link
7+
8+
# TypeAlias is not available in Python 2.9, we can't use that or the newer `type`.
9+
KeybertInput = Union[str, Document]
10+
11+
12+
class KeybertLinkExtractor(LinkExtractor[KeybertInput]):
13+
def __init__(
14+
self,
15+
*,
16+
kind: str = "kw",
17+
embedding_model: str = "all-MiniLM-L6-v2",
18+
extract_keywords_kwargs: Optional[Dict[str, Any]] = None,
19+
):
20+
"""Extract keywords using Keybert.
21+
22+
Args:
23+
kind: Kind of links to produce with this extractor.
24+
embedding_model: Name of the embedding model to use with Keybert.
25+
extract_keywords_kwargs: Keyword arguments to pass to Keybert's
26+
`extract_keywords` method.
27+
"""
28+
try:
29+
import keybert
30+
31+
self._kw_model = keybert.KeyBERT(model=embedding_model)
32+
except ImportError:
33+
raise ImportError(
34+
"keybert is required for KeybertLinkExtractor. "
35+
"Please install it with `pip install keybert`."
36+
) from None
37+
38+
self._kind = kind
39+
self._extract_keywords_kwargs = extract_keywords_kwargs or {}
40+
41+
def extract_one(self, input: KeybertInput) -> Set[Link]:
42+
keywords = self._kw_model.extract_keywords(
43+
input if isinstance(input, str) else input.page_content,
44+
**self._extract_keywords_kwargs,
45+
)
46+
return {Link.bidir(kind=self._kind, tag=kw[0]) for kw in keywords}
47+
48+
def extract_many(
49+
self,
50+
inputs: Iterable[KeybertInput],
51+
) -> Iterable[Set[Link]]:
52+
if len(inputs) == 1:
53+
# Even though we pass a list, if it contains one item, keybert will
54+
# flatten it. This means it's easier to just call the special case
55+
# for one item.
56+
yield self.extract_one(inputs[0])
57+
elif len(inputs) > 1:
58+
strs = [i if isinstance(i, str) else i.page_content for i in inputs]
59+
extracted = self._kw_model.extract_keywords(
60+
strs, **self._extract_keywords_kwargs
61+
)
62+
for keywords in extracted:
63+
yield {Link.bidir(kind=self._kind, tag=kw[0]) for kw in keywords}

libs/langchain/ragstack_langchain/graph_store/extractors/link_extractor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from typing import Generic, Iterable, TypeVar
4+
from typing import Generic, Iterable, Set, TypeVar
55

66
from ragstack_langchain.graph_store.links import Link
77

@@ -24,7 +24,7 @@ def extract_one(self, input: InputT) -> set[Link]:
2424
Set of links extracted from the input.
2525
"""
2626

27-
def extract_many(self, inputs: Iterable[InputT]):
27+
def extract_many(self, inputs: Iterable[InputT]) -> Iterable[Set[Link]]:
2828
"""Add edges from each `input` to the corresponding documents.
2929
3030
Args:
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing import Callable, Iterable, Set, TypeVar
2+
3+
from ragstack_langchain.graph_store.extractors.link_extractor import LinkExtractor
4+
from ragstack_langchain.graph_store.links import Link
5+
6+
InputT = TypeVar("InputT")
7+
UnderlyingInputT = TypeVar("UnderlyingInputT")
8+
9+
10+
class LinkExtractorAdapter(LinkExtractor[InputT]):
11+
def __init__(
12+
self,
13+
underlying: LinkExtractor[UnderlyingInputT],
14+
transform: Callable[[InputT], UnderlyingInputT],
15+
) -> None:
16+
self._underlying = underlying
17+
self._transform = transform
18+
19+
def extract_one(self, input: InputT) -> Set[Link]:
20+
return self.extract_one(self._transform(input))
21+
22+
def extract_many(self, inputs: Iterable[InputT]) -> Iterable[Set[Link]]:
23+
underlying_inputs = [self._transform(input) for input in inputs]
24+
return self._underlying.extract_many(underlying_inputs)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Iterable, Sequence
2+
3+
from langchain_core.documents import Document
4+
from langchain_core.documents.transformers import BaseDocumentTransformer
5+
6+
from ragstack_langchain.graph_store.extractors.link_extractor import LinkExtractor
7+
from ragstack_langchain.graph_store.links import add_links
8+
9+
10+
class LinkExtractorTransformer(BaseDocumentTransformer):
11+
def __init__(self, link_extractors: Iterable[LinkExtractor[Document]]):
12+
"""Create a DocumentTransformer which adds the given links."""
13+
self.link_extractors = link_extractors
14+
15+
def transform_documents(self, documents: Sequence[Document]) -> Sequence[Document]:
16+
document_links = zip(
17+
documents,
18+
zip(
19+
*[
20+
extractor.extract_many(documents)
21+
for extractor in self.link_extractors
22+
]
23+
),
24+
)
25+
for document, links in document_links:
26+
add_links(document, *links)
27+
return documents

0 commit comments

Comments
 (0)