Skip to content

Commit 9efa3a3

Browse files
cbornetbjchambers
andauthored
Add type checking of test-utils and knowledge-store (#577)
* Fix some mypy issues in test-utils and knowledge-store * Changes following review * Run mypy in CI * suggestion * Fix lint * Fix issue with lambda callback typing * Check type of test-utils in CI --------- Co-authored-by: Ben Chambers <[email protected]>
1 parent 22f9aeb commit 9efa3a3

File tree

18 files changed

+211
-110
lines changed

18 files changed

+211
-110
lines changed

.github/changes-filter.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ notebooks:
1717
- "./.github/workflows/_run_e2e_tests.yml"
1818
integration_tests:
1919
- "libs/colbert/**"
20+
- "libs/knowledge-store/**"
2021
- "libs/llamaindex/**"
2122
- "libs/langchain/**"
2223
- "./.github/actions/**"

.github/workflows/ci-unit-tests.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ jobs:
8989
- name: "Type check (knowledge-graph)"
9090
run: tox -e type -c libs/knowledge-graph && rm -rf libs/knowledge-graph/.tox
9191

92+
- name: "Type check (knowledge-store)"
93+
run: tox -e type -c libs/knowledge-store && rm -rf libs/knowledge-store/.tox
94+
95+
- name: "Type check (test-utils)"
96+
run: tox -e type -c libs/tests-utils && rm -rf libs/tests-utils/.tox
9297

9398
unit-tests:
9499
name: Unit Tests (Python ${{ matrix.python-version }})

libs/knowledge-store/pyproject.toml

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,21 @@ requires = ["poetry-core"]
3838
build-backend = "poetry.core.masonry.api"
3939

4040
[tool.mypy]
41-
strict = true
42-
warn_unreachable = true
43-
pretty = true
44-
show_column_numbers = true
41+
disallow_any_generics = true
42+
disallow_incomplete_defs = true
43+
disallow_untyped_calls = true
44+
disallow_untyped_decorators = true
45+
disallow_untyped_defs = true
46+
follow_imports = "normal"
47+
ignore_missing_imports = true
48+
no_implicit_reexport = true
49+
show_error_codes = true
4550
show_error_context = true
51+
strict_equality = true
52+
strict_optional = true
53+
warn_redundant_casts = true
54+
warn_return_any = true
55+
warn_unused_ignores = true
4656

4757
[tool.pytest.ini_options]
4858
testpaths = ["tests"]

libs/knowledge-store/ragstack_knowledge_store/_mmr_helper.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
from typing import Dict, Iterable, List, Optional
33

44
import numpy as np
5+
from numpy.typing import NDArray
56

67
from ragstack_knowledge_store.math import cosine_similarity
78

89

9-
def _emb_to_ndarray(embedding: List[float]) -> np.ndarray:
10-
embedding = np.array(embedding, dtype=np.float32)
11-
if embedding.ndim == 1:
12-
embedding = np.expand_dims(embedding, axis=0)
13-
return embedding
10+
def _emb_to_ndarray(embedding: List[float]) -> NDArray[np.float32]:
11+
emb_array = np.array(embedding, dtype=np.float32)
12+
if emb_array.ndim == 1:
13+
emb_array = np.expand_dims(emb_array, axis=0)
14+
return emb_array
1415

1516

1617
NEG_INF = float("-inf")
@@ -23,10 +24,10 @@ class _Candidate:
2324
weighted_redundancy: float
2425
score: float = dataclasses.field(init=False)
2526

26-
def __post_init__(self):
27+
def __post_init__(self) -> None:
2728
self.score = self.weighted_similarity - self.weighted_redundancy
2829

29-
def update_redundancy(self, new_weighted_redundancy: float):
30+
def update_redundancy(self, new_weighted_redundancy: float) -> None:
3031
if new_weighted_redundancy > self.weighted_redundancy:
3132
self.weighted_redundancy = new_weighted_redundancy
3233
self.score = self.weighted_similarity - self.weighted_redundancy
@@ -47,7 +48,7 @@ class MmrHelper:
4748
dimensions: int
4849
"""Dimensions of the embedding."""
4950

50-
query_embedding: np.ndarray
51+
query_embedding: NDArray[np.float32]
5152
"""Embedding of the query as a (1,dim) ndarray."""
5253

5354
lambda_mult: float
@@ -64,7 +65,7 @@ class MmrHelper:
6465

6566
selected_ids: List[str]
6667
"""List of selected IDs (in selection order)."""
67-
selected_embeddings: np.ndarray
68+
selected_embeddings: NDArray[np.float32]
6869
"""(N, dim) ndarray with a row for each selected node."""
6970

7071
candidate_id_to_index: Dict[str, int]
@@ -74,7 +75,7 @@ class MmrHelper:
7475
7576
Same order as rows in `candidate_embeddings`.
7677
"""
77-
candidate_embeddings: np.ndarray
78+
candidate_embeddings: NDArray[np.float32]
7879
"""(N, dim) ndarray with a row for each candidate."""
7980

8081
best_score: float
@@ -113,12 +114,12 @@ def candidate_ids(self) -> Iterable[str]:
113114
"""Return the IDs of the candidates."""
114115
return self.candidate_id_to_index.keys()
115116

116-
def _already_selected_embeddings(self) -> np.ndarray:
117+
def _already_selected_embeddings(self) -> NDArray[np.float32]:
117118
"""Return the selected embeddings sliced to the already assigned values."""
118119
selected = len(self.selected_ids)
119120
return np.vsplit(self.selected_embeddings, [selected])[0]
120121

121-
def _pop_candidate(self, candidate_id: str) -> np.ndarray:
122+
def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
122123
"""Pop the candidate with the given ID.
123124
124125
Returns:
@@ -127,7 +128,7 @@ def _pop_candidate(self, candidate_id: str) -> np.ndarray:
127128
# Get the embedding for the id.
128129
index = self.candidate_id_to_index.pop(candidate_id)
129130
assert self.candidates[index].id == candidate_id
130-
embedding = self.candidate_embeddings[index].copy()
131+
embedding: NDArray[np.float32] = self.candidate_embeddings[index].copy()
131132

132133
# Swap that index with the last index in the candidates and
133134
# candidate_embeddings.
@@ -186,19 +187,21 @@ def pop_best(self) -> Optional[str]:
186187

187188
return selected_id
188189

189-
def add_candidates(self, candidates: Dict[str, List[float]]):
190+
def add_candidates(self, candidates: Dict[str, List[float]]) -> None:
190191
"""Add candidates to the consideration set."""
191192
# Determine the keys to actually include.
192193
# These are the candidates that aren't already selected
193194
# or under consideration.
194-
include_ids = set(candidates.keys())
195-
include_ids.difference_update(self.selected_ids)
196-
include_ids.difference_update(self.candidate_id_to_index.keys())
197-
include_ids = list(include_ids)
195+
include_ids_set = set(candidates.keys())
196+
include_ids_set.difference_update(self.selected_ids)
197+
include_ids_set.difference_update(self.candidate_id_to_index.keys())
198+
include_ids = list(include_ids_set)
198199

199200
# Now, build up a matrix of the remaining candidate embeddings.
200201
# And add them to the
201-
new_embeddings = np.ndarray((len(include_ids), self.dimensions))
202+
new_embeddings: NDArray[np.float32] = np.ndarray(
203+
(len(include_ids), self.dimensions)
204+
)
202205
offset = self.candidate_embeddings.shape[0]
203206
for index, candidate_id in enumerate(include_ids):
204207
if candidate_id in include_ids:

libs/knowledge-store/ragstack_knowledge_store/_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1+
from __future__ import annotations
2+
13
import sys
24

35
try:
46
# Try importing the function from itertools (Python 3.12+)
5-
from itertools import batched
7+
from itertools import batched # type: ignore[attr-defined]
68
except ImportError:
79
from itertools import islice
8-
from typing import Iterable, Iterator, TypeVar
10+
from typing import Any, Iterable, Iterator, TypeVar
911

1012
# Fallback implementation for older Python versions
1113

1214
T = TypeVar("T")
1315

1416
# This is equivalent to `itertools.batched`, but that is only available in 3.12
15-
def batched(iterable: Iterable[T], n: int) -> Iterator[Iterator[T]]:
17+
def batched(iterable: Iterable[T], n: int) -> Iterator[tuple[T, ...]]:
1618
if n < 1:
1719
raise ValueError("n must be at least one")
1820
it = iter(iterable)
@@ -24,12 +26,12 @@ def batched(iterable: Iterable[T], n: int) -> Iterator[Iterator[T]]:
2426

2527
if sys.version_info >= (3, 10):
2628

27-
def strict_zip(*iterables):
29+
def strict_zip(*iterables: Iterable[Any]) -> zip[tuple[Any, ...]]:
2830
return zip(*iterables, strict=True)
2931

3032
else:
3133

32-
def strict_zip(*iterables):
34+
def strict_zip(*iterables: Iterable[T]) -> zip[tuple[T]]:
3335
# Custom implementation for Python versions older than 3.10
3436
if not iterables:
3537
return

libs/knowledge-store/ragstack_knowledge_store/concurrency.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,43 @@
22
import logging
33
import threading
44
from types import TracebackType
5-
from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple, Type
5+
from typing import (
6+
Any,
7+
Callable,
8+
Literal,
9+
NamedTuple,
10+
Optional,
11+
Protocol,
12+
Sequence,
13+
Tuple,
14+
Type,
15+
)
616

717
from cassandra.cluster import ResponseFuture, Session
818
from cassandra.query import PreparedStatement
919

1020
logger = logging.getLogger(__name__)
1121

1222

13-
class ConcurrentQueries(contextlib.AbstractContextManager):
23+
class _Callback(Protocol):
24+
def __call__(self, rows: Sequence[Any], /) -> None: ...
25+
26+
27+
class ConcurrentQueries(contextlib.AbstractContextManager["ConcurrentQueries"]):
1428
"""Context manager for concurrent queries."""
1529

1630
def __init__(self, session: Session) -> None:
1731
self._session = session
1832
self._completion = threading.Condition()
19-
2033
self._pending = 0
21-
22-
self._error = None
34+
self._error: Optional[BaseException] = None
2335

2436
def _handle_result(
2537
self,
2638
result: Sequence[NamedTuple],
2739
future: ResponseFuture,
2840
callback: Optional[Callable[[Sequence[NamedTuple]], Any]],
29-
):
41+
) -> None:
3042
if callback is not None:
3143
callback(result)
3244

