Skip to content

Commit ac3845d

Browse files
authored
feat: Concurrent traversal (#447)
This should speed up traversals by issuing the queries for nodes/edges/ids concurrently.
1 parent 02b21cb commit ac3845d

File tree

5 files changed

+107
-80
lines changed

5 files changed

+107
-80
lines changed

libs/knowledge-store/ragstack_knowledge_store/concurrency.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ def __init__(self, session: Session, *, concurrency: int = 20) -> None:
1919

2020
self._error = None
2121

22-
def _handle_result(self,
23-
result: Sequence[NamedTuple],
24-
future: ResponseFuture,
25-
callback: Optional[Callable[[Sequence[NamedTuple]], Any]]):
22+
def _handle_result(
23+
self,
24+
result: Sequence[NamedTuple],
25+
future: ResponseFuture,
26+
callback: Optional[Callable[[Sequence[NamedTuple]], Any]],
27+
):
2628
if callback is not None:
2729
callback(result)
2830

@@ -40,22 +42,27 @@ def _handle_error(self, error):
4042
self._error = error
4143
self._completion.notify()
4244

43-
def execute(self,
44-
query: PreparedStatement,
45-
parameters: Optional[Tuple] = None,
46-
callback: Optional[str] = None):
45+
def execute(
46+
self,
47+
query: PreparedStatement,
48+
parameters: Optional[Tuple] = None,
49+
callback: Optional[Callable[[Sequence[NamedTuple]], Any]] = None,
50+
):
4751
with self._completion:
4852
self._pending += 1
4953
if self._error is not None:
5054
return
5155

5256
self._semaphore.acquire()
5357
future: ResponseFuture = self._session.execute_async(query, parameters)
54-
future.add_callbacks(self._handle_result, self._handle_error,
55-
callback_kwargs={
56-
"future": future,
57-
"callback": callback,
58-
})
58+
future.add_callbacks(
59+
self._handle_result,
60+
self._handle_error,
61+
callback_kwargs={
62+
"future": future,
63+
"callback": callback,
64+
},
65+
)
5966

6067
def __enter__(self) -> "ConcurrentQueries":
6168
return super().__enter__()

libs/knowledge-store/ragstack_knowledge_store/content.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from enum import Enum
22

3+
34
class Kind(str, Enum):
45
document = "document"
56
"""A root document (PDF, HTML, etc.).
@@ -21,4 +22,4 @@ class Kind(str, Enum):
2122
"""An image within a document."""
2223

2324
table = "table"
24-
"""A table within a document."""
25+
"""A table within a document."""

libs/knowledge-store/ragstack_knowledge_store/knowledge_store.py

Lines changed: 76 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import secrets
2-
from typing import Any, Dict, Iterable, List, Optional, Set, Union
2+
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Sequence, Union
33

44
from cassandra.cluster import ResponseFuture, Session
55
from cassio.config import check_resolve_keyspace, check_resolve_session
@@ -8,8 +8,8 @@
88
from langchain_core.runnables import Runnable, RunnableLambda
99
from langchain_core.vectorstores import VectorStore
1010

11-
from .content import Kind
1211
from .concurrency import ConcurrentQueries
12+
from .content import Kind
1313

1414
CONTENT_ID = "content_id"
1515
PARENT_CONTENT_ID = "parent_content_id"
@@ -242,6 +242,9 @@ def embeddings(self) -> Optional[Embeddings]:
242242
"""Access the query embedding object if available."""
243243
return self._embedding
244244

245+
def _concurrent_queries(self) -> ConcurrentQueries:
246+
return ConcurrentQueries(self._session, concurrency=self._concurrency)
247+
245248
# TODO: async
246249
def add_texts(
247250
self,
@@ -269,21 +272,24 @@ def add_texts(
269272
keywords_in_texts = {k for md in metadatas for k in md.get(KEYWORDS, {})}
270273
keywords_to_ids = {}
271274
if self._infer_keywords:
272-
with ConcurrentQueries(self._session, concurrency=self._concurrency) as cq:
275+
with self._concurrent_queries() as cq:
276+
273277
def handle_keywords(rows, k):
274278
related = set(_results_to_ids(rows))
275279
keywords_to_ids[k] = related
276280

277281
for k in keywords_in_texts:
278-
cq.execute(self._query_ids_by_keyword,
279-
parameters = (k,),
280-
callback = lambda rows, k1=k: handle_keywords(rows, k1))
282+
cq.execute(
283+
self._query_ids_by_keyword,
284+
parameters=(k,),
285+
callback=lambda rows, k1=k: handle_keywords(rows, k1),
286+
)
281287

282288
new_hrefs_to_ids = {}
283289
new_urls_to_ids = {}
284290

285291
ids = []
286-
with ConcurrentQueries(self._session, concurrency=self._concurrency) as cq:
292+
with self._concurrent_queries() as cq:
287293
tuples = zip(texts, text_embeddings, metadatas, strict=True)
288294
for text, text_embedding, metadata in tuples:
289295
id = metadata.get(CONTENT_ID) or secrets.token_hex(8)
@@ -297,7 +303,9 @@ def handle_keywords(rows, k):
297303
for href in hrefs:
298304
new_hrefs_to_ids.setdefault(href, set()).add(id)
299305

300-
cq.execute(self._insert_passage, (id, text, text_embedding, keywords, urls, hrefs))
306+
cq.execute(
307+
self._insert_passage, (id, text, text_embedding, keywords, urls, hrefs)
308+
)
301309

302310
if (parent_content_id := metadata.get(PARENT_CONTENT_ID)) is not None:
303311
cq.execute(self._insert_edge, (id, str(parent_content_id)))
@@ -319,7 +327,8 @@ def handle_keywords(rows, k):
319327

320328
href_url_pairs = set()
321329

322-
with ConcurrentQueries(self._session, concurrency=self._concurrency) as cq:
330+
with self._concurrent_queries() as cq:
331+
323332
def add_href_url_pairs(href_ids, url_ids):
324333
for href_id in href_ids:
325334
if not isinstance(href_id, str):
@@ -331,19 +340,23 @@ def add_href_url_pairs(href_ids, url_ids):
331340
href_url_pairs.add((href_id, url_id))
332341

333342
for href, href_ids in new_hrefs_to_ids.items():
334-
cq.execute(self._query_ids_by_url,
335-
parameters=(href, ),
336-
# Weird syntax ensures we capture each `href_ids` instead of the final value.
337-
callback=lambda urls, hrefs=href_ids: add_href_url_pairs(hrefs, urls))
343+
cq.execute(
344+
self._query_ids_by_url,
345+
parameters=(href,),
346+
# Weird syntax to capture each `href_ids` instead of the last iteration.
347+
callback=lambda urls, hrefs=href_ids: add_href_url_pairs(hrefs, urls),
348+
)
338349

339350
for url, url_ids in new_urls_to_ids.items():
340-
cq.execute(self._query_ids_by_href,
341-
parameters=(url, ),
342-
# Weird syntax ensures we capture each `url_ids` instead of the final value.
343-
callback=lambda hrefs, urls=url_ids: add_href_url_pairs(hrefs, urls))
344-
345-
with ConcurrentQueries(self._session, concurrency=self._concurrency) as cq:
346-
for (href, url) in href_url_pairs:
351+
cq.execute(
352+
self._query_ids_by_href,
353+
parameters=(url,),
354+
# Weird syntax to capture each `url_ids` instead of the last iteration.
355+
callback=lambda hrefs, urls=url_ids: add_href_url_pairs(hrefs, urls),
356+
)
357+
358+
with self._concurrent_queries() as cq:
359+
for href, url in href_url_pairs:
347360
cq.execute(self._insert_edge, (href, url))
348361
print(f"Added {len(href_url_pairs)} edges based on HREFs/URLs")
349362

@@ -409,27 +422,23 @@ def similarity_search_by_vector(
409422
results = self._session.execute(self._query_by_embedding, (query_vector, k))
410423
return _results_to_documents(results)
411424

412-
def _similarity_search_ids(
413-
self,
414-
query: str,
415-
*,
416-
k: int = 4,
417-
) -> Iterable[str]:
418-
"Return content IDs of documents by similarity to `query`."
419-
query_vector = self._embedding.embed_query(query)
420-
results = self._session.execute(self._query_ids_by_embedding, (query_vector, k))
421-
return _results_to_ids(results)
422-
423425
def _query_by_ids(
424426
self,
425427
ids: Iterable[str],
426428
) -> Iterable[Document]:
427-
# TODO: Concurrency.
428-
return [
429-
_row_to_document(row)
430-
for id in ids
431-
for row in self._session.execute(self._query_by_id, (id,))
432-
]
429+
results = []
430+
with self._concurrent_queries() as cq:
431+
for id in ids:
432+
433+
def add_documents(rows):
434+
results.extend(_results_to_documents(rows))
435+
436+
cq.execute(
437+
self._query_by_id,
438+
parameters=(id,),
439+
callback=lambda rows: add_documents(rows),
440+
)
441+
return results
433442

434443
def _linked_ids(
435444
self,
@@ -456,25 +465,36 @@ def retrieve(
456465
Collection of retrieved documents.
457466
"""
458467
if isinstance(query, str):
459-
query = [query]
460-
461-
start_ids = {
462-
content_id for q in query for content_id in self._similarity_search_ids(q, k=k)
463-
}
464-
465-
result_ids = start_ids
466-
source_ids = start_ids
467-
for _ in range(0, depth):
468-
# TODO: Concurrency
469-
level_ids = {
470-
content_id
471-
for source_id in source_ids
472-
for content_id in self._linked_ids(source_id)
473-
}
474-
result_ids.update(level_ids)
475-
source_ids = level_ids
476-
477-
return self._query_by_ids(result_ids)
468+
query = {query}
469+
else:
470+
query = set(query)
471+
472+
with self._concurrent_queries() as cq:
473+
visited = {}
474+
475+
def visit(d: int, nodes: Sequence[NamedTuple]):
476+
nonlocal visited
477+
for node in nodes:
478+
content_id = node.content_id
479+
if d <= visited.get(content_id, depth):
480+
visited[content_id] = d
481+
# We discovered this for the first time, or at a shorter depth.
482+
if d + 1 <= depth:
483+
cq.execute(
484+
self._query_linked_ids,
485+
parameters=(content_id,),
486+
callback=lambda nodes, d=d: visit(d + 1, nodes),
487+
)
488+
489+
for q in query:
490+
query_embedding = self._embedding.embed_query(q)
491+
cq.execute(
492+
self._query_ids_by_embedding,
493+
parameters=(query_embedding, k),
494+
callback=lambda nodes: visit(0, nodes),
495+
)
496+
497+
return self._query_by_ids(visited.keys())
478498

479499
def as_retriever(
480500
self,

libs/knowledge-store/tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33

44
import pytest
55
from cassandra.cluster import Cluster, Session
6+
from dotenv import load_dotenv
67
from langchain_core.documents import Document
78
from langchain_core.embeddings import Embeddings
89
from testcontainers.core.container import DockerContainer
910
from testcontainers.core.waiting_utils import wait_for_logs
10-
from dotenv import load_dotenv
1111

1212
from ragstack_knowledge_store.knowledge_store import KnowledgeStore
1313

1414
load_dotenv()
1515

16+
1617
@pytest.fixture(scope="session")
1718
def db_keyspace() -> str:
1819
return "default_keyspace"

libs/knowledge-store/tests/test_knowledge_store.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,34 @@
11
from langchain_core.documents import Document
2+
from precisely import assert_that, contains_exactly
3+
24
from .conftest import DataFixture
35

4-
from precisely import assert_that, contains_exactly
56

67
def test_write_retrieve_href_url_pair(fresh_fixture: DataFixture):
78
a = Document(
89
page_content="A",
910
metadata={
1011
"content_id": "a",
1112
"urls": ["http://a"],
12-
}
13+
},
1314
)
1415
b = Document(
1516
page_content="B",
1617
metadata={
1718
"content_id": "b",
1819
"hrefs": ["http://a"],
1920
"urls": ["http://b"],
20-
}
21+
},
2122
)
2223
c = Document(
2324
page_content="C",
2425
metadata={
2526
"content_id": "c",
2627
"hrefs": ["http://a"],
27-
}
28+
},
2829
)
2930
d = Document(
30-
page_content="D",
31-
metadata={
32-
"content_id": "d",
33-
"hrefs": ["http://a", "http://b"]
34-
}
31+
page_content="D", metadata={"content_id": "d", "hrefs": ["http://a", "http://b"]}
3532
)
3633

3734
store = fresh_fixture.store([a, b, c, d])
@@ -41,6 +38,7 @@ def test_write_retrieve_href_url_pair(fresh_fixture: DataFixture):
4138
assert_that(store._linked_ids("c"), contains_exactly("a"))
4239
assert_that(store._linked_ids("d"), contains_exactly("a", "b"))
4340

41+
4442
def test_write_retrieve_keywords(fresh_fixture: DataFixture):
4543
greetings = Document(
4644
page_content="Typical Greetings",

0 commit comments

Comments
 (0)