2
2
from dataclasses import dataclass
3
3
from typing import (
4
4
Any ,
5
+ Dict ,
5
6
Iterable ,
6
7
List ,
7
8
NamedTuple ,
8
9
Optional ,
9
10
Sequence ,
11
+ Tuple ,
10
12
Type ,
11
13
)
12
14
18
20
from langchain_core .documents import Document
19
21
from langchain_core .embeddings import Embeddings
20
22
21
- from ragstack_knowledge_store .edge_extractor import EdgeExtractor
23
+ from ragstack_knowledge_store .edge_extractor import get_link_tags
22
24
23
25
from ._utils import strict_zip
24
26
from .base import KnowledgeStore , Node , TextNode
@@ -97,7 +99,6 @@ class CassandraKnowledgeStore(KnowledgeStore):
97
99
def __init__ (
98
100
self ,
99
101
embedding : Embeddings ,
100
- edge_extractors : List [EdgeExtractor ],
101
102
* ,
102
103
node_table : str = "knowledge_nodes" ,
103
104
edge_table : str = "knowledge_edges" ,
@@ -111,16 +112,8 @@ def __init__(
111
112
Document chunks support vector-similarity search as well as edges linking
112
113
documents based on structural and semantic properties.
113
114
114
- Parameters configure the ways that edges should be added between
115
- documents. Many take `Union[bool, Set[str]]`, with `False` disabling
116
- inference, `True` enabling it globally between all documents, and a set
117
- of metadata fields defining a scope in which to enable it. Specifically,
118
- passing a set of metadata fields such as `source` only links documents
119
- with the same `source` metadata value.
120
-
121
115
Args:
122
116
embedding: The embeddings to use for the document content.
123
- edge_extractors: Edge extractors to use for linking knowledge chunks.
124
117
concurrency: Maximum number of queries to have concurrently executing.
125
118
apply_schema: If true, the schema will be created if necessary. If false,
126
119
the schema must have already been applied.
@@ -143,17 +136,13 @@ def __init__(
143
136
"Only SYNC and OFF are supported at the moment"
144
137
)
145
138
146
- # Ensure the edge extractor `kind`s are unique.
147
- assert len (edge_extractors ) == len (set ([e .kind for e in edge_extractors ]))
148
- self ._edge_extractors = edge_extractors
149
-
150
139
# TODO: Metadata
151
140
# TODO: Parent ID / source ID / etc.
152
141
self ._insert_passage = session .prepare (
153
142
f"""
154
143
INSERT INTO { keyspace } .{ node_table } (
155
- content_id, kind, text_content, text_embedding, tags
156
- ) VALUES (?, '{ Kind .passage } ', ?, ?, ?)
144
+ content_id, kind, text_content, text_embedding, link_to_tags, link_from_tags
145
+ ) VALUES (?, '{ Kind .passage } ', ?, ?, ?, ? )
157
146
"""
158
147
)
159
148
@@ -229,19 +218,27 @@ def __init__(
229
218
"""
230
219
)
231
220
232
- self ._query_ids_by_tag = session .prepare (
221
+ self ._query_ids_by_link_to_tag = session .prepare (
233
222
f"""
234
223
SELECT content_id
235
224
FROM { keyspace } .{ node_table }
236
- WHERE tags CONTAINS ?
225
+ WHERE link_to_tags CONTAINS ?
237
226
"""
238
227
)
239
228
240
- self ._query_ids_and_embedding_by_tag = session .prepare (
229
+ self ._query_ids_and_embedding_by_link_to_tag = session .prepare (
241
230
f"""
242
231
SELECT content_id, text_embedding
243
232
FROM { keyspace } .{ node_table }
244
- WHERE tags CONTAINS ?
233
+ WHERE link_to_tags CONTAINS ?
234
+ """
235
+ )
236
+
237
+ self ._query_ids_and_embedding_by_link_from_tag = session .prepare (
238
+ f"""
239
+ SELECT content_id, text_embedding
240
+ FROM { keyspace } .{ node_table }
241
+ WHERE link_from_tags CONTAINS ?
245
242
"""
246
243
)
247
244
@@ -255,7 +252,8 @@ def _apply_schema(self):
255
252
text_content TEXT,
256
253
text_embedding VECTOR<FLOAT, { embedding_dim } >,
257
254
258
- tags SET<TEXT>,
255
+ link_to_tags SET<TEXT>,
256
+ link_from_tags SET<TEXT>,
259
257
260
258
PRIMARY KEY (content_id)
261
259
)
@@ -289,8 +287,16 @@ def _apply_schema(self):
289
287
# Index on tags
290
288
self ._session .execute (
291
289
f"""
292
- CREATE CUSTOM INDEX IF NOT EXISTS { self ._node_table } _tags_index
293
- ON { self ._keyspace } .{ self ._node_table } (tags)
290
+ CREATE CUSTOM INDEX IF NOT EXISTS { self ._node_table } _link_from_tags_index
291
+ ON { self ._keyspace } .{ self ._node_table } (link_from_tags)
292
+ USING 'StorageAttachedIndex';
293
+ """
294
+ )
295
+
296
+ self ._session .execute (
297
+ f"""
298
+ CREATE CUSTOM INDEX IF NOT EXISTS { self ._node_table } _link_to_tags_index
299
+ ON { self ._keyspace } .{ self ._node_table } (link_to_tags)
294
300
USING 'StorageAttachedIndex';
295
301
"""
296
302
)
@@ -319,6 +325,11 @@ def add_nodes(
319
325
text_embeddings = self ._embedding .embed_documents (texts )
320
326
321
327
ids = []
328
+
329
+ tag_to_new_sources : Dict [str , List [Tuple [str , str ]]] = {}
330
+ tag_to_new_targets : Dict [str , Dict [str , Tuple [str , List [float ]]]] = {}
331
+
332
+ # Step 1: Add the nodes, collecting the tags and new sources / targets.
322
333
with self ._concurrent_queries () as cq :
323
334
tuples = strict_zip (texts , text_embeddings , metadatas )
324
335
for text , text_embedding , metadata in tuples :
@@ -327,13 +338,118 @@ def add_nodes(
327
338
id = metadata [CONTENT_ID ]
328
339
ids .append (id )
329
340
330
- tags = set ()
331
- tags .update (* [e .tags (text , metadata ) for e in self ._edge_extractors ])
341
+ link_to_tags = set () # link to these tags
342
+ link_from_tags = set () # link from these tags
343
+
344
+ for tag in get_link_tags (metadata ):
345
+ tag_str = f"{ tag .kind } :{ tag .tag } "
346
+ if tag .direction == "incoming" or tag .direction == "bidir" :
347
+ # An incom`ing link should be linked *from* nodes with the given tag.
348
+ link_from_tags .add (tag_str )
349
+ tag_to_new_targets .setdefault (tag_str , dict ())[id ] = (tag .kind , text_embedding )
350
+ if tag .direction == "outgoing" or tag .direction == "bidir" :
351
+ link_to_tags .add (tag_str )
352
+ tag_to_new_sources .setdefault (tag_str , list ()).append ((tag .kind , id ))
353
+
354
+ cq .execute (self ._insert_passage , (id , text , text_embedding , link_to_tags , link_from_tags ))
355
+
356
+ # Step 2: Query information about those tags to determine the edges to add.
357
+ # Add edges as needed.
358
+ id_set = set (ids )
359
+ with self ._concurrent_queries () as cq :
360
+ edges = []
361
+ def add_edge (source_id , target_id , kind , target_embedding ):
362
+ nonlocal added_edges
363
+ if source_id == target_id :
364
+ # Don't add self-cycles (could happen with bidirectional tags).
365
+ return
366
+
367
+ edges .append ((source_id , target_id , kind , target_embedding ))
368
+
369
+ # TODO: Would be good to be able to execute these... but
370
+ # may cause problems if we can't execute it right away
371
+ # (because of a pending query) and we can't complete
372
+ # the pending queries (because we can't finish the callback).
373
+
374
+ # cq.execute(
375
+ # self._insert_edge,
376
+ # (source_id, target_id, kind, target_embedding),
377
+ # )
378
+
379
+ def add_edges_for_sources (
380
+ source_rows ,
381
+ target_embeddings : Dict [(str , List [float ])],
382
+ ):
383
+ for source_id in source_rows :
384
+ if source_id in id_set :
385
+ # Source ID is new, and anything in `target_embeddings` is too.
386
+ # Don't add here.
387
+ continue
388
+
389
+ for target_id , (kind , target_emb ) in target_embeddings .items ():
390
+ add_edge (source_id .content_id , target_id , kind , target_emb )
391
+
392
+ def add_edges_for_targets (
393
+ sources : Iterable [Tuple [str , str ]],
394
+ target_rows ,
395
+ ):
396
+ for target in target_rows :
397
+ if target .content_id in id_set :
398
+ # Target ID is new, and anything in `sources` is too.
399
+ # Don't add here (will be handled later).
400
+ continue
401
+
402
+ for (kind , source_id ) in sources :
403
+ add_edge (source_id , target .content_id , kind , target .text_embedding )
404
+
405
+ for tag , new_target_embs in tag_to_new_targets .items ():
406
+ # For each new node with a `link_from_tag`, find the source
407
+ # nodes with that `link_to_tag`` and create the edges.
408
+ cq .execute (
409
+ self ._query_ids_by_link_to_tag ,
410
+ parameters = (tag , ),
411
+ callback = lambda sources , targets = new_target_embs : add_edges_for_sources (
412
+ sources , targets )
413
+ )
414
+
415
+ for tag , new_sources in tag_to_new_sources .items ():
416
+ # For each new node with a `link_to_tag`, find the target
417
+ # nodes with that `link_from_tag` tag and create the edges.
418
+ cq .execute (
419
+ self ._query_ids_and_embedding_by_link_from_tag ,
420
+ parameters = (tag , ),
421
+ callback = lambda targets , sources = new_sources : add_edges_for_targets (
422
+ sources , targets )
423
+ )
424
+
425
+ # Step 3: Add edges.
426
+ # TODO: Combine steps, ideally to a single set of concurrent queries.
427
+ # This should be possible, but will require some form of queueing, since
428
+ # we need to be able to handle a result set, and that may require us to queue
429
+ # more than |max concurency| edges.
430
+ added_edges = 0
431
+ with self ._concurrent_queries () as cq :
432
+ print ("Adding edges" )
433
+ # Add edges from query results (should be one new node and one old node)
434
+ for edge in edges :
435
+ added_edges += 1
436
+ cq .execute (self ._insert_edge , edge )
437
+
438
+ # Add edges for new nodes
439
+ for tag , new_sources in tag_to_new_sources .items ():
440
+ for (kind , source_id ) in new_sources :
441
+ new_targets = tag_to_new_targets .get (tag , None )
442
+ if new_targets is None :
443
+ continue
332
444
333
- cq .execute (self ._insert_passage , (id , text , text_embedding , tags ))
445
+ for (target_id , (target_kind , target_embedding )) in new_targets .items ():
446
+ # TODO: Improve the structures so this can be a lookup?
447
+ if target_kind == kind and source_id != target_id :
448
+ added_edges += 1
449
+ cq .execute (self ._insert_edge ,
450
+ (source_id , target_id , kind , target_embedding ))
334
451
335
- for extractor in self ._edge_extractors :
336
- extractor .extract_edges (self , texts , text_embeddings , metadatas )
452
+ print (f"Added { added_edges } edges" )
337
453
338
454
return ids
339
455
0 commit comments