Skip to content

Commit 80e7af5

Browse files
authored
🆕 Define DeepFeatureExtractor (#963)
# 🚀 Summary This PR introduces a new **`DeepFeatureExtractor` engine** to the TIAToolbox framework, enabling extraction of intermediate CNN feature representations from whole slide images (WSIs) or image patches. These features can be used for downstream tasks such as clustering, visualization, or training other models. The update also includes: - A **command-line interface (CLI)** for the new engine. - Extended **CLI utilities** for flexible input/output configurations. - Comprehensive **unit tests** covering patch-based and WSI-based workflows, multi-GPU support, and CLI functionality. - Integration with TIAToolbox’s model registry and CLI ecosystem. --- ## ✨ Key Features ### **New Engine: `DeepFeatureExtractor`** - Extracts intermediate CNN features from WSIs or patches. - Outputs feature embeddings and spatial coordinates in **Zarr** or **dict** format. - Implements **memory-aware caching** for large-scale WSI processing. - Compatible with: - TIAToolbox pretrained models. - Torchvision CNN backbones (e.g., ResNet, DenseNet, MobileNet). - **All timm architectures via `timm.list_models()`**, including HuggingFace-hosted models. - Supports both **patch-mode** and **WSI-mode** workflows. ### **CLI Integration** - Adds `deep-feature-extractor` command to TIAToolbox CLI. - Supports options for: - Input/output paths and file types. - Model selection (`resnet18`, `efficientnet_b0`, timm-based backbones, etc.). - Patch extraction parameters (`patch_input_shape`, `stride_shape`, `input_resolutions`). - Batch size, device selection, memory threshold, overwrite behavior. - Flexible JSON-based CLI options for resolutions and class mappings. ### **Extended CLI Utilities** - New reusable options: - `--input-resolutions`, `--output-resolutions` (JSON list of dicts). - `--patch-input-shape`, `--stride-shape`, `--scale-factor`. - `--class-dict` for mapping class indices to names. - `--overwrite` and `--output-file` for fine-grained control. ### **Unit Tests** - **Engine Tests**: - Patch-based and WSI-based feature extraction. - Validation of Zarr outputs (features and coordinates). - Multi-GPU functionality. - **Model Compatibility**: - Tests with `CNNBackbone` and `TimmBackbone` models. - **CLI Tests**: - Single-file and parameterized runs. - Validation of JSON parsing for CLI options. ### **Codebase Integration** - Registers `DeepFeatureExtractor` in `tiatoolbox.models` and engine registry. - Adds CLI command in `tiatoolbox.cli.__init__.py`. - Updates architecture utilities to support timm-based backbones and HuggingFace models. - Introduces dictionaries for Torch and timm backbones (`torch_cnn_backbone_dict`, `timm_arch_dict`).
1 parent b5ba794 commit 80e7af5

19 files changed

+2352
-228
lines changed

tests/engines/test_engine_abc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ def test_engine_run() -> NoReturn:
313313
on_gpu=False,
314314
)
315315

316+
eng = TestEngineABC(model="alexnet-kather100k")
316317
with pytest.raises(
317318
ValueError,
318319
match=r".*The shape of the numpy array should be NHWC*",
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
"""Test for feature extractor."""
2+
3+
import shutil
4+
from collections.abc import Callable
5+
from pathlib import Path
6+
7+
import numpy as np
8+
import pytest
9+
import torch
10+
import zarr
11+
from click.testing import CliRunner
12+
13+
from tiatoolbox import cli
14+
from tiatoolbox.models import IOPatchPredictorConfig
15+
from tiatoolbox.models.architecture.vanilla import CNNBackbone, TimmBackbone
16+
from tiatoolbox.models.engine.deep_feature_extractor import DeepFeatureExtractor
17+
from tiatoolbox.utils import env_detection as toolbox_env
18+
from tiatoolbox.wsicore.wsireader import WSIReader
19+
20+
ON_GPU = not toolbox_env.running_on_ci() and toolbox_env.has_gpu()
21+
22+
# -------------------------------------------------------------------------------------
23+
# Engine
24+
# -------------------------------------------------------------------------------------
25+
26+
device = "cuda" if toolbox_env.has_gpu() else "cpu"
27+
28+
29+
def test_feature_extractor_patches(
30+
remote_sample: Callable,
31+
) -> None:
32+
"""Tests DeepFeatureExtractor on image patches."""
33+
extractor = DeepFeatureExtractor(
34+
model="fcn-tissue_mask", batch_size=32, verbose=False, device=device
35+
)
36+
37+
sample_image = remote_sample("thumbnail-1k-1k")
38+
39+
inputs = [sample_image, sample_image]
40+
41+
assert not extractor.patch_mode
42+
output = extractor.run(
43+
images=inputs,
44+
return_probabilities=True,
45+
return_labels=False,
46+
device=device,
47+
patch_mode=True,
48+
)
49+
50+
assert 0.48 < np.mean(output["features"][:]) < 0.52
51+
52+
with pytest.raises(
53+
ValueError,
54+
match=r".*output_type: `annotationstore` is not supported "
55+
r"for `DeepFeatureExtractor` engine",
56+
):
57+
_ = extractor.run(
58+
images=inputs,
59+
return_probabilities=True,
60+
return_labels=False,
61+
device=device,
62+
patch_mode=True,
63+
output_type="annotationstore",
64+
)
65+
66+
67+
def test_feature_extractor_wsi(remote_sample: Callable, track_tmp_path: Path) -> None:
68+
"""Test feature extraction with DeepFeatureExtractor engine."""
69+
save_dir = track_tmp_path / "output"
70+
# # convert to pathlib Path to prevent wsireader complaint
71+
mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs"))
72+
73+
# * test providing pretrained from torch vs pretrained_model.yaml
74+
shutil.rmtree(save_dir, ignore_errors=True) # default output dir test
75+
76+
extractor = DeepFeatureExtractor(batch_size=1, model="fcn-tissue_mask")
77+
output = extractor.run(
78+
images=[mini_wsi_svs],
79+
return_probabilities=False,
80+
return_labels=False,
81+
device=device,
82+
patch_mode=False,
83+
save_dir=track_tmp_path / "wsi_out_check",
84+
batch_size=1,
85+
output_type="zarr",
86+
memory_threshold=1,
87+
)
88+
89+
output_ = zarr.open(output[mini_wsi_svs], mode="r")
90+
assert len(output_["coordinates"].shape) == 2
91+
assert len(output_["features"].shape) == 3
92+
93+
94+
@pytest.mark.parametrize(
95+
"model",
96+
[
97+
CNNBackbone("resnet18"),
98+
TimmBackbone("efficientnet_b0", pretrained=True),
99+
"resnet18",
100+
"efficientnet_b0",
101+
],
102+
)
103+
def test_full_inference(
104+
remote_sample: Callable, track_tmp_path: Path, model: Callable
105+
) -> None:
106+
"""Test full inference with CNNBackbone and TimmBackbone models."""
107+
save_dir = track_tmp_path / "output"
108+
# pre-emptive clean up
109+
shutil.rmtree(save_dir, ignore_errors=True) # default output dir test
110+
111+
mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))
112+
113+
ioconfig = IOPatchPredictorConfig(
114+
input_resolutions=[
115+
{"units": "mpp", "resolution": 0.25},
116+
],
117+
patch_input_shape=[512, 512],
118+
stride_shape=[256, 256],
119+
)
120+
121+
extractor = DeepFeatureExtractor(batch_size=4, model=model)
122+
output = extractor.run(
123+
images=[mini_wsi_svs],
124+
device=device,
125+
save_dir=track_tmp_path / "wsi_out_check",
126+
batch_size=4,
127+
output_type="zarr",
128+
ioconfig=ioconfig,
129+
patch_mode=False,
130+
)
131+
132+
output_ = zarr.open(output[mini_wsi_svs], mode="r")
133+
134+
positions = output_["coordinates"]
135+
features = output_["features"]
136+
137+
reader = WSIReader.open(mini_wsi_svs)
138+
patches = [
139+
reader.read_bounds(
140+
positions[patch_idx],
141+
resolution=0.25,
142+
units="mpp",
143+
pad_constant_values=255,
144+
coord_space="resolution",
145+
)
146+
for patch_idx in range(4)
147+
]
148+
patches = np.array(patches)
149+
patches = torch.from_numpy(patches) # NHWC
150+
patches = patches.permute(0, 3, 1, 2).contiguous() # NCHW
151+
patches = patches.to(device).type(torch.float32)
152+
model = extractor.model
153+
# Inference mode
154+
model.eval()
155+
with torch.inference_mode():
156+
_features = model(patches).cpu().numpy()
157+
# ! must maintain same batch size and likely same ordering
158+
# ! else the output values will not exactly be the same (still < 1.0e-4
159+
# ! of epsilon though)
160+
assert np.mean(np.abs(features[:4] - _features)) < 1.0e-1
161+
162+
163+
@pytest.mark.skipif(
164+
toolbox_env.running_on_ci() or not ON_GPU,
165+
reason="Local test on machine with GPU.",
166+
)
167+
def test_multi_gpu_feature_extraction(
168+
remote_sample: Callable, track_tmp_path: Path
169+
) -> None:
170+
"""Local functionality test for feature extraction using multiple GPUs."""
171+
save_dir = track_tmp_path / "output"
172+
mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))
173+
shutil.rmtree(save_dir, ignore_errors=True)
174+
175+
wsi_ioconfig = IOPatchPredictorConfig(
176+
input_resolutions=[{"units": "mpp", "resolution": 0.5}],
177+
patch_input_shape=[224, 224],
178+
stride_shape=[224, 224],
179+
)
180+
181+
extractor = DeepFeatureExtractor(
182+
model="UNI",
183+
batch_size=32,
184+
num_workers=4,
185+
)
186+
187+
output = extractor.run(
188+
[mini_wsi_svs],
189+
patch_mode=False,
190+
device=device,
191+
ioconfig=wsi_ioconfig,
192+
save_dir=save_dir,
193+
auto_get_mask=True,
194+
output_type="zarr",
195+
)
196+
output_ = zarr.open(output[mini_wsi_svs], mode="r")
197+
198+
positions = output_["coordinates"]
199+
features = output_["features"]
200+
assert len(positions.shape) == 2
201+
assert len(features.shape) == 2
202+
203+
204+
# -------------------------------------------------------------------------------------
205+
# Command Line Interface
206+
# -------------------------------------------------------------------------------------
207+
208+
209+
def test_cli_model_single_file(sample_svs: Path, track_tmp_path: Path) -> None:
210+
"""Test for feature extractor CLI single file."""
211+
runner = CliRunner()
212+
213+
models_wsi_result = runner.invoke(
214+
cli.main,
215+
[
216+
"deep-feature-extractor",
217+
"--img-input",
218+
str(sample_svs),
219+
"--model",
220+
"resnet18",
221+
"--patch-mode",
222+
"False",
223+
"--output-path",
224+
str(track_tmp_path / "output"),
225+
"--patch-input-shape",
226+
"224",
227+
"224",
228+
"--input-resolutions",
229+
'[{"units": "mpp", "resolution": 0.25}]',
230+
],
231+
)
232+
233+
assert models_wsi_result.exit_code == 0
234+
assert (track_tmp_path / "output" / (sample_svs.stem + ".zarr")).exists()
235+
236+
output = zarr.open(
237+
track_tmp_path / "output" / (sample_svs.stem + ".zarr"), mode="r"
238+
)
239+
240+
# Output shape should be # of patches x feature size
241+
assert output["features"].shape == (255, 512)

tests/engines/test_patch_predictor.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -268,17 +268,14 @@ def test_wsi_predictor_api(
268268
kwargs = {
269269
"patch_input_shape": patch_size,
270270
"stride_shape": patch_size,
271-
"input_resolutions": [{"units": "baseline", "resolution": 1.0}],
272271
"save_dir": save_dir,
273272
}
274-
# ! add this test back once the read at `baseline` is fixed
275-
# sanity check, both output should be the same with same resolution read args
276-
# remove previously generated data
277273

278274
_kwargs = copy.deepcopy(kwargs)
279275
# test reading of multiple whole-slide images
280276
output = predictor.run(
281277
images=[mini_wsi_svs, str(mini_wsi_jpg)],
278+
input_resolutions=[{"units": "baseline", "resolution": 1.0}],
282279
masks=[mini_wsi_msk, mini_wsi_msk],
283280
patch_mode=False,
284281
return_probabilities=True,
@@ -308,21 +305,7 @@ def test_patch_predictor_kather100k_output(
308305
pretrained_info = {
309306
"alexnet-kather100k": [1.0, 0.9999735355377197],
310307
"resnet18-kather100k": [1.0, 0.9999911785125732],
311-
"resnet34-kather100k": [1.0, 0.9979840517044067],
312-
"resnet50-kather100k": [1.0, 0.9999986886978149],
313-
"resnet101-kather100k": [1.0, 0.9999932050704956],
314-
"resnext50_32x4d-kather100k": [1.0, 0.9910059571266174],
315-
"resnext101_32x8d-kather100k": [1.0, 0.9999971389770508],
316-
"wide_resnet50_2-kather100k": [1.0, 0.9953408241271973],
317-
"wide_resnet101_2-kather100k": [1.0, 0.9999831914901733],
318-
"densenet121-kather100k": [1.0, 1.0],
319-
"densenet161-kather100k": [1.0, 0.9999959468841553],
320-
"densenet169-kather100k": [1.0, 0.9999934434890747],
321-
"densenet201-kather100k": [1.0, 0.9999983310699463],
322-
"mobilenet_v2-kather100k": [0.9999998807907104, 0.9999126195907593],
323-
"mobilenet_v3_large-kather100k": [0.9999996423721313, 0.9999878406524658],
324308
"mobilenet_v3_small-kather100k": [0.9999998807907104, 0.9999997615814209],
325-
"googlenet-kather100k": [1.0, 0.9999639987945557],
326309
}
327310
for model, expected_prob in pretrained_info.items():
328311
_test_predictor_output(

tests/engines/test_semantic_segmentor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def test_wsi_segmentor_annotationstore(
498498

499499

500500
def test_cli_model_single_file(sample_svs: Path, track_tmp_path: Path) -> None:
501-
"""Test for models CLI single file."""
501+
"""Test semantic segmentor CLI single file."""
502502
runner = CliRunner()
503503
models_wsi_result = runner.invoke(
504504
cli.main,

tests/models/test_models_abc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,5 +174,5 @@ def test_get_pretrained_model_not_str() -> None:
174174

175175
def test_get_pretrained_model_not_in_info() -> None:
176176
"""Test ValueError is raised if input is not in info."""
177-
with pytest.raises(ValueError, match=r"Pretrained model `alexnet` does not exist."):
178-
_ = get_pretrained_model("alexnet")
177+
with pytest.raises(ValueError, match=r"Pretrained model `random` does not exist."):
178+
_ = get_pretrained_model("random")

0 commit comments

Comments
 (0)