Skip to content

Commit fa98828

Browse files
authored
ref: Separate out edge extractors (#450)
* ref: Separate out edge extractors This is a step towards making the edge extraction extensible. Rather than specifying which extractions to apply with a bunch of hard-coded booleans, this takes a list of extractors. This has some limitations: - Concurrent execution is synchronized between extractors. That is, all tasks from one extractor complete before the next begins. This is likely better than requiring some coordination between extractors, especially if we wish to make this a more generalizable implementation. This can be revisited if it becomes a problem, although `async` versions are likely the best solution. Additionally, future improvements to be made include: - Allowing extractors to (configurably) annotate the edges they add - Moving queries into the extractors and/or generally cleaning up how the extractors interact with the store.
1 parent 6b4c7a5 commit fa98828

File tree

7 files changed

+401
-140
lines changed

7 files changed

+401
-140
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Dict, Iterable, Set
4+
5+
from ragstack_knowledge_store.edge_extractor import EdgeExtractor
6+
from ragstack_knowledge_store.knowledge_store import CONTENT_ID, KnowledgeStore
7+
8+
9+
class DirectedEdgeExtractor(EdgeExtractor):
10+
def __init__(self, sources_field: str, targets_field: str, kind: str) -> None:
11+
"""Extract directed edges between uses and definitions.
12+
While `UndirectedEdgeExtractor` links nodes in both directions if they share
13+
a keyword, this only creates links from nodes with a "source" to nodes with
14+
a matching "target". For example, uses may be the `href` of `a` tags in the
15+
chunk and definitions may be the URLs that the chunk is accessible at.
16+
17+
This may also be used for other forms of references, such as Wikipedia
18+
article IDs, etc.
19+
20+
Args:
21+
sources_field: The metadata field to read sources from.
22+
targets_field: The metadata field to read targets from.
23+
kind: The kind label to apply to created edges. Must be unique.
24+
"""
25+
26+
# TODO: Assert the kind matches some reasonable regex?
27+
28+
# TODO: Allow specifying how properties should be added to the edge.
29+
# For instance, `links_to`.
30+
self._sources_field = sources_field
31+
self._targets_field = targets_field
32+
self._kind = kind
33+
34+
@property
35+
def kind(self) -> str:
36+
return self._kind
37+
38+
@staticmethod
39+
def for_hrefs_to_urls() -> DirectedEdgeExtractor:
40+
return DirectedEdgeExtractor(sources_field="hrefs", targets_field="urls", kind="link")
41+
42+
def _sources(self, metadata: Dict[str, Any]) -> Set[str]:
43+
sources = metadata.get(self._sources_field)
44+
if not sources:
45+
return set()
46+
elif isinstance(sources, str):
47+
return set({sources})
48+
else:
49+
return set(sources)
50+
51+
def _targets(self, metadata: Dict[str, Any]) -> Set[str]:
52+
targets = metadata.get(self._targets_field)
53+
if not targets:
54+
return set()
55+
elif isinstance(targets, str):
56+
return set({targets})
57+
else:
58+
return set(targets)
59+
60+
def tags(self, text: str, metadata: Dict[str, Any]) -> Set[str]:
61+
results = set()
62+
for source in self._sources(metadata):
63+
results.add(f"{self._kind}_s:{source}")
64+
for target in self._targets(metadata):
65+
results.add(f"{self._kind}_t:{target}")
66+
return results
67+
68+
def extract_edges(
69+
self, store: KnowledgeStore, texts: Iterable[str], metadatas: Iterable[Dict[str, Any]]
70+
) -> int:
71+
# First, iterate over the new nodes, collecting the sources/targets that
72+
# are referenced and which IDs contain those.
73+
new_ids = set()
74+
new_sources_to_ids = {}
75+
new_targets_to_ids = {}
76+
for md in metadatas:
77+
id = md[CONTENT_ID]
78+
79+
new_ids.add(id)
80+
for resource in self._sources(md):
81+
new_sources_to_ids.setdefault(resource, set()).add(id)
82+
for target in self._targets(md):
83+
new_targets_to_ids.setdefault(target, set()).add(id)
84+
85+
# Then, retrieve the set of persisted items for each of those
86+
# source/targets and link them to the new items as needed.
87+
# Remembering that the the *new* nodes will have been added.
88+
source_target_pairs = set()
89+
with store._concurrent_queries() as cq:
90+
91+
def add_source_target_pairs(href_ids, url_ids):
92+
for href_id in href_ids:
93+
if not isinstance(href_id, str):
94+
href_id = href_id.content_id
95+
96+
for url_id in url_ids:
97+
if not isinstance(url_id, str):
98+
url_id = url_id.content_id
99+
source_target_pairs.add((href_id, url_id))
100+
101+
for resource, source_ids in new_sources_to_ids.items():
102+
cq.execute(
103+
store._query_ids_by_tag,
104+
parameters=(f"{self._kind}_t:{resource}",),
105+
# Weird syntax to capture each `source_ids` instead of the last iteration.
106+
callback=lambda targets, sources=source_ids: add_source_target_pairs(
107+
sources, targets
108+
),
109+
)
110+
111+
for resource, target_ids in new_targets_to_ids.items():
112+
cq.execute(
113+
store._query_ids_by_tag,
114+
parameters=(f"{self._kind}_s:{resource}",),
115+
# Weird syntax to capture each `target_ids` instead of the last iteration.
116+
callback=lambda sources, targets=target_ids: add_source_target_pairs(
117+
sources, targets
118+
),
119+
)
120+
121+
# TODO: we should allow passing in the concurent queries, and figure out
122+
# how to start sending these before everyting previously finished.
123+
with store._concurrent_queries() as cq:
124+
for source, target in source_target_pairs:
125+
cq.execute(store._insert_edge, (source, target))
126+
127+
return len(source_target_pairs)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Dict, Iterable, Set
3+
4+
from knowledge_store import KnowledgeStore
5+
from langchain_core.runnables import run_in_executor
6+
7+
8+
class EdgeExtractor(ABC):
9+
"""Extension defining how edges should be created."""
10+
11+
@property
12+
@abstractmethod
13+
def kind(self) -> str:
14+
"""Return the kind of edge extracted by this."""
15+
16+
def tags(self, text: str, metadata: Dict[str, Any]) -> Set[str]:
17+
"""Return the set of tags to add for this extraction."""
18+
return set()
19+
20+
@abstractmethod
21+
def extract_edges(
22+
self, store: KnowledgeStore, texts: Iterable[str], metadatas: Iterable[Dict[str, Any]]
23+
) -> int:
24+
"""Add edges for the given nodes.
25+
26+
The nodes have already been persisted.
27+
28+
Args:
29+
store: KnowledgeStore edges are being extracted for.
30+
texts: The texts of the nodes to be processed.
31+
metadatas: The metadatas of the nodes to be processed.
32+
33+
Returns:
34+
Number of edges extracted involving the given nodes.
35+
"""
36+
37+
async def aextract_edges(
38+
self, store: KnowledgeStore, texts: Iterable[str], metadatas: Iterable[Dict[str, Any]]
39+
) -> int:
40+
"""Add edges for the given nodes.
41+
42+
The nodes have already been persisted.
43+
44+
Args:
45+
store: KnowledgeStore edges are being extracted for.
46+
texts: The texts of the nodes to be processed.
47+
metadatas: The metadatas of the nodes to be processed.
48+
49+
Returns:
50+
Number of edges extracted involving the given nodes.
51+
"""
52+
return await run_in_executor(
53+
None,
54+
self._extract_edges,
55+
store,
56+
texts,
57+
metadatas,
58+
)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Any, Dict, Iterable
2+
3+
from ragstack_knowledge_store.edge_extractor import EdgeExtractor
4+
from ragstack_knowledge_store.knowledge_store import CONTENT_ID, KnowledgeStore
5+
6+
7+
class ExplicitEdgeExtractor(EdgeExtractor):
8+
def __init__(self, edges_field: str, kind: str, bidir: bool = False) -> None:
9+
"""Extract edges from explicit IDs in the metadata.
10+
11+
This extraction is faster than using a `DirectedEdgeExtractor` when the IDs
12+
are available since it doesn't need to look-up the nodes associated with a
13+
given tag.
14+
15+
Note: This extractor does not check whether the target ID exists. Edges
16+
will be created even if the target does not exist. This means traversals
17+
over graphs using this extractor may discover nodes that do not exist.
18+
Such "phantom IDs" will be filtered out when loading content for the
19+
nodes.
20+
21+
Args:
22+
edges_field: The metadata field containing the IDs of nodes to link to.
23+
kind: The `kind` to apply to edges created by this extractor.
24+
bidir: If true, creates edges in both directions.
25+
"""
26+
27+
self._edges_field = edges_field
28+
self._kind = kind
29+
self._bidir = bidir
30+
31+
@property
32+
def kind(self) -> str:
33+
return self._kind
34+
35+
def extract_edges(
36+
self, store: KnowledgeStore, texts: Iterable[str], metadatas: Iterable[Dict[str, Any]]
37+
) -> int:
38+
num_edges = 0
39+
with store._concurrent_queries() as cq:
40+
for md in metadatas:
41+
if (edges := md.get(self._edges_field, None)) is not None:
42+
id = md[CONTENT_ID]
43+
for target in set(edges):
44+
cq.execute(store._insert_edge, (id, target))
45+
num_edges += 1
46+
if self._bidir:
47+
cq.execute(store._insert_edge, (target, id))
48+
num_edges += 1
49+
return num_edges

0 commit comments

Comments
 (0)