Skip to content

Commit 850003e

Browse files
authored
feat: Use tags exclusively for edge creation (#471)
* feat: Use tags exclusively for edge creation This enables edge extractors to be implemented as document transformers configuring the metadata, rather than needing to interact with the underlying storage.
1 parent cc222f0 commit 850003e

File tree

11 files changed

+819
-629
lines changed

11 files changed

+819
-629
lines changed

libs/knowledge-store/notebooks/astra_support.ipynb

Lines changed: 484 additions & 151 deletions
Large diffs are not rendered by default.

libs/knowledge-store/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ packages = [{ include = "ragstack_knowledge_store" }]
1313
python = ">=3.10,<3.13"
1414
langchain-core = "^0.2"
1515
cassio = "^0.1.7"
16+
asyncstdlib = "^3.12.3"
1617

1718
[tool.poetry.group.dev.dependencies]
1819
ruff = "*"

libs/knowledge-store/ragstack_knowledge_store/cassandra.py

Lines changed: 144 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
from dataclasses import dataclass
33
from typing import (
44
Any,
5+
Dict,
56
Iterable,
67
List,
78
NamedTuple,
89
Optional,
910
Sequence,
11+
Tuple,
1012
Type,
1113
)
1214

@@ -18,7 +20,7 @@
1820
from langchain_core.documents import Document
1921
from langchain_core.embeddings import Embeddings
2022

21-
from ragstack_knowledge_store.edge_extractor import EdgeExtractor
23+
from ragstack_knowledge_store.edge_extractor import get_link_tags
2224

2325
from ._utils import strict_zip
2426
from .base import KnowledgeStore, Node, TextNode
@@ -97,7 +99,6 @@ class CassandraKnowledgeStore(KnowledgeStore):
9799
def __init__(
98100
self,
99101
embedding: Embeddings,
100-
edge_extractors: List[EdgeExtractor],
101102
*,
102103
node_table: str = "knowledge_nodes",
103104
edge_table: str = "knowledge_edges",
@@ -111,16 +112,8 @@ def __init__(
111112
Document chunks support vector-similarity search as well as edges linking
112113
documents based on structural and semantic properties.
113114
114-
Parameters configure the ways that edges should be added between
115-
documents. Many take `Union[bool, Set[str]]`, with `False` disabling
116-
inference, `True` enabling it globally between all documents, and a set
117-
of metadata fields defining a scope in which to enable it. Specifically,
118-
passing a set of metadata fields such as `source` only links documents
119-
with the same `source` metadata value.
120-
121115
Args:
122116
embedding: The embeddings to use for the document content.
123-
edge_extractors: Edge extractors to use for linking knowledge chunks.
124117
concurrency: Maximum number of queries to have concurrently executing.
125118
apply_schema: If true, the schema will be created if necessary. If false,
126119
the schema must have already been applied.
@@ -143,17 +136,13 @@ def __init__(
143136
"Only SYNC and OFF are supported at the moment"
144137
)
145138

146-
# Ensure the edge extractor `kind`s are unique.
147-
assert len(edge_extractors) == len(set([e.kind for e in edge_extractors]))
148-
self._edge_extractors = edge_extractors
149-
150139
# TODO: Metadata
151140
# TODO: Parent ID / source ID / etc.
152141
self._insert_passage = session.prepare(
153142
f"""
154143
INSERT INTO {keyspace}.{node_table} (
155-
content_id, kind, text_content, text_embedding, tags
156-
) VALUES (?, '{Kind.passage}', ?, ?, ?)
144+
content_id, kind, text_content, text_embedding, link_to_tags, link_from_tags
145+
) VALUES (?, '{Kind.passage}', ?, ?, ?, ?)
157146
"""
158147
)
159148

@@ -229,19 +218,27 @@ def __init__(
229218
"""
230219
)
231220

232-
self._query_ids_by_tag = session.prepare(
221+
self._query_ids_by_link_to_tag = session.prepare(
233222
f"""
234223
SELECT content_id
235224
FROM {keyspace}.{node_table}
236-
WHERE tags CONTAINS ?
225+
WHERE link_to_tags CONTAINS ?
237226
"""
238227
)
239228

240-
self._query_ids_and_embedding_by_tag = session.prepare(
229+
self._query_ids_and_embedding_by_link_to_tag = session.prepare(
241230
f"""
242231
SELECT content_id, text_embedding
243232
FROM {keyspace}.{node_table}
244-
WHERE tags CONTAINS ?
233+
WHERE link_to_tags CONTAINS ?
234+
"""
235+
)
236+
237+
self._query_ids_and_embedding_by_link_from_tag = session.prepare(
238+
f"""
239+
SELECT content_id, text_embedding
240+
FROM {keyspace}.{node_table}
241+
WHERE link_from_tags CONTAINS ?
245242
"""
246243
)
247244

@@ -255,7 +252,8 @@ def _apply_schema(self):
255252
text_content TEXT,
256253
text_embedding VECTOR<FLOAT, {embedding_dim}>,
257254
258-
tags SET<TEXT>,
255+
link_to_tags SET<TEXT>,
256+
link_from_tags SET<TEXT>,
259257
260258
PRIMARY KEY (content_id)
261259
)
@@ -289,8 +287,16 @@ def _apply_schema(self):
289287
# Index on tags
290288
self._session.execute(
291289
f"""
292-
CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_tags_index
293-
ON {self._keyspace}.{self._node_table} (tags)
290+
CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_link_from_tags_index
291+
ON {self._keyspace}.{self._node_table} (link_from_tags)
292+
USING 'StorageAttachedIndex';
293+
"""
294+
)
295+
296+
self._session.execute(
297+
f"""
298+
CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_link_to_tags_index
299+
ON {self._keyspace}.{self._node_table} (link_to_tags)
294300
USING 'StorageAttachedIndex';
295301
"""
296302
)
@@ -319,6 +325,11 @@ def add_nodes(
319325
text_embeddings = self._embedding.embed_documents(texts)
320326

321327
ids = []
328+
329+
tag_to_new_sources: Dict[str, List[Tuple[str, str]]] = {}
330+
tag_to_new_targets: Dict[str, Dict[str, Tuple[str, List[float]]]] = {}
331+
332+
# Step 1: Add the nodes, collecting the tags and new sources / targets.
322333
with self._concurrent_queries() as cq:
323334
tuples = strict_zip(texts, text_embeddings, metadatas)
324335
for text, text_embedding, metadata in tuples:
@@ -327,13 +338,118 @@ def add_nodes(
327338
id = metadata[CONTENT_ID]
328339
ids.append(id)
329340

330-
tags = set()
331-
tags.update(*[e.tags(text, metadata) for e in self._edge_extractors])
341+
link_to_tags = set() # link to these tags
342+
link_from_tags = set() # link from these tags
343+
344+
for tag in get_link_tags(metadata):
345+
tag_str = f"{tag.kind}:{tag.tag}"
346+
if tag.direction == "incoming" or tag.direction == "bidir":
347+
# An incom`ing link should be linked *from* nodes with the given tag.
348+
link_from_tags.add(tag_str)
349+
tag_to_new_targets.setdefault(tag_str, dict())[id] = (tag.kind, text_embedding)
350+
if tag.direction == "outgoing" or tag.direction == "bidir":
351+
link_to_tags.add(tag_str)
352+
tag_to_new_sources.setdefault(tag_str, list()).append((tag.kind, id))
353+
354+
cq.execute(self._insert_passage, (id, text, text_embedding, link_to_tags, link_from_tags))
355+
356+
# Step 2: Query information about those tags to determine the edges to add.
357+
# Add edges as needed.
358+
id_set = set(ids)
359+
with self._concurrent_queries() as cq:
360+
edges = []
361+
def add_edge(source_id, target_id, kind, target_embedding):
362+
nonlocal added_edges
363+
if source_id == target_id:
364+
# Don't add self-cycles (could happen with bidirectional tags).
365+
return
366+
367+
edges.append((source_id, target_id, kind, target_embedding))
368+
369+
# TODO: Would be good to be able to execute these... but
370+
# may cause problems if we can't execute it right away
371+
# (because of a pending query) and we can't complete
372+
# the pending queries (because we can't finish the callback).
373+
374+
# cq.execute(
375+
# self._insert_edge,
376+
# (source_id, target_id, kind, target_embedding),
377+
# )
378+
379+
def add_edges_for_sources(
380+
source_rows,
381+
target_embeddings: Dict[(str, List[float])],
382+
):
383+
for source_id in source_rows:
384+
if source_id in id_set:
385+
# Source ID is new, and anything in `target_embeddings` is too.
386+
# Don't add here.
387+
continue
388+
389+
for target_id, (kind, target_emb) in target_embeddings.items():
390+
add_edge(source_id.content_id, target_id, kind, target_emb)
391+
392+
def add_edges_for_targets(
393+
sources: Iterable[Tuple[str, str]],
394+
target_rows,
395+
):
396+
for target in target_rows:
397+
if target.content_id in id_set:
398+
# Target ID is new, and anything in `sources` is too.
399+
# Don't add here (will be handled later).
400+
continue
401+
402+
for (kind, source_id) in sources:
403+
add_edge(source_id, target.content_id, kind, target.text_embedding)
404+
405+
for tag, new_target_embs in tag_to_new_targets.items():
406+
# For each new node with a `link_from_tag`, find the source
407+
# nodes with that `link_to_tag`` and create the edges.
408+
cq.execute(
409+
self._query_ids_by_link_to_tag,
410+
parameters=(tag, ),
411+
callback=lambda sources, targets=new_target_embs: add_edges_for_sources(
412+
sources, targets)
413+
)
414+
415+
for tag, new_sources in tag_to_new_sources.items():
416+
# For each new node with a `link_to_tag`, find the target
417+
# nodes with that `link_from_tag` tag and create the edges.
418+
cq.execute(
419+
self._query_ids_and_embedding_by_link_from_tag,
420+
parameters=(tag, ),
421+
callback=lambda targets, sources=new_sources: add_edges_for_targets(
422+
sources, targets)
423+
)
424+
425+
# Step 3: Add edges.
426+
# TODO: Combine steps, ideally to a single set of concurrent queries.
427+
# This should be possible, but will require some form of queueing, since
428+
# we need to be able to handle a result set, and that may require us to queue
429+
# more than |max concurency| edges.
430+
added_edges = 0
431+
with self._concurrent_queries() as cq:
432+
print("Adding edges")
433+
# Add edges from query results (should be one new node and one old node)
434+
for edge in edges:
435+
added_edges += 1
436+
cq.execute(self._insert_edge, edge)
437+
438+
# Add edges for new nodes
439+
for tag, new_sources in tag_to_new_sources.items():
440+
for (kind, source_id) in new_sources:
441+
new_targets = tag_to_new_targets.get(tag, None)
442+
if new_targets is None:
443+
continue
332444

333-
cq.execute(self._insert_passage, (id, text, text_embedding, tags))
445+
for (target_id, (target_kind, target_embedding)) in new_targets.items():
446+
# TODO: Improve the structures so this can be a lookup?
447+
if target_kind == kind and source_id != target_id:
448+
added_edges += 1
449+
cq.execute(self._insert_edge,
450+
(source_id, target_id, kind, target_embedding))
334451

335-
for extractor in self._edge_extractors:
336-
extractor.extract_edges(self, texts, text_embeddings, metadatas)
452+
print(f"Added {added_edges} edges")
337453

338454
return ids
339455

0 commit comments

Comments
 (0)