Skip to content

Commit 5c8d1ba

Browse files
committed
feat: typing cleanup
1 parent 2e8b5b0 commit 5c8d1ba

File tree

2 files changed

+37
-39
lines changed

2 files changed

+37
-39
lines changed

src/brainweb_dl/_brainweb.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,10 @@
1818

1919
from __future__ import annotations
2020

21-
from typing import TYPE_CHECKING, Any
22-
23-
if TYPE_CHECKING:
24-
from _typeshed import AnyPath
25-
2621
import csv
27-
import re
2822
import logging
2923
import os
24+
import re
3025
import sys
3126

3227
if sys.version_info > (3, 9):
@@ -48,6 +43,8 @@
4843

4944
logger = logging.getLogger("brainweb_dl")
5045

46+
GenericPath = os.PathLike[str] | str
47+
5148

5249
class ContainsEnumMeta(EnumMeta):
5350
"""Metaclass for case insensitive Enum."""
@@ -70,7 +67,7 @@ class MyEnum(str, Enum, metaclass=ContainsEnumMeta):
7067
"""Enum with case insensitive comparison."""
7168

7269
@classmethod
73-
def _missing_(cls, value: Any) -> None | MyEnum:
70+
def _missing_(cls, value) -> None | MyEnum: # noqa
7471
if isinstance(value, str):
7572
value = value.upper()
7673
value = value.replace("*", "s")
@@ -179,12 +176,12 @@ def get_brainweb_dir(brainweb_dir: BrainWebDirType = None) -> Path:
179176
180177
Parameters
181178
----------
182-
brainweb_dir : AnyPath
179+
brainweb_dir : GenericPath
183180
brainweb_directory to download the data.
184181
185182
Returns
186183
-------
187-
AnyPath
184+
GenericPath
188185
Path to brainweb_dir
189186
190187
Notes
@@ -210,7 +207,7 @@ def get_brainweb20_multiple(
210207
brainweb_dir: BrainWebDirType = None,
211208
force: bool = False,
212209
segmentation: Segmentation = Segmentation.CRISP,
213-
) -> AnyPath: ...
210+
) -> GenericPath: ...
214211

215212

216213
@overload
@@ -219,30 +216,30 @@ def get_brainweb20_multiple(
219216
brainweb_dir: BrainWebDirType = None,
220217
force: bool = False,
221218
segmentation: Segmentation = Segmentation.CRISP,
222-
) -> list[AnyPath]: ...
219+
) -> list[GenericPath]: ...
223220

224221

225222
def get_brainweb20_multiple(
226223
subject: int | str | Literal["all"] | list[int | str],
227224
brainweb_dir: BrainWebDirType = None,
228225
force: bool = False,
229226
segmentation: Segmentation = Segmentation.CRISP,
230-
) -> list[AnyPath] | os.PathLike:
227+
) -> list[GenericPath] | GenericPath:
231228
"""Download sample or all brainweb subjects.
232229
233230
Parameters
234231
----------
235232
subject : int | list | Literal["all"]
236233
subject id or list of subject id to download.
237234
If "all", download all subjects.
238-
brainweb_dir : AnyPath
235+
brainweb_dir : GenericPath
239236
brainweb_directory to download the data.
240237
force : bool
241238
force download even if the file already exists.
242239
243240
Returns
244241
-------
245-
list[AnyPath]
242+
list[GenericPath]
246243
list of downloaded files.
247244
"""
248245
if subject == "all":
@@ -256,7 +253,7 @@ def get_brainweb20_multiple(
256253
"subject must be int, a str, a list of int or string or 'all'"
257254
)
258255
if len(_subject) > 1:
259-
f: list[AnyPath] = []
256+
f: list[GenericPath] = []
260257
pbar = tqdm(total=len(_subject), desc="Downloading Brainweb phantoms")
261258
for s in _subject:
262259
f.append(get_brainweb20(_sub_id(s), brainweb_dir, force, segmentation))
@@ -272,14 +269,14 @@ def get_brainweb20(
272269
force: bool = False,
273270
segmentation: Segmentation = Segmentation.CRISP,
274271
extension: Literal["nii.gz", "nii"] = "nii.gz",
275-
) -> AnyPath:
272+
) -> GenericPath:
276273
"""Download one subject of brainweb dataset.
277274
278275
Parameters
279276
----------
280277
s : int
281278
subject id.
282-
brainweb_dir : AnyPath
279+
brainweb_dir : GenericPath
283280
brainweb_directory to download the data.
284281
force : bool
285282
force download even if the file already exists.
@@ -290,7 +287,7 @@ def get_brainweb20(
290287
291288
Returns
292289
-------
293-
AnyPath
290+
GenericPath
294291
Path to downloaded file.
295292
"""
296293
s = _sub_id(s)
@@ -346,14 +343,14 @@ def get_brainweb20_T1(
346343
brainweb_dir: BrainWebDirType = None,
347344
force: bool = False,
348345
extension: Literal["nii.gz", "nii"] = "nii.gz",
349-
) -> AnyPath:
346+
) -> GenericPath:
350347
"""Download the Brainweb20 T1 Phantom.
351348
352349
Parameters
353350
----------
354351
s : int
355352
subject id.
356-
brainweb_dir : AnyPath
353+
brainweb_dir : GenericPath
357354
brainweb_directory to download the data.
358355
force : bool
359356
force download even if the file already exists.
@@ -362,7 +359,7 @@ def get_brainweb20_T1(
362359
363360
Returns
364361
-------
365-
AnyPath
362+
GenericPath
366363
Path to downloaded file.
367364
368365
Notes
@@ -397,7 +394,7 @@ def get_brainweb1(
397394
brainweb_dir: BrainWebDirType = None,
398395
force: bool = False,
399396
extension: Literal["nii.gz", "nii"] = "nii.gz",
400-
) -> AnyPath:
397+
) -> GenericPath:
401398
"""Download the Brainweb1 phantom as a nifti file.
402399
403400
Parameters
@@ -411,7 +408,7 @@ def get_brainweb1(
411408
{0, 1, 3, 5, 7, 9}
412409
field_value : int
413410
RF field value in the phantom. Must be in {0, 20, 40}
414-
brainweb_dir : AnyPath
411+
brainweb_dir : GenericPath
415412
brainweb_directory to download the data.
416413
force : bool
417414
force download even if the file already exists.
@@ -420,7 +417,7 @@ def get_brainweb1(
420417
421418
Returns
422419
-------
423-
AnyPath
420+
GenericPath
424421
Path to downloaded file.
425422
426423
Notes
@@ -462,7 +459,7 @@ def get_brainweb1_seg(
462459
extension: Literal["nii.gz", "nii"] = "nii.gz",
463460
brainweb_dir: BrainWebDirType = None,
464461
force: bool = False,
465-
) -> AnyPath:
462+
) -> GenericPath:
466463
"""Download the Brainweb1 phantom segmentation as a nifti file."""
467464
brainweb_dir = get_brainweb_dir(brainweb_dir)
468465
try:
@@ -562,7 +559,7 @@ def _request_get_brainweb_affine(download_cmd: str) -> NDArray:
562559

563560
def _request_get_brainweb(
564561
download_command: str,
565-
path: AnyPath,
562+
path: GenericPath,
566563
force: bool = False,
567564
dtype: DTypeLike = np.float32,
568565
shape: tuple = STD_RES_SHAPE,
@@ -573,7 +570,7 @@ def _request_get_brainweb(
573570
----------
574571
do_download_alias : str
575572
Formatted request code to download a volume from brainweb.
576-
path : AnyPath
573+
path : GenericPath
577574
Path to save the downloaded file.
578575
force : bool
579576
Force download even if the file already exists.
@@ -586,7 +583,7 @@ def _request_get_brainweb(
586583
587584
Returns
588585
-------
589-
AnyPath
586+
GenericPath
590587
Path to downloaded file.
591588
592589
Raises
@@ -653,12 +650,12 @@ def _request_get_brainweb(
653650
return (data, affine)
654651

655652

656-
def _load_tissue_map(tissue_map: AnyPath) -> list[dict]:
653+
def _load_tissue_map(tissue_map: GenericPath) -> list[dict]:
657654
with open(tissue_map) as csvfile:
658655
return list(csv.DictReader(csvfile))
659656

660657

661-
def save_array(data: NDArray, affine: NDArray | None, path: AnyPath) -> os.PathLike:
658+
def save_array(data: NDArray, affine: NDArray | None, path: GenericPath) -> GenericPath:
662659
path_ = Path(path)
663660
if path_.suffix == ".npy":
664661
np.save(path_, data)
@@ -669,7 +666,7 @@ def save_array(data: NDArray, affine: NDArray | None, path: AnyPath) -> os.PathL
669666
return path
670667

671668

672-
def load_array(path: AnyPath, dtype: DTypeLike = None) -> tuple[NDArray, NDArray]:
669+
def load_array(path: GenericPath, dtype: DTypeLike = None) -> tuple[NDArray, NDArray]:
673670
path_ = Path(path)
674671
if path_.suffix == ".npy":
675672
data = np.load(path_)

src/brainweb_dl/mri.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626

2727
logger = logging.getLogger("brainweb_dl")
28+
GenericPath = os.PathLike[str] | str
2829

2930

3031
def _get_mri_sub0(
@@ -34,7 +35,7 @@ def _get_mri_sub0(
3435
noise: int = 0,
3536
field_value: int = 0,
3637
force: bool = False,
37-
tissue_map: os.PathLike = BrainWebTissueMap.v1,
38+
tissue_map: GenericPath = BrainWebTissueMap.v1,
3839
rng: int | np.random.Generator | None = None,
3940
) -> np.ndarray:
4041
if contrast in [Contrast.T1, Contrast.T2, Contrast.PD]:
@@ -76,27 +77,27 @@ def _get_mri_sub20(
7677
sub_id: int | str,
7778
brainweb_dir: BrainWebDirType = None,
7879
force: bool = False,
79-
tissue_map: os.PathLike = BrainWebTissueMap.v2,
80+
tissue_map: GenericPath = BrainWebTissueMap.v2,
8081
rng: int | np.random.Generator | None = None,
8182
) -> tuple[NDArray, NDArray]:
8283
if contrast is Contrast.T1:
8384
filename = get_brainweb20_T1(sub_id, brainweb_dir=brainweb_dir, force=force)
8485
nft = nifti.Nifti1Image.from_filename(filename)
85-
data, affine = nft.get_fdata(), nft.affine
86+
data, affine = np.asarray(nft.get_fdata()), np.asarray(nft.affine)
8687
elif contrast in Segmentation:
8788
filename = get_brainweb20(
8889
sub_id, segmentation=Segmentation(contrast), force=force
8990
)
9091
nft = nifti.Nifti1Image.from_filename(filename)
9192
data = np.asanyarray(nft.dataobj, dtype=np.uint16)
92-
affine = nft.affine
93+
affine = np.asarray(nft.affine)
9394
if contrast is Segmentation.FUZZY:
9495
data = data.astype(np.float32) / 4095
9596
else:
9697
filename = get_brainweb20(sub_id, segmentation=Segmentation.FUZZY, force=force)
9798
tissue_map = tissue_map or BrainWebTissueMap.v2
9899
data = _apply_contrast(filename, tissue_map, Contrast(contrast), rng)
99-
affine = nifti.Nifti1Image.from_filename(filename).affine
100+
affine = np.asarray(nifti.Nifti1Image.from_filename(filename).affine)
100101
return (data, affine)
101102

102103

@@ -112,7 +113,7 @@ def get_mri(
112113
field_value: int = 0,
113114
force: bool = False,
114115
with_affine: bool = False,
115-
tissue_map: os.PathLike | None = None,
116+
tissue_map: GenericPath | None = None,
116117
rng: int | np.random.Generator | None = None,
117118
) -> tuple[NDArray, NDArray] | NDArray:
118119
"""Get MRI data from a brainweb fuzzy segmentation.
@@ -246,8 +247,8 @@ def _crop_data(data: np.ndarray, bbox: tuple[float | None, ...]) -> np.ndarray:
246247

247248

248249
def _apply_contrast(
249-
file_fuzzy: os.PathLike,
250-
tissue_map: os.PathLike,
250+
file_fuzzy: GenericPath,
251+
tissue_map: GenericPath,
251252
contrast: Contrast,
252253
rng: int | np.random.Generator | None,
253254
) -> np.ndarray:

0 commit comments

Comments
 (0)