Skip to content

Commit 7864a75

Browse files
Jiaqi LvJiaqi Lv
authored andcommitted
refactor postproc to accept 'raw_prediictions'
1 parent 9e15154 commit 7864a75

File tree

4 files changed

+14
-11
lines changed

4 files changed

+14
-11
lines changed

tests/engines/test_engine_abc.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pathlib import Path
99
from typing import NoReturn
1010

11+
import dask.array as da
1112
import numpy as np
1213
import pytest
1314
import torch
@@ -353,11 +354,13 @@ def test_engine_run() -> NoReturn:
353354
assert "probabilities" in out
354355
assert "labels" in out
355356

357+
raw_output = np.zeros((3, 3, 3))
356358
pred = eng.post_process_wsi(
357-
raw_predictions=Path("/path/to/raw_predictions.npy"),
359+
raw_predictions={"probabilities": da.from_array(raw_output)},
358360
save_path=Path("/path/to/save_predictions.zarr"),
359361
)
360-
assert str(pred) == "/path/to/raw_predictions.npy"
362+
pred = np.array(pred)
363+
np.testing.assert_array_equal(pred, raw_output)
361364

362365

363366
def test_engine_run_with_verbose() -> NoReturn:

tiatoolbox/models/engine/engine_abc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ def post_process_patches( # skipcq: PYL-R0201
625625
Post-processed predictions as a Dask array.
626626
627627
"""
628-
return raw_predictions
628+
return raw_predictions["probabilities"]
629629

630630
def save_predictions(
631631
self: EngineABC,
@@ -933,7 +933,7 @@ def post_process_wsi( # skipcq: PYL-R0201
933933
Post-processed predictions as a Dask array.
934934
935935
"""
936-
return raw_predictions
936+
return raw_predictions["probabilities"]
937937

938938
def _load_ioconfig(self: EngineABC, ioconfig: ModelIOConfigABC) -> ModelIOConfigABC:
939939
"""Load or validate the IO configuration for the engine.
@@ -1367,7 +1367,7 @@ def _run_patch_mode(
13671367
)
13681368

13691369
raw_predictions["predictions"] = self.post_process_patches(
1370-
raw_predictions=raw_predictions["probabilities"],
1370+
raw_predictions=raw_predictions,
13711371
prediction_shape=raw_predictions["probabilities"].shape[:-1],
13721372
prediction_dtype=raw_predictions["probabilities"].dtype,
13731373
**kwargs,
@@ -1543,7 +1543,7 @@ def get_path(image: Path | WSIReader) -> Path:
15431543
)
15441544

15451545
raw_predictions["predictions"] = self.post_process_wsi(
1546-
raw_predictions=raw_predictions["probabilities"],
1546+
raw_predictions=raw_predictions,
15471547
save_path=save_path[get_path(image)],
15481548
prediction_shape=raw_predictions["probabilities"].shape[:-1],
15491549
prediction_dtype=raw_predictions["probabilities"].dtype,

tiatoolbox/models/engine/nucleus_detector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,8 @@ def post_process_patches(
356356
probs = []
357357

358358
# Process each patch's predictions
359-
for i in range(raw_predictions.shape[0]):
360-
probs_prediction_patch = raw_predictions[i].compute()
359+
for i in range(raw_predictions["probabilities"].shape[0]):
360+
probs_prediction_patch = raw_predictions["probabilities"][i].compute()
361361
centroids_map_patch = self.model.postproc(
362362
probs_prediction_patch,
363363
min_distance=min_distance,
@@ -462,13 +462,13 @@ def post_process_wsi(
462462
depth = {0: depth_h, 1: depth_w, 2: 0}
463463

464464
# Re-chunk to post-processing tile shape for more efficient processing
465-
rechunked_prediction_map = raw_predictions.rechunk(
465+
rechunked_probability_map = raw_predictions["probabilities"].rechunk(
466466
(postproc_tile_shape[0], postproc_tile_shape[1], -1)
467467
)
468468

469469
centroid_maps = da.map_overlap(
470470
self.model.postproc,
471-
rechunked_prediction_map,
471+
rechunked_probability_map,
472472
min_distance=min_distance,
473473
threshold_abs=threshold_abs,
474474
threshold_rel=threshold_rel,

tiatoolbox/models/engine/patch_predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def post_process_patches(
381381
_ = kwargs.get("return_probabilities")
382382
_ = prediction_shape
383383
_ = prediction_dtype
384-
raw_predictions = self.model.postproc_func(raw_predictions)
384+
raw_predictions = self.model.postproc_func(raw_predictions["probabilities"])
385385
return cast_to_min_dtype(raw_predictions)
386386

387387
def post_process_wsi(

0 commit comments

Comments
 (0)