@@ -38,7 +50,7 @@ def _handle_result(
3850
if self._pending == 0:
3951
self._completion.notify()
4052

41-
def _handle_error(self, error, future: ResponseFuture):
53+
def _handle_error(self, error: BaseException, future: ResponseFuture) -> None:
4254
logger.error(
4355
"Error executing query: %s",
4456
future.query,
@@ -51,9 +63,9 @@ def _handle_error(self, error, future: ResponseFuture):
5163
def execute(
5264
self,
5365
query: PreparedStatement,
54-
parameters: Optional[Tuple] = None,
55-
callback: Optional[Callable[[Sequence[NamedTuple]], Any]] = None,
56-
):
66+
parameters: Optional[Tuple[Any, ...]] = None,
67+
callback: Optional[_Callback] = None,
68+
) -> None:
5769
"""Execute a query concurrently.
5870
5971
Because this is done concurrently, it expects a callback if you need
@@ -93,7 +105,7 @@ def __exit__(
93105
_exc_type: Optional[Type[BaseException]],
94106
_exc_inst: Optional[BaseException],
95107
_exc_traceback: Optional[TracebackType],
96-
) -> bool:
108+
) -> Literal[False]:
97109
with self._completion:
98110
while self._error is None and self._pending > 0:
99111
self._completion.wait()

0 commit comments

Comments
 (0)