Skip to content

Commit c180566

Browse files
Jiaqi-Lvpre-commit-ci[bot]shaneahmed
authored
Add mypy type check to annotation/ (#806)
- [x] dsl.py - [x] storage.py - Fix mypy checks under `utils/` - Add new `tests`. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Shan E Ahmed Raza <[email protected]>
1 parent 973378b commit c180566

File tree

13 files changed

+493
-208
lines changed

13 files changed

+493
-208
lines changed

.github/workflows/mypy-type-check.yml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@ jobs:
2828
- name: Checkout repository
2929
uses: actions/checkout@v3
3030

31-
- name: Setup mypy
31+
- name: Install dependencies
3232
run: |
33-
pip install mypy
33+
sudo apt update
34+
sudo apt-get install -y libopenslide-dev openslide-tools libopenjp2-7 libopenjp2-tools
35+
python -m pip install --upgrade pip
36+
pip install -r requirements/requirements_dev.txt
3437
3538
- name: Perform type checking
3639
run: |
@@ -39,5 +42,7 @@ jobs:
3942
tiatoolbox/__main__.py \
4043
tiatoolbox/typing.py \
4144
tiatoolbox/tiatoolbox.py \
42-
tiatoolbox/utils/ \
43-
tiatoolbox/tools/
45+
tiatoolbox/utils \
46+
tiatoolbox/tools \
47+
tiatoolbox/data \
48+
tiatoolbox/annotation

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ ENV/
109109

110110
# IDE settings
111111
.vscode/
112+
.idea/
112113

113114
# Mac generated
114115
.DS_Store

tests/test_annotation_stores.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,21 @@ def test_sqlite_store_compile_options_exception(monkeypatch: object) -> None:
504504
SQLiteStore()
505505

506506

507+
def test_sqlite_store_compile_options_exception_json_rtree(monkeypatch: object) -> None:
508+
"""Test SQLiteStore compile options for exceptions."""
509+
monkeypatch.setattr(sqlite3, "sqlite_version_info", (3, 38, 0))
510+
monkeypatch.setattr(
511+
SQLiteStore,
512+
"compile_options",
513+
lambda _x: ["ENABLE_RTREE"],
514+
raising=True,
515+
)
516+
SQLiteStore()
517+
monkeypatch.setattr(SQLiteStore, "compile_options", lambda _x: [], raising=True)
518+
with pytest.raises(EnvironmentError, match="RTREE sqlite3"):
519+
SQLiteStore()
520+
521+
507522
def test_sqlite_store_compile_options_exception_v3_38(monkeypatch: object) -> None:
508523
"""Test SQLiteStore compile options for exceptions."""
509524
monkeypatch.setattr(sqlite3, "sqlite_version_info", (3, 38, 0))
@@ -1326,11 +1341,16 @@ def test_from_geojson_path_transform(
13261341
_, store = fill_store(store_cls, tmp_path / "polygon.db")
13271342
com = annotations_center_of_mass(list(store.values()))
13281343
store.to_geojson(tmp_path / "polygon.json")
1344+
13291345
# load the store translated so that origin is (100,100) and scaled by 2
1346+
def dummy_transform(annotation: Annotation) -> Annotation:
1347+
return annotation
1348+
13301349
store2 = store_cls.from_geojson(
13311350
tmp_path / "polygon.json",
13321351
scale_factor=(2, 2),
13331352
origin=(100, 100),
1353+
transform=dummy_transform,
13341354
)
13351355
assert len(store) == len(store2)
13361356
com2 = annotations_center_of_mass(list(store2.values()))

tests/test_stainnorm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,16 @@ def test_get_normalizer_assertion() -> None:
108108
_ = get_normalizer("ruifrok", stain_matrix)
109109

110110

111+
def test_get_custom_normalizer_assertion() -> None:
112+
"""Test get custom normalizer assertion error."""
113+
stain_matrix = None
114+
with pytest.raises(
115+
ValueError,
116+
match=r"`stain_matrix` is None when using `method_name`=\"custom\".",
117+
):
118+
_ = get_normalizer("custom", stain_matrix)
119+
120+
111121
def test_ruifrok_normalize(source_image: Path, norm_ruifrok: Path) -> None:
112122
"""Test for stain normalization with stain matrix from Ruifrok and Johnston."""
113123
source_img = imread(Path(source_image))

tests/test_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from tests.test_annotation_stores import cell_polygon
2121
from tiatoolbox import utils
22-
from tiatoolbox.annotation.storage import SQLiteStore
22+
from tiatoolbox.annotation.storage import DictionaryStore, SQLiteStore
2323
from tiatoolbox.models.architecture import fetch_pretrained_weights
2424
from tiatoolbox.utils import misc
2525
from tiatoolbox.utils.exceptions import FileNotSupportedError
@@ -1524,7 +1524,15 @@ def test_from_dat(tmp_path: Path) -> None:
15241524
"""Test generating an annotation store from a .dat file."""
15251525
data = make_simple_dat()
15261526
joblib.dump(data, tmp_path / "test.dat")
1527-
store = utils.misc.store_from_dat(tmp_path / "test.dat")
1527+
store = utils.misc.store_from_dat(tmp_path / "test.dat", cls=SQLiteStore)
1528+
assert len(store) == 2
1529+
1530+
1531+
def test_dict_store_from_dat(tmp_path: Path) -> None:
1532+
"""Test generating a DictionaryStore from a .dat file."""
1533+
data = make_simple_dat()
1534+
joblib.dump(data, tmp_path / "test.dat")
1535+
store = utils.misc.store_from_dat(tmp_path / "test.dat", cls=DictionaryStore)
15281536
assert len(store) == 2
15291537

15301538

tiatoolbox/annotation/dsl.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@
6565
from numbers import Number
6666
from typing import Callable
6767

68+
from typing_extensions import TypedDict
69+
6870

6971
@dataclass
7072
class SQLNone:
@@ -82,7 +84,9 @@ def __repr__(self: SQLNone) -> str:
8284
class SQLExpression:
8385
"""SQL expression base class."""
8486

85-
__hash__ = None
87+
def __hash__(self: SQLExpression) -> int:
88+
"""Return hash of the object (Not used)."""
89+
return hash(str(self)) # pragma: no cover
8690

8791
def __repr__(self: SQLExpression) -> str:
8892
"""Return a string representation of the object."""
@@ -156,19 +160,23 @@ def __abs__(self: SQLExpression) -> SQLTriplet:
156160
"""Return the absolute value of the object."""
157161
return SQLTriplet(self, operator.abs)
158162

159-
def __eq__(self: SQLExpression, other: SQLExpression) -> SQLTriplet:
163+
def __eq__( # type: ignore[override]
164+
self: SQLExpression, other: object
165+
) -> SQLTriplet:
160166
"""Define how the object is compared for equality."""
161167
return SQLTriplet(self, operator.eq, other)
162168

163-
def __ne__(self: SQLExpression, other: object) -> SQLTriplet:
169+
def __ne__( # type: ignore[override]
170+
self: SQLExpression, other: object
171+
) -> SQLTriplet:
164172
"""Define how the object is compared for equality (not equal to)."""
165173
return SQLTriplet(self, operator.ne, other)
166174

167175
def __neg__(self: SQLExpression) -> SQLTriplet:
168176
"""Define how the object is compared for negation (not equal to)."""
169177
return SQLTriplet(self, operator.neg)
170178

171-
def __contains__(self: SQLExpression, other: object) -> bool:
179+
def __contains__(self: SQLExpression, other: object) -> SQLTriplet:
172180
"""Test whether the object contains the specified object or not."""
173181
return SQLTriplet(self, "contains", other)
174182

@@ -209,9 +217,9 @@ class SQLTriplet(SQLExpression):
209217

210218
def __init__(
211219
self: SQLExpression,
212-
lhs: SQLTriplet | str,
220+
lhs: SQLTriplet | str | SQLExpression | Number | bool | object,
213221
op: Callable | str | None = None,
214-
rhs: SQLTriplet | str | None = None,
222+
rhs: SQLTriplet | str | SQLExpression | Number | SQLNone | object | None = None,
215223
) -> None:
216224
"""Initialize :class:`SQLTriplet`."""
217225
self.lhs = lhs
@@ -244,7 +252,7 @@ def __init__(
244252
"bool": lambda x, _: f"({x} != 0)",
245253
}
246254

247-
def __str__(self: SQLExpression) -> str:
255+
def __str__(self: SQLTriplet) -> str:
248256
"""Return a human-readable, or informal, string representation of an object."""
249257
lhs = self.lhs
250258
rhs = self.rhs
@@ -268,7 +276,7 @@ def __str__(self: SQLJSONDictionary) -> str:
268276
"""Return a human-readable, or informal, string representation of an object."""
269277
return f"json_extract(properties, '$.{self.acc}')"
270278

271-
def __getitem__(self: SQLJSONDictionary, key: str) -> SQLJSONDictionary:
279+
def __getitem__(self: SQLJSONDictionary, key: str | int) -> SQLJSONDictionary:
272280
"""Get an item from the dataset."""
273281
key_str = f"[{key}]" if isinstance(key, (int,)) else f'"{key}"'
274282

@@ -424,11 +432,14 @@ def sql_has_key(dictionary: SQLJSONDictionary, key: str | int) -> SQLTriplet:
424432

425433
# Constants defining the global variables for use in eval() when
426434
# evaluating expressions.
427-
428-
_COMMON_GLOBALS = {
435+
COMMON_GLOBALS_Type = TypedDict(
436+
"COMMON_GLOBALS_Type", {"__builtins__": dict[str, Callable], "re": object}
437+
)
438+
_COMMON_GLOBALS: COMMON_GLOBALS_Type = {
429439
"__builtins__": {"abs": abs},
430440
"re": re.RegexFlag,
431441
}
442+
432443
SQL_GLOBALS = {
433444
"__builtins__": {**_COMMON_GLOBALS["__builtins__"], "sum": sql_list_sum},
434445
"props": SQLJSONDictionary(),

0 commit comments

Comments
 (0)