Skip to content

Commit 817ac3c

Browse files
authored
Merge branch 'develop' into feature-extractor-example
2 parents 5d35f1c + 4a1940d commit 817ac3c

File tree

4 files changed

+54
-8
lines changed

4 files changed

+54
-8
lines changed

tests/test_patch_extraction.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,3 +663,23 @@ def test_mask_based_patch_extractor_ndpi(
663663
len_region1 = len(patches)
664664

665665
assert len_all > len_region2 > len_region1
666+
667+
668+
def test_invalid_points_type() -> None:
669+
"""Test invalid locations_list type for PointsPatchExtractor."""
670+
img = np.zeros((256, 256, 3))
671+
coords = [[10, 10]]
672+
msg = "Please input correct locations_list. "
673+
msg += "Supported types: np.ndarray, DataFrame, str, Path."
674+
with pytest.raises(
675+
TypeError,
676+
match=msg,
677+
):
678+
patchextraction.get_patch_extractor(
679+
"point", input_img=img, locations_list=coords, patch_size=38
680+
)
681+
682+
patches = patchextraction.get_patch_extractor(
683+
"point", input_img=img, locations_list=np.array(coords), patch_size=38
684+
)
685+
assert len(patches) > 0

tiatoolbox/models/architecture/utils.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import sys
6-
from typing import NoReturn
76

87
import numpy as np
98
import torch
@@ -12,13 +11,17 @@
1211
from tiatoolbox import logger
1312

1413

15-
def is_torch_compile_compatible() -> NoReturn:
14+
def is_torch_compile_compatible() -> bool:
1615
"""Check if the current GPU is compatible with torch-compile.
1716
17+
Returns:
18+
True if current GPU is compatible with torch-compile, False otherwise.
19+
1820
Raises:
1921
Warning if GPU is not compatible with `torch.compile`.
2022
2123
"""
24+
gpu_compatibility = True
2225
if torch.cuda.is_available(): # pragma: no cover
2326
device_cap = torch.cuda.get_device_capability()
2427
if device_cap not in ((7, 0), (8, 0), (9, 0)):
@@ -28,13 +31,17 @@ def is_torch_compile_compatible() -> NoReturn:
2831
"Speedup numbers may be lower than expected.",
2932
stacklevel=2,
3033
)
34+
gpu_compatibility = False
3135
else:
3236
logger.warning(
3337
"No GPU detected or cuda not installed, "
3438
"torch.compile is only supported on selected NVIDIA GPUs. "
3539
"Speedup numbers may be lower than expected.",
3640
stacklevel=2,
3741
)
42+
gpu_compatibility = False
43+
44+
return gpu_compatibility
3845

3946

4047
def compile_model(
@@ -68,12 +75,24 @@ def compile_model(
6875
return model
6976

7077
# Check if GPU is compatible with torch.compile
71-
is_torch_compile_compatible()
78+
gpu_compatibility = is_torch_compile_compatible()
79+
80+
if not gpu_compatibility:
81+
return model
82+
83+
if sys.platform == "win32": # pragma: no cover
84+
msg = (
85+
"`torch.compile` is not supported on Windows. Please see "
86+
"https://github.com/pytorch/pytorch/issues/122094."
87+
)
88+
logger.warning(msg=msg)
89+
return model
7290

7391
# This check will be removed when torch.compile is supported in Python 3.12+
7492
if sys.version_info > (3, 12): # pragma: no cover
93+
msg = "torch-compile is currently not supported in Python 3.12+."
7594
logger.warning(
76-
("torch-compile is currently not supported in Python 3.12+. ",),
95+
msg=msg,
7796
)
7897
return model
7998

tiatoolbox/tools/patchextraction.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from tiatoolbox import logger
1212
from tiatoolbox.utils import misc
13-
from tiatoolbox.utils.exceptions import MethodNotSupportedError
13+
from tiatoolbox.utils.exceptions import FileNotSupportedError, MethodNotSupportedError
1414
from tiatoolbox.utils.visualization import AnnotationRenderer
1515
from tiatoolbox.wsicore import wsireader
1616

@@ -772,7 +772,12 @@ def __init__(
772772
pad_constant_values=pad_constant_values,
773773
within_bound=within_bound,
774774
)
775-
self.locations_df = misc.read_locations(input_table=locations_list)
775+
try:
776+
self.locations_df = misc.read_locations(input_table=locations_list)
777+
except (TypeError, FileNotSupportedError) as exc:
778+
msg = "Please input correct locations_list. "
779+
msg += "Supported types: np.ndarray, DataFrame, str, Path."
780+
raise TypeError(msg) from exc
776781
self.locations_df["x"] = self.locations_df["x"] - int(
777782
(self.patch_size[1] - 1) / 2,
778783
)

tiatoolbox/utils/misc.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def read_locations(
531531
out_table = pd.read_json(input_table)
532532
return __assign_unknown_class(out_table)
533533

534-
msg = "File type not supported."
534+
msg = "File type not supported. Supported types: .npy, .csv, .json"
535535
raise FileNotSupportedError(msg)
536536

537537
if isinstance(input_table, np.ndarray):
@@ -540,7 +540,9 @@ def read_locations(
540540
if isinstance(input_table, pd.DataFrame):
541541
return __assign_unknown_class(input_table)
542542

543-
msg = "Please input correct image path or an ndarray image."
543+
msg = "File type not supported. "
544+
msg += "Supported types: str, Path, PathLike, np.ndarray, pd.DataFrame"
545+
544546
raise TypeError(msg)
545547

546548

0 commit comments

Comments
 (0)