1
1
import secrets
2
- from typing import Any , Dict , Iterable , List , Optional , Set , Union
2
+ from typing import Any , Dict , Iterable , List , NamedTuple , Optional , Sequence , Union
3
3
4
4
from cassandra .cluster import ResponseFuture , Session
5
5
from cassio .config import check_resolve_keyspace , check_resolve_session
8
8
from langchain_core .runnables import Runnable , RunnableLambda
9
9
from langchain_core .vectorstores import VectorStore
10
10
11
- from .content import Kind
12
11
from .concurrency import ConcurrentQueries
12
+ from .content import Kind
13
13
14
14
CONTENT_ID = "content_id"
15
15
PARENT_CONTENT_ID = "parent_content_id"
@@ -242,6 +242,9 @@ def embeddings(self) -> Optional[Embeddings]:
242
242
"""Access the query embedding object if available."""
243
243
return self ._embedding
244
244
245
+ def _concurrent_queries (self ) -> ConcurrentQueries :
246
+ return ConcurrentQueries (self ._session , concurrency = self ._concurrency )
247
+
245
248
# TODO: async
246
249
def add_texts (
247
250
self ,
@@ -269,21 +272,24 @@ def add_texts(
269
272
keywords_in_texts = {k for md in metadatas for k in md .get (KEYWORDS , {})}
270
273
keywords_to_ids = {}
271
274
if self ._infer_keywords :
272
- with ConcurrentQueries (self ._session , concurrency = self ._concurrency ) as cq :
275
+ with self ._concurrent_queries () as cq :
276
+
273
277
def handle_keywords (rows , k ):
274
278
related = set (_results_to_ids (rows ))
275
279
keywords_to_ids [k ] = related
276
280
277
281
for k in keywords_in_texts :
278
- cq .execute (self ._query_ids_by_keyword ,
279
- parameters = (k ,),
280
- callback = lambda rows , k1 = k : handle_keywords (rows , k1 ))
282
+ cq .execute (
283
+ self ._query_ids_by_keyword ,
284
+ parameters = (k ,),
285
+ callback = lambda rows , k1 = k : handle_keywords (rows , k1 ),
286
+ )
281
287
282
288
new_hrefs_to_ids = {}
283
289
new_urls_to_ids = {}
284
290
285
291
ids = []
286
- with ConcurrentQueries ( self ._session , concurrency = self . _concurrency ) as cq :
292
+ with self ._concurrent_queries ( ) as cq :
287
293
tuples = zip (texts , text_embeddings , metadatas , strict = True )
288
294
for text , text_embedding , metadata in tuples :
289
295
id = metadata .get (CONTENT_ID ) or secrets .token_hex (8 )
@@ -297,7 +303,9 @@ def handle_keywords(rows, k):
297
303
for href in hrefs :
298
304
new_hrefs_to_ids .setdefault (href , set ()).add (id )
299
305
300
- cq .execute (self ._insert_passage , (id , text , text_embedding , keywords , urls , hrefs ))
306
+ cq .execute (
307
+ self ._insert_passage , (id , text , text_embedding , keywords , urls , hrefs )
308
+ )
301
309
302
310
if (parent_content_id := metadata .get (PARENT_CONTENT_ID )) is not None :
303
311
cq .execute (self ._insert_edge , (id , str (parent_content_id )))
@@ -319,7 +327,8 @@ def handle_keywords(rows, k):
319
327
320
328
href_url_pairs = set ()
321
329
322
- with ConcurrentQueries (self ._session , concurrency = self ._concurrency ) as cq :
330
+ with self ._concurrent_queries () as cq :
331
+
323
332
def add_href_url_pairs (href_ids , url_ids ):
324
333
for href_id in href_ids :
325
334
if not isinstance (href_id , str ):
@@ -331,19 +340,23 @@ def add_href_url_pairs(href_ids, url_ids):
331
340
href_url_pairs .add ((href_id , url_id ))
332
341
333
342
for href , href_ids in new_hrefs_to_ids .items ():
334
- cq .execute (self ._query_ids_by_url ,
335
- parameters = (href , ),
336
- # Weird syntax ensures we capture each `href_ids` instead of the final value.
337
- callback = lambda urls , hrefs = href_ids : add_href_url_pairs (hrefs , urls ))
343
+ cq .execute (
344
+ self ._query_ids_by_url ,
345
+ parameters = (href ,),
346
+ # Weird syntax to capture each `href_ids` instead of the last iteration.
347
+ callback = lambda urls , hrefs = href_ids : add_href_url_pairs (hrefs , urls ),
348
+ )
338
349
339
350
for url , url_ids in new_urls_to_ids .items ():
340
- cq .execute (self ._query_ids_by_href ,
341
- parameters = (url , ),
342
- # Weird syntax ensures we capture each `url_ids` instead of the final value.
343
- callback = lambda hrefs , urls = url_ids : add_href_url_pairs (hrefs , urls ))
344
-
345
- with ConcurrentQueries (self ._session , concurrency = self ._concurrency ) as cq :
346
- for (href , url ) in href_url_pairs :
351
+ cq .execute (
352
+ self ._query_ids_by_href ,
353
+ parameters = (url ,),
354
+ # Weird syntax to capture each `url_ids` instead of the last iteration.
355
+ callback = lambda hrefs , urls = url_ids : add_href_url_pairs (hrefs , urls ),
356
+ )
357
+
358
+ with self ._concurrent_queries () as cq :
359
+ for href , url in href_url_pairs :
347
360
cq .execute (self ._insert_edge , (href , url ))
348
361
print (f"Added { len (href_url_pairs )} edges based on HREFs/URLs" )
349
362
@@ -409,27 +422,23 @@ def similarity_search_by_vector(
409
422
results = self ._session .execute (self ._query_by_embedding , (query_vector , k ))
410
423
return _results_to_documents (results )
411
424
412
- def _similarity_search_ids (
413
- self ,
414
- query : str ,
415
- * ,
416
- k : int = 4 ,
417
- ) -> Iterable [str ]:
418
- "Return content IDs of documents by similarity to `query`."
419
- query_vector = self ._embedding .embed_query (query )
420
- results = self ._session .execute (self ._query_ids_by_embedding , (query_vector , k ))
421
- return _results_to_ids (results )
422
-
423
425
def _query_by_ids (
424
426
self ,
425
427
ids : Iterable [str ],
426
428
) -> Iterable [Document ]:
427
- # TODO: Concurrency.
428
- return [
429
- _row_to_document (row )
430
- for id in ids
431
- for row in self ._session .execute (self ._query_by_id , (id ,))
432
- ]
429
+ results = []
430
+ with self ._concurrent_queries () as cq :
431
+ for id in ids :
432
+
433
+ def add_documents (rows ):
434
+ results .extend (_results_to_documents (rows ))
435
+
436
+ cq .execute (
437
+ self ._query_by_id ,
438
+ parameters = (id ,),
439
+ callback = lambda rows : add_documents (rows ),
440
+ )
441
+ return results
433
442
434
443
def _linked_ids (
435
444
self ,
@@ -456,25 +465,36 @@ def retrieve(
456
465
Collection of retrieved documents.
457
466
"""
458
467
if isinstance (query , str ):
459
- query = [query ]
460
-
461
- start_ids = {
462
- content_id for q in query for content_id in self ._similarity_search_ids (q , k = k )
463
- }
464
-
465
- result_ids = start_ids
466
- source_ids = start_ids
467
- for _ in range (0 , depth ):
468
- # TODO: Concurrency
469
- level_ids = {
470
- content_id
471
- for source_id in source_ids
472
- for content_id in self ._linked_ids (source_id )
473
- }
474
- result_ids .update (level_ids )
475
- source_ids = level_ids
476
-
477
- return self ._query_by_ids (result_ids )
468
+ query = {query }
469
+ else :
470
+ query = set (query )
471
+
472
+ with self ._concurrent_queries () as cq :
473
+ visited = {}
474
+
475
+ def visit (d : int , nodes : Sequence [NamedTuple ]):
476
+ nonlocal visited
477
+ for node in nodes :
478
+ content_id = node .content_id
479
+ if d <= visited .get (content_id , depth ):
480
+ visited [content_id ] = d
481
+ # We discovered this for the first time, or at a shorter depth.
482
+ if d + 1 <= depth :
483
+ cq .execute (
484
+ self ._query_linked_ids ,
485
+ parameters = (content_id ,),
486
+ callback = lambda nodes , d = d : visit (d + 1 , nodes ),
487
+ )
488
+
489
+ for q in query :
490
+ query_embedding = self ._embedding .embed_query (q )
491
+ cq .execute (
492
+ self ._query_ids_by_embedding ,
493
+ parameters = (query_embedding , k ),
494
+ callback = lambda nodes : visit (0 , nodes ),
495
+ )
496
+
497
+ return self ._query_by_ids (visited .keys ())
478
498
479
499
def as_retriever (
480
500
self ,
0 commit comments