|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import Any, Dict, Iterable, Set |
| 4 | + |
| 5 | +from ragstack_knowledge_store.edge_extractor import EdgeExtractor |
| 6 | +from ragstack_knowledge_store.knowledge_store import CONTENT_ID, KnowledgeStore |
| 7 | + |
| 8 | + |
| 9 | +class DirectedEdgeExtractor(EdgeExtractor): |
| 10 | + def __init__(self, sources_field: str, targets_field: str, kind: str) -> None: |
| 11 | + """Extract directed edges between uses and definitions. |
| 12 | + While `UndirectedEdgeExtractor` links nodes in both directions if they share |
| 13 | + a keyword, this only creates links from nodes with a "source" to nodes with |
| 14 | + a matching "target". For example, uses may be the `href` of `a` tags in the |
| 15 | + chunk and definitions may be the URLs that the chunk is accessible at. |
| 16 | +
|
| 17 | + This may also be used for other forms of references, such as Wikipedia |
| 18 | + article IDs, etc. |
| 19 | +
|
| 20 | + Args: |
| 21 | + sources_field: The metadata field to read sources from. |
| 22 | + targets_field: The metadata field to read targets from. |
| 23 | + kind: The kind label to apply to created edges. Must be unique. |
| 24 | + """ |
| 25 | + |
| 26 | + # TODO: Assert the kind matches some reasonable regex? |
| 27 | + |
| 28 | + # TODO: Allow specifying how properties should be added to the edge. |
| 29 | + # For instance, `links_to`. |
| 30 | + self._sources_field = sources_field |
| 31 | + self._targets_field = targets_field |
| 32 | + self._kind = kind |
| 33 | + |
| 34 | + @property |
| 35 | + def kind(self) -> str: |
| 36 | + return self._kind |
| 37 | + |
| 38 | + @staticmethod |
| 39 | + def for_hrefs_to_urls() -> DirectedEdgeExtractor: |
| 40 | + return DirectedEdgeExtractor(sources_field="hrefs", targets_field="urls", kind="link") |
| 41 | + |
| 42 | + def _sources(self, metadata: Dict[str, Any]) -> Set[str]: |
| 43 | + sources = metadata.get(self._sources_field) |
| 44 | + if not sources: |
| 45 | + return set() |
| 46 | + elif isinstance(sources, str): |
| 47 | + return set({sources}) |
| 48 | + else: |
| 49 | + return set(sources) |
| 50 | + |
| 51 | + def _targets(self, metadata: Dict[str, Any]) -> Set[str]: |
| 52 | + targets = metadata.get(self._targets_field) |
| 53 | + if not targets: |
| 54 | + return set() |
| 55 | + elif isinstance(targets, str): |
| 56 | + return set({targets}) |
| 57 | + else: |
| 58 | + return set(targets) |
| 59 | + |
| 60 | + def tags(self, text: str, metadata: Dict[str, Any]) -> Set[str]: |
| 61 | + results = set() |
| 62 | + for source in self._sources(metadata): |
| 63 | + results.add(f"{self._kind}_s:{source}") |
| 64 | + for target in self._targets(metadata): |
| 65 | + results.add(f"{self._kind}_t:{target}") |
| 66 | + return results |
| 67 | + |
| 68 | + def extract_edges( |
| 69 | + self, store: KnowledgeStore, texts: Iterable[str], metadatas: Iterable[Dict[str, Any]] |
| 70 | + ) -> int: |
| 71 | + # First, iterate over the new nodes, collecting the sources/targets that |
| 72 | + # are referenced and which IDs contain those. |
| 73 | + new_ids = set() |
| 74 | + new_sources_to_ids = {} |
| 75 | + new_targets_to_ids = {} |
| 76 | + for md in metadatas: |
| 77 | + id = md[CONTENT_ID] |
| 78 | + |
| 79 | + new_ids.add(id) |
| 80 | + for resource in self._sources(md): |
| 81 | + new_sources_to_ids.setdefault(resource, set()).add(id) |
| 82 | + for target in self._targets(md): |
| 83 | + new_targets_to_ids.setdefault(target, set()).add(id) |
| 84 | + |
| 85 | + # Then, retrieve the set of persisted items for each of those |
| 86 | + # source/targets and link them to the new items as needed. |
| 87 | + # Remembering that the the *new* nodes will have been added. |
| 88 | + source_target_pairs = set() |
| 89 | + with store._concurrent_queries() as cq: |
| 90 | + |
| 91 | + def add_source_target_pairs(href_ids, url_ids): |
| 92 | + for href_id in href_ids: |
| 93 | + if not isinstance(href_id, str): |
| 94 | + href_id = href_id.content_id |
| 95 | + |
| 96 | + for url_id in url_ids: |
| 97 | + if not isinstance(url_id, str): |
| 98 | + url_id = url_id.content_id |
| 99 | + source_target_pairs.add((href_id, url_id)) |
| 100 | + |
| 101 | + for resource, source_ids in new_sources_to_ids.items(): |
| 102 | + cq.execute( |
| 103 | + store._query_ids_by_tag, |
| 104 | + parameters=(f"{self._kind}_t:{resource}",), |
| 105 | + # Weird syntax to capture each `source_ids` instead of the last iteration. |
| 106 | + callback=lambda targets, sources=source_ids: add_source_target_pairs( |
| 107 | + sources, targets |
| 108 | + ), |
| 109 | + ) |
| 110 | + |
| 111 | + for resource, target_ids in new_targets_to_ids.items(): |
| 112 | + cq.execute( |
| 113 | + store._query_ids_by_tag, |
| 114 | + parameters=(f"{self._kind}_s:{resource}",), |
| 115 | + # Weird syntax to capture each `target_ids` instead of the last iteration. |
| 116 | + callback=lambda sources, targets=target_ids: add_source_target_pairs( |
| 117 | + sources, targets |
| 118 | + ), |
| 119 | + ) |
| 120 | + |
| 121 | + # TODO: we should allow passing in the concurent queries, and figure out |
| 122 | + # how to start sending these before everyting previously finished. |
| 123 | + with store._concurrent_queries() as cq: |
| 124 | + for source, target in source_target_pairs: |
| 125 | + cq.execute(store._insert_edge, (source, target)) |
| 126 | + |
| 127 | + return len(source_target_pairs) |
0 commit comments