Skip to content

Commit 9f8c041

Browse files
gozdegpre-commit-ci[bot]shaneahmed
authored
📝 Update 05-patch-prediction notebook for the New API (#977)
## Summary This PR updates the patch‑prediction example to align with the new `PatchPredictor` engine and fixes a long‑standing issue in `EngineABC` related to model‑attribute retrieval when using `DataParallel`. --- ## What’s Changed ### 🔧 Example Notebook Updates - Updated **`examples/05-patch-prediction.ipynb`** to use the new `PatchPredictor` engine API. - Added a new **“Visualize in TIAViz”** section, allowing readers to directly inspect prediction results inside **TIAViz** for a smoother, more interactive workflow. ### 🐛 EngineABC Bug Fix - Fixed a bug in **`EngineABC`** where model attributes were incorrectly retrieved from a `DataParallel` wrapper. - Introduced `_get_model_attr()` to safely unwrap the underlying model when needed. - This resolves multi‑GPU crashes caused by attributes living on the wrapped module instead of the actual model. --- ## Why This Matters - Ensures the patch‑prediction example stays up‑to‑date with the latest engine design. - Improves multi‑GPU stability and prevents confusing attribute‑access errors. - Enhances the user experience by integrating TIAViz visualization directly into the example workflow. --- ## Testing - Verified that the updated notebook runs end‑to‑end with the new engine. - Confirmed that multi‑GPU training and inference no longer crash when accessing model attributes. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com>
1 parent d40dc64 commit 9f8c041

File tree

8 files changed

+608
-191
lines changed

8 files changed

+608
-191
lines changed

examples/05-patch-prediction.ipynb

Lines changed: 578 additions & 183 deletions
Large diffs are not rendered by default.

pre-commit/notebook_markdown_format.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def main(files: list[Path]) -> None:
5050
5151
"""
5252
for path in files:
53-
notebook = json.loads(path.read_text())
53+
with Path.open(path, encoding="utf-8", errors="ignore") as f:
54+
notebook = json.load(f)
5455
formatted_notebook = format_notebook(copy.deepcopy(notebook))
5556
changed = any(
5657
cell != formatted_cell

requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ flask-cors>=4.0.0
1010
glymur>=0.12.7
1111
huggingface_hub>=0.33.3
1212
imagecodecs>=2022.9.26
13+
ipywidgets>=8.1.7
1314
joblib>=1.1.1
1415
jupyterlab>=3.5.2
1516
matplotlib>=3.6.2

tests/engines/test_engine_abc.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,13 @@ def test_engine_initalization() -> NoReturn:
273273
eng = TestEngineABC(model=model, weights=weights_path)
274274
assert isinstance(eng, EngineABC)
275275

276+
with pytest.raises(AttributeError):
277+
_ = eng._get_model_attr("test_attr")
278+
279+
model.test_attr = True
280+
eng = TestEngineABC(model=model, weights=weights_path)
281+
assert eng._get_model_attr("test_attr") is True
282+
276283

277284
def test_engine_run() -> NoReturn:
278285
"""Test engine run."""

tiatoolbox/models/engine/deep_feature_extractor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,9 @@ def infer_wsi(
302302
probabilities_zarr, coordinates_zarr = None, None
303303

304304
probabilities_used_percent = 0
305+
infer_batch = self._get_model_attr("infer_batch")
305306
for batch_data in tqdm_loop:
306-
batch_output = self.model.infer_batch(
307+
batch_output = infer_batch(
307308
self.model,
308309
batch_data["image"],
309310
device=self.device,

tiatoolbox/models/engine/engine_abc.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565

6666
if TYPE_CHECKING: # pragma: no cover
6767
import os
68+
from collections.abc import Callable
6869

6970
from torch.utils.data import DataLoader
7071

@@ -375,6 +376,14 @@ def _initialize_model_ioconfig(
375376

376377
return model, None
377378

379+
def _get_model_attr(self: EngineABC, attr_name: str) -> Callable:
380+
"""Return a model attribute, unwrapping DataParallel if required."""
381+
try:
382+
return getattr(self.model, attr_name)
383+
except AttributeError:
384+
module = getattr(self.model, "module", None)
385+
return getattr(module, attr_name)
386+
378387
def get_dataloader(
379388
self: EngineABC,
380389
images: str | Path | list[str | Path] | np.ndarray,
@@ -428,7 +437,7 @@ def get_dataloader(
428437
auto_get_mask=auto_get_mask,
429438
)
430439

431-
dataset.preproc_func = self.model.preproc_func
440+
dataset.preproc_func = self._get_model_attr("preproc_func")
432441

433442
# preprocessing must be defined with the dataset
434443
return torch.utils.data.DataLoader(
@@ -444,7 +453,7 @@ def get_dataloader(
444453
inputs=images, labels=labels, patch_input_shape=ioconfig.patch_input_shape
445454
)
446455

447-
dataset.preproc_func = self.model.preproc_func
456+
dataset.preproc_func = self._get_model_attr("preproc_func")
448457

449458
# preprocessing must be defined with the dataset
450459
return torch.utils.data.DataLoader(
@@ -529,8 +538,9 @@ def infer_patches(
529538
else self.dataloader
530539
)
531540

541+
infer_batch = self._get_model_attr("infer_batch")
532542
for batch_data in tqdm_loop:
533-
batch_output = self.model.infer_batch(
543+
batch_output = infer_batch(
534544
self.model,
535545
batch_data["image"],
536546
device=self.device,

tiatoolbox/models/engine/patch_predictor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,8 @@ def post_process_patches(
372372
dict[str, da.Array]: Post-processed predictions as a Dask array.
373373
374374
"""
375-
predictions = self.model.postproc_func(raw_predictions["probabilities"])
375+
postproc_func = self._get_model_attr("postproc_func")
376+
predictions = postproc_func(raw_predictions["probabilities"])
376377
raw_predictions["predictions"] = cast_to_min_dtype(predictions)
377378
return raw_predictions
378379

tiatoolbox/models/engine/semantic_segmentor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def get_dataloader(
361361
auto_get_mask=auto_get_mask,
362362
)
363363

364-
dataset.preproc_func = self.model.preproc_func
364+
dataset.preproc_func = self._get_model_attr("preproc_func")
365365
self.output_locations = dataset.outputs
366366

367367
# preprocessing must be defined with the dataset
@@ -477,8 +477,9 @@ def infer_wsi(
477477
else dataloader.dataset.outputs
478478
)
479479

480+
infer_batch = self._get_model_attr("infer_batch")
480481
for batch_idx, batch_data in enumerate(tqdm_loop):
481-
batch_output = self.model.infer_batch(
482+
batch_output = infer_batch(
482483
self.model,
483484
batch_data["image"],
484485
device=self.device,

0 commit comments

Comments
 (0)