Skip to content

Commit 0cbbac4

Browse files
authored
Cleanup knowledge store code (#475)
1 parent 259b101 commit 0cbbac4

File tree

8 files changed

+160
-93
lines changed

8 files changed

+160
-93
lines changed

libs/knowledge-store/pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ packages = [{ include = "ragstack_knowledge_store" }]
1313
python = ">=3.10,<3.13"
1414
langchain-core = "^0.2"
1515
cassio = "^0.1.7"
16-
asyncstdlib = "^3.12.3"
1716

1817
[tool.poetry.group.dev.dependencies]
1918
ruff = "*"

libs/knowledge-store/ragstack_knowledge_store/_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ def batched(iterable: Iterable[T], n: int) -> Iterator[Iterator[T]]:
1919
while batch := tuple(islice(it, n)):
2020
yield batch
2121

22+
2223
# TODO: Remove the "polyfill" when we required python is >= 3.10.
2324

2425
if sys.version_info >= (3, 10):
2526

2627
def strict_zip(*iterables):
2728
return zip(*iterables, strict=True)
29+
2830
else:
2931

3032
def strict_zip(*iterables):

libs/knowledge-store/ragstack_knowledge_store/base.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def mmr_traversal_search(
223223
depth: int = 2,
224224
fetch_k: int = 100,
225225
lambda_mult: float = 0.5,
226-
score_threshold: float = float('-inf'),
226+
score_threshold: float = float("-inf"),
227227
**kwargs: Any,
228228
) -> Iterable[Document]:
229229
"""Retrieve documents from this knowledge store using MMR-traversal.
@@ -258,7 +258,7 @@ async def ammr_traversal_search(
258258
depth: int = 2,
259259
fetch_k: int = 100,
260260
lambda_mult: float = 0.5,
261-
score_threshold: float = float('-inf'),
261+
score_threshold: float = float("-inf"),
262262
**kwargs: Any,
263263
) -> AsyncIterable[Document]:
264264
"""Retrieve documents from this knowledge store using MMR-traversal.
@@ -297,17 +297,23 @@ async def ammr_traversal_search(
297297
):
298298
yield doc
299299

300-
def similarity_search(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
300+
def similarity_search(
301+
self, query: str, k: int = 4, **kwargs: Any
302+
) -> List[Document]:
301303
return list(self.traversal_search(query, k=k, depth=0))
302304

303-
async def asimilarity_search(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
305+
async def asimilarity_search(
306+
self, query: str, k: int = 4, **kwargs: Any
307+
) -> List[Document]:
304308
return [doc async for doc in self.atraversal_search(query, k=k, depth=0)]
305309

306310
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
307311
if search_type == "similarity":
308312
return self.similarity_search(query, **kwargs)
309313
elif search_type == "similarity_score_threshold":
310-
docs_and_similarities = self.similarity_search_with_relevance_scores(query, **kwargs)
314+
docs_and_similarities = self.similarity_search_with_relevance_scores(
315+
query, **kwargs
316+
)
311317
return [doc for doc, _ in docs_and_similarities]
312318
elif search_type == "mmr":
313319
return self.max_marginal_relevance_search(query, **kwargs)
@@ -322,7 +328,9 @@ def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
322328
"'mmr' or 'traversal'."
323329
)
324330

325-
async def asearch(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
331+
async def asearch(
332+
self, query: str, search_type: str, **kwargs: Any
333+
) -> List[Document]:
326334
if search_type == "similarity":
327335
return await self.asimilarity_search(query, **kwargs)
328336
elif search_type == "similarity_score_threshold":
@@ -420,7 +428,9 @@ def _get_relevant_documents(
420428
if self.search_type == "traversal":
421429
return list(self.vectorstore.traversal_search(query, **self.search_kwargs))
422430
elif self.search_type == "mmr_traversal":
423-
return list(self.vectorstore.mmr_traversal_search(query, **self.search_kwargs))
431+
return list(
432+
self.vectorstore.mmr_traversal_search(query, **self.search_kwargs)
433+
)
424434
else:
425435
return super()._get_relevant_documents(query, run_manager=run_manager)
426436

@@ -430,12 +440,18 @@ async def _aget_relevant_documents(
430440
if self.search_type == "traversal":
431441
return [
432442
doc
433-
async for doc in self.vectorstore.atraversal_search(query, **self.search_kwargs)
443+
async for doc in self.vectorstore.atraversal_search(
444+
query, **self.search_kwargs
445+
)
434446
]
435447
elif self.search_type == "mmr_traversal":
436448
return [
437449
doc
438-
async for doc in self.vectorstore.ammr_traversal_search(query, **self.search_kwargs)
450+
async for doc in self.vectorstore.ammr_traversal_search(
451+
query, **self.search_kwargs
452+
)
439453
]
440454
else:
441-
return await super()._aget_relevant_documents(query, run_manager=run_manager)
455+
return await super()._aget_relevant_documents(
456+
query, run_manager=run_manager
457+
)

libs/knowledge-store/ragstack_knowledge_store/cassandra.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -338,26 +338,35 @@ def add_nodes(
338338
id = metadata[CONTENT_ID]
339339
ids.append(id)
340340

341-
link_to_tags = set() # link to these tags
342-
link_from_tags = set() # link from these tags
341+
link_to_tags = set() # link to these tags
342+
link_from_tags = set() # link from these tags
343343

344344
for tag in get_link_tags(metadata):
345345
tag_str = f"{tag.kind}:{tag.tag}"
346346
if tag.direction == "incoming" or tag.direction == "bidir":
347347
# An incom`ing link should be linked *from* nodes with the given tag.
348348
link_from_tags.add(tag_str)
349-
tag_to_new_targets.setdefault(tag_str, dict())[id] = (tag.kind, text_embedding)
349+
tag_to_new_targets.setdefault(tag_str, dict())[id] = (
350+
tag.kind,
351+
text_embedding,
352+
)
350353
if tag.direction == "outgoing" or tag.direction == "bidir":
351354
link_to_tags.add(tag_str)
352-
tag_to_new_sources.setdefault(tag_str, list()).append((tag.kind, id))
355+
tag_to_new_sources.setdefault(tag_str, list()).append(
356+
(tag.kind, id)
357+
)
353358

354-
cq.execute(self._insert_passage, (id, text, text_embedding, link_to_tags, link_from_tags))
359+
cq.execute(
360+
self._insert_passage,
361+
(id, text, text_embedding, link_to_tags, link_from_tags),
362+
)
355363

356364
# Step 2: Query information about those tags to determine the edges to add.
357365
# Add edges as needed.
358366
id_set = set(ids)
359367
with self._concurrent_queries() as cq:
360368
edges = []
369+
361370
def add_edge(source_id, target_id, kind, target_embedding):
362371
nonlocal added_edges
363372
if source_id == target_id:
@@ -399,27 +408,31 @@ def add_edges_for_targets(
399408
# Don't add here (will be handled later).
400409
continue
401410

402-
for (kind, source_id) in sources:
403-
add_edge(source_id, target.content_id, kind, target.text_embedding)
411+
for kind, source_id in sources:
412+
add_edge(
413+
source_id, target.content_id, kind, target.text_embedding
414+
)
404415

405416
for tag, new_target_embs in tag_to_new_targets.items():
406417
# For each new node with a `link_from_tag`, find the source
407418
# nodes with that `link_to_tag`` and create the edges.
408419
cq.execute(
409420
self._query_ids_by_link_to_tag,
410-
parameters=(tag, ),
421+
parameters=(tag,),
411422
callback=lambda sources, targets=new_target_embs: add_edges_for_sources(
412-
sources, targets)
423+
sources, targets
424+
),
413425
)
414426

415427
for tag, new_sources in tag_to_new_sources.items():
416428
# For each new node with a `link_to_tag`, find the target
417429
# nodes with that `link_from_tag` tag and create the edges.
418430
cq.execute(
419431
self._query_ids_and_embedding_by_link_from_tag,
420-
parameters=(tag, ),
432+
parameters=(tag,),
421433
callback=lambda targets, sources=new_sources: add_edges_for_targets(
422-
sources, targets)
434+
sources, targets
435+
),
423436
)
424437

425438
# Step 3: Add edges.
@@ -429,27 +442,29 @@ def add_edges_for_targets(
429442
# more than |max concurency| edges.
430443
added_edges = 0
431444
with self._concurrent_queries() as cq:
432-
print("Adding edges")
433445
# Add edges from query results (should be one new node and one old node)
434446
for edge in edges:
435447
added_edges += 1
436448
cq.execute(self._insert_edge, edge)
437449

438450
# Add edges for new nodes
439451
for tag, new_sources in tag_to_new_sources.items():
440-
for (kind, source_id) in new_sources:
452+
for kind, source_id in new_sources:
441453
new_targets = tag_to_new_targets.get(tag, None)
442454
if new_targets is None:
443455
continue
444456

445-
for (target_id, (target_kind, target_embedding)) in new_targets.items():
457+
for target_id, (
458+
target_kind,
459+
target_embedding,
460+
) in new_targets.items():
446461
# TODO: Improve the structures so this can be a lookup?
447462
if target_kind == kind and source_id != target_id:
448463
added_edges += 1
449-
cq.execute(self._insert_edge,
450-
(source_id, target_id, kind, target_embedding))
451-
452-
print(f"Added {added_edges} edges")
464+
cq.execute(
465+
self._insert_edge,
466+
(source_id, target_id, kind, target_embedding),
467+
)
453468

454469
return ids
455470

@@ -530,7 +545,7 @@ def mmr_traversal_search(
530545
depth: int = 2,
531546
fetch_k: int = 100,
532547
lambda_mult: float = 0.5,
533-
score_threshold: float = float('-inf'),
548+
score_threshold: float = float("-inf"),
534549
) -> Iterable[Document]:
535550
"""Retrieve documents from this knowledge store using MMR-traversal.
536551
@@ -588,7 +603,7 @@ def mmr_traversal_search(
588603
selected_embedding = next_selected.embedding
589604
selected_embeddings.append(selected_embedding)
590605

591-
best_score = float('-inf')
606+
best_score = float("-inf")
592607
next_id = None
593608

594609
# Update unselected scores.

libs/knowledge-store/ragstack_knowledge_store/edge_extractor.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
from __future__ import annotations
22

33
import abc
4-
import dataclasses
54
from abc import ABC, abstractmethod
6-
from typing import Any, AsyncIterator, Dict, Generic, Iterable, Iterator, Literal, Set, Sequence, TypeVar, Union
7-
8-
import asyncstdlib
9-
from langchain_core.runnables import run_in_executor
10-
from langchain_core.documents import Document, BaseDocumentTransformer
5+
from typing import (
6+
Any,
7+
Dict,
8+
Generic,
9+
Iterable,
10+
Iterator,
11+
Literal,
12+
Set,
13+
TypeVar,
14+
Union,
15+
)
16+
17+
from langchain_core.documents import Document
1118
from pydantic import BaseModel
1219
from ._utils import strict_zip
1320

@@ -20,17 +27,22 @@ class LinkTag(BaseModel, abc.ABC):
2027
def __hash__(self):
2128
return hash((type(self),) + tuple(self.__dict__.values()))
2229

30+
2331
class OutgoingLinkTag(LinkTag):
2432
direction: Literal["outgoing"] = "outgoing"
2533

34+
2635
class IncomingLinkTag(LinkTag):
2736
direction: Literal["incoming"] = "incoming"
2837

38+
2939
class BidirLinkTag(LinkTag):
3040
direction: Literal["bidir"] = "bidir"
3141

42+
3243
LINK_TAGS = "link_tags"
3344

45+
3446
def get_link_tags(doc_or_md: Union[Document, Dict[str, Any]]) -> Set[LinkTag]:
3547
"""Get the link-tag set from a document or metadata.
3648
@@ -49,7 +61,10 @@ def get_link_tags(doc_or_md: Union[Document, Dict[str, Any]]) -> Set[LinkTag]:
4961
doc_or_md[LINK_TAGS] = link_tags
5062
return link_tags
5163

64+
5265
InputT = TypeVar("InputT")
66+
67+
5368
class EdgeExtractor(ABC, Generic[InputT]):
5469
@abstractmethod
5570
def extract_one(self, document: Document, input: InputT):
@@ -60,12 +75,14 @@ def extract_one(self, document: Document, input: InputT):
6075
inputs: The input content to extract edges from.
6176
"""
6277

63-
def extract(self, documents: Iterable[Document], inputs: Iterable[InputT]) -> Iterator[Set[LinkTag]]:
78+
def extract(
79+
self, documents: Iterable[Document], inputs: Iterable[InputT]
80+
) -> Iterator[Set[LinkTag]]:
6481
"""Add edges from each `input` to the corresponding documents.
6582
6683
Args:
6784
documents: The documents to add the link tags to.
6885
inputs: The input content to extract edges from.
6986
"""
70-
for (document, input) in strict_zip(documents, inputs):
71-
self.extract_one(document, input)
87+
for document, input in strict_zip(documents, inputs):
88+
self.extract_one(document, input)

0 commit comments

Comments
 (0)