Skip to content

Commit 18eb3b4

Browse files
authored
Merge branch 'dev-define-engines-abc' into dev-define-nucleus-instance-segmentor
2 parents a911d3c + 80e7af5 commit 18eb3b4

19 files changed

+2352
-228
lines changed

tests/engines/test_engine_abc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ def test_engine_run() -> NoReturn:
345345
on_gpu=False,
346346
)
347347

348+
eng = TestEngineABC(model="alexnet-kather100k")
348349
with pytest.raises(
349350
ValueError,
350351
match=r".*The shape of the numpy array should be NHWC*",
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
"""Test for feature extractor."""
2+
3+
import shutil
4+
from collections.abc import Callable
5+
from pathlib import Path
6+
7+
import numpy as np
8+
import pytest
9+
import torch
10+
import zarr
11+
from click.testing import CliRunner
12+
13+
from tiatoolbox import cli
14+
from tiatoolbox.models import IOPatchPredictorConfig
15+
from tiatoolbox.models.architecture.vanilla import CNNBackbone, TimmBackbone
16+
from tiatoolbox.models.engine.deep_feature_extractor import DeepFeatureExtractor
17+
from tiatoolbox.utils import env_detection as toolbox_env
18+
from tiatoolbox.wsicore.wsireader import WSIReader
19+
20+
ON_GPU = not toolbox_env.running_on_ci() and toolbox_env.has_gpu()
21+
22+
# -------------------------------------------------------------------------------------
23+
# Engine
24+
# -------------------------------------------------------------------------------------
25+
26+
device = "cuda" if toolbox_env.has_gpu() else "cpu"
27+
28+
29+
def test_feature_extractor_patches(
30+
remote_sample: Callable,
31+
) -> None:
32+
"""Tests DeepFeatureExtractor on image patches."""
33+
extractor = DeepFeatureExtractor(
34+
model="fcn-tissue_mask", batch_size=32, verbose=False, device=device
35+
)
36+
37+
sample_image = remote_sample("thumbnail-1k-1k")
38+
39+
inputs = [sample_image, sample_image]
40+
41+
assert not extractor.patch_mode
42+
output = extractor.run(
43+
images=inputs,
44+
return_probabilities=True,
45+
return_labels=False,
46+
device=device,
47+
patch_mode=True,
48+
)
49+
50+
assert 0.48 < np.mean(output["features"][:]) < 0.52
51+
52+
with pytest.raises(
53+
ValueError,
54+
match=r".*output_type: `annotationstore` is not supported "
55+
r"for `DeepFeatureExtractor` engine",
56+
):
57+
_ = extractor.run(
58+
images=inputs,
59+
return_probabilities=True,
60+
return_labels=False,
61+
device=device,
62+
patch_mode=True,
63+
output_type="annotationstore",
64+
)
65+
66+
67+
def test_feature_extractor_wsi(remote_sample: Callable, track_tmp_path: Path) -> None:
68+
"""Test feature extraction with DeepFeatureExtractor engine."""
69+
save_dir = track_tmp_path / "output"
70+
# # convert to pathlib Path to prevent wsireader complaint
71+
mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs"))
72+
73+
# * test providing pretrained from torch vs pretrained_model.yaml
74+
shutil.rmtree(save_dir, ignore_errors=True) # default output dir test
75+
76+
extractor = DeepFeatureExtractor(batch_size=1, model="fcn-tissue_mask")
77+
output = extractor.run(
78+
images=[mini_wsi_svs],
79+
return_probabilities=False,
80+
return_labels=False,
81+
device=device,
82+
patch_mode=False,
83+
save_dir=track_tmp_path / "wsi_out_check",
84+
batch_size=1,
85+
output_type="zarr",
86+
memory_threshold=1,
87+
)
88+
89+
output_ = zarr.open(output[mini_wsi_svs], mode="r")
90+
assert len(output_["coordinates"].shape) == 2
91+
assert len(output_["features"].shape) == 3
92+
93+
94+
@pytest.mark.parametrize(
95+
"model",
96+
[
97+
CNNBackbone("resnet18"),
98+
TimmBackbone("efficientnet_b0", pretrained=True),
99+
"resnet18",
100+
"efficientnet_b0",
101+
],
102+
)
103+
def test_full_inference(
104+
remote_sample: Callable, track_tmp_path: Path, model: Callable
105+
) -> None:
106+
"""Test full inference with CNNBackbone and TimmBackbone models."""
107+
save_dir = track_tmp_path / "output"
108+
# pre-emptive clean up
109+
shutil.rmtree(save_dir, ignore_errors=True) # default output dir test
110+
111+
mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))
112+
113+
ioconfig = IOPatchPredictorConfig(
114+
input_resolutions=[
115+
{"units": "mpp", "resolution": 0.25},
116+
],
117+
patch_input_shape=[512, 512],
118+
stride_shape=[256, 256],
119+
)
120+
121+
extractor = DeepFeatureExtractor(batch_size=4, model=model)
122+
output = extractor.run(
123+
images=[mini_wsi_svs],
124+
device=device,
125+
save_dir=track_tmp_path / "wsi_out_check",
126+
batch_size=4,
127+
output_type="zarr",
128+
ioconfig=ioconfig,
129+
patch_mode=False,
130+
)
131+
132+
output_ = zarr.open(output[mini_wsi_svs], mode="r")
133+
134+
positions = output_["coordinates"]
135+
features = output_["features"]
136+
137+
reader = WSIReader.open(mini_wsi_svs)
138+
patches = [
139+
reader.read_bounds(
140+
positions[patch_idx],
141+
resolution=0.25,
142+
units="mpp",
143+
pad_constant_values=255,
144+
coord_space="resolution",
145+
)
146+
for patch_idx in range(4)
147+
]
148+
patches = np.array(patches)
149+
patches = torch.from_numpy(patches) # NHWC
150+
patches = patches.permute(0, 3, 1, 2).contiguous() # NCHW
151+
patches = patches.to(device).type(torch.float32)
152+
model = extractor.model
153+
# Inference mode
154+
model.eval()
155+
with torch.inference_mode():
156+
_features = model(patches).cpu().numpy()
157+
# ! must maintain same batch size and likely same ordering
158+
# ! else the output values will not exactly be the same (still < 1.0e-4
159+
# ! of epsilon though)
160+
assert np.mean(np.abs(features[:4] - _features)) < 1.0e-1
161+
162+
163+
@pytest.mark.skipif(
164+
toolbox_env.running_on_ci() or not ON_GPU,
165+
reason="Local test on machine with GPU.",
166+
)
167+
def test_multi_gpu_feature_extraction(
168+
remote_sample: Callable, track_tmp_path: Path
169+
) -> None:
170+
"""Local functionality test for feature extraction using multiple GPUs."""
171+
save_dir = track_tmp_path / "output"
172+
mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))
173+
shutil.rmtree(save_dir, ignore_errors=True)
174+
175+
wsi_ioconfig = IOPatchPredictorConfig(
176+
input_resolutions=[{"units": "mpp", "resolution": 0.5}],
177+
patch_input_shape=[224, 224],
178+
stride_shape=[224, 224],
179+
)
180+
181+
extractor = DeepFeatureExtractor(
182+
model="UNI",
183+
batch_size=32,
184+
num_workers=4,
185+
)
186+
187+
output = extractor.run(
188+
[mini_wsi_svs],
189+
patch_mode=False,
190+
device=device,
191+
ioconfig=wsi_ioconfig,
192+
save_dir=save_dir,
193+
auto_get_mask=True,
194+
output_type="zarr",
195+
)
196+
output_ = zarr.open(output[mini_wsi_svs], mode="r")
197+
198+
positions = output_["coordinates"]
199+
features = output_["features"]
200+
assert len(positions.shape) == 2
201+
assert len(features.shape) == 2
202+
203+
204+
# -------------------------------------------------------------------------------------
205+
# Command Line Interface
206+
# -------------------------------------------------------------------------------------
207+
208+
209+
def test_cli_model_single_file(sample_svs: Path, track_tmp_path: Path) -> None:
210+
"""Test for feature extractor CLI single file."""
211+
runner = CliRunner()
212+
213+
models_wsi_result = runner.invoke(
214+
cli.main,
215+
[
216+
"deep-feature-extractor",
217+
"--img-input",
218+
str(sample_svs),
219+
"--model",
220+
"resnet18",
221+
"--patch-mode",
222+
"False",
223+
"--output-path",
224+
str(track_tmp_path / "output"),
225+
"--patch-input-shape",
226+
"224",
227+
"224",
228+
"--input-resolutions",
229+
'[{"units": "mpp", "resolution": 0.25}]',
230+
],
231+
)
232+
233+
assert models_wsi_result.exit_code == 0
234+
assert (track_tmp_path / "output" / (sample_svs.stem + ".zarr")).exists()
235+
236+
output = zarr.open(
237+
track_tmp_path / "output" / (sample_svs.stem + ".zarr"), mode="r"
238+
)
239+
240+
# Output shape should be # of patches x feature size
241+
assert output["features"].shape == (255, 512)

tests/engines/test_patch_predictor.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -268,17 +268,14 @@ def test_wsi_predictor_api(
268268
kwargs = {
269269
"patch_input_shape": patch_size,
270270
"stride_shape": patch_size,
271-
"input_resolutions": [{"units": "baseline", "resolution": 1.0}],
272271
"save_dir": save_dir,
273272
}
274-
# ! add this test back once the read at `baseline` is fixed
275-
# sanity check, both output should be the same with same resolution read args
276-
# remove previously generated data
277273

278274
_kwargs = copy.deepcopy(kwargs)
279275
# test reading of multiple whole-slide images
280276
output = predictor.run(
281277
images=[mini_wsi_svs, str(mini_wsi_jpg)],
278+
input_resolutions=[{"units": "baseline", "resolution": 1.0}],
282279
masks=[mini_wsi_msk, mini_wsi_msk],
283280
patch_mode=False,
284281
return_probabilities=True,
@@ -308,21 +305,7 @@ def test_patch_predictor_kather100k_output(
308305
pretrained_info = {
309306
"alexnet-kather100k": [1.0, 0.9999735355377197],
310307
"resnet18-kather100k": [1.0, 0.9999911785125732],
311-
"resnet34-kather100k": [1.0, 0.9979840517044067],
312-
"resnet50-kather100k": [1.0, 0.9999986886978149],
313-
"resnet101-kather100k": [1.0, 0.9999932050704956],
314-
"resnext50_32x4d-kather100k": [1.0, 0.9910059571266174],
315-
"resnext101_32x8d-kather100k": [1.0, 0.9999971389770508],
316-
"wide_resnet50_2-kather100k": [1.0, 0.9953408241271973],
317-
"wide_resnet101_2-kather100k": [1.0, 0.9999831914901733],
318-
"densenet121-kather100k": [1.0, 1.0],
319-
"densenet161-kather100k": [1.0, 0.9999959468841553],
320-
"densenet169-kather100k": [1.0, 0.9999934434890747],
321-
"densenet201-kather100k": [1.0, 0.9999983310699463],
322-
"mobilenet_v2-kather100k": [0.9999998807907104, 0.9999126195907593],
323-
"mobilenet_v3_large-kather100k": [0.9999996423721313, 0.9999878406524658],
324308
"mobilenet_v3_small-kather100k": [0.9999998807907104, 0.9999997615814209],
325-
"googlenet-kather100k": [1.0, 0.9999639987945557],
326309
}
327310
for model, expected_prob in pretrained_info.items():
328311
_test_predictor_output(

tests/engines/test_semantic_segmentor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def test_wsi_segmentor_annotationstore(
498498

499499

500500
def test_cli_model_single_file(sample_svs: Path, track_tmp_path: Path) -> None:
501-
"""Test for models CLI single file."""
501+
"""Test semantic segmentor CLI single file."""
502502
runner = CliRunner()
503503
models_wsi_result = runner.invoke(
504504
cli.main,

tests/models/test_models_abc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,5 +174,5 @@ def test_get_pretrained_model_not_str() -> None:
174174

175175
def test_get_pretrained_model_not_in_info() -> None:
176176
"""Test ValueError is raised if input is not in info."""
177-
with pytest.raises(ValueError, match=r"Pretrained model `alexnet` does not exist."):
178-
_ = get_pretrained_model("alexnet")
177+
with pytest.raises(ValueError, match=r"Pretrained model `random` does not exist."):
178+
_ = get_pretrained_model("random")

0 commit comments

Comments
 (0)