Skip to content

Commit 1d6b877

Browse files
committed
⚡ Improve merge performance
1 parent 621adb5 commit 1d6b877

File tree

4 files changed

+122
-44
lines changed

4 files changed

+122
-44
lines changed

tiatoolbox/models/dataset/dataset_abc.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,10 @@ def __init__( # skipcq: PY-R1000
523523
def __getitem__(self: WSIPatchDataset, idx: int) -> dict:
524524
"""Get an item from the dataset."""
525525
coords = self.inputs[idx]
526+
output_locs = None
527+
if hasattr(self, "outputs"):
528+
output_locs = self.outputs[idx]
529+
526530
# Read image patch from the whole-slide image
527531
patch = self.reader.read_bounds(
528532
coords,
@@ -535,6 +539,13 @@ def __getitem__(self: WSIPatchDataset, idx: int) -> dict:
535539
# Apply preprocessing to selected patch
536540
patch = self._preproc(patch)
537541

542+
if output_locs is not None:
543+
return {
544+
"image": patch,
545+
"coords": np.array(coords),
546+
"output_locs": output_locs,
547+
}
548+
538549
return {"image": patch, "coords": np.array(coords)}
539550

540551

tiatoolbox/models/engine/engine_abc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch
1414
import zarr
1515
from dask import compute, delayed
16+
from dask.diagnostics import ProgressBar
1617
from torch import nn
1718
from typing_extensions import Unpack
1819

@@ -669,7 +670,8 @@ def save_predictions(
669670
)
670671
write_tasks.append(task)
671672

672-
compute(*write_tasks)
673+
with ProgressBar():
674+
compute(*write_tasks)
673675

674676
return save_path
675677

tiatoolbox/models/engine/semantic_segmentor.py

Lines changed: 104 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,16 @@
2424
if TYPE_CHECKING: # pragma: no cover
2525
import os
2626

27+
from torch.utils.data import DataLoader
28+
2729
from tiatoolbox.annotation import AnnotationStore
2830
from tiatoolbox.models.engine.io_config import IOSegmentorConfig
2931
from tiatoolbox.models.models_abc import ModelABC
3032
from tiatoolbox.type_hints import Resolution
3133
from tiatoolbox.wsicore import WSIReader
3234

35+
from .engine_abc import EngineABC, EngineABCRunParams
36+
3337

3438
class SemanticSegmentorRunParams(PredictorRunParams):
3539
"""Class describing the input parameters for the :func:`EngineABC.run()` method.
@@ -364,85 +368,144 @@ def get_dataloader(
364368
patch_mode=patch_mode,
365369
)
366370

367-
def post_process_wsi(
368-
self: SemanticSegmentor,
369-
raw_predictions: Path,
370-
**kwargs: Unpack[PredictorRunParams],
371+
def infer_wsi(
372+
self: EngineABC,
373+
dataloader: DataLoader,
374+
**kwargs: EngineABCRunParams,
371375
) -> Path:
372-
"""Returns an array from raw predictions.
376+
"""Model inference on a WSI.
373377
374-
Merges raw predictions from individual patches into a single prediction array if
375-
patch_mode is False.
378+
Args:
379+
dataloader (DataLoader):
380+
A torch dataloader to process WSIs.
381+
382+
save_path (Path):
383+
Path to save the intermediate output. The intermediate output is saved
384+
in a zarr file.
385+
**kwargs (EngineABCRunParams):
386+
Keyword Args to update setup_patch_dataset() method attributes. See
387+
:class:`EngineRunParams` for accepted keyword arguments.
388+
389+
Returns:
390+
save_path (Path):
391+
Path to zarr file where intermediate output is saved.
376392
377393
"""
378394
_ = kwargs.get("return_probabilities")
379-
progress_bar = None
395+
380396
tqdm = get_tqdm()
381397

382-
if self.verbose:
383-
progress_bar = tqdm(
384-
total=len(self.output_locations),
385-
leave=False,
386-
desc="Merging Patch Outputs",
387-
)
398+
progress_bar = (
399+
tqdm(total=len(dataloader), leave=self.patch_mode, desc="Inferring patches")
400+
if self.verbose
401+
else None
402+
)
388403

389-
num_post_proc_workers = self.num_post_proc_workers
404+
keys = ["coordinates"]
390405

391-
if num_post_proc_workers is not None and num_post_proc_workers > 0:
392-
dask.config.set(scheduler="threads", num_workers=num_post_proc_workers)
393-
else:
394-
dask.config.set(scheduler="threads")
406+
if self.return_labels:
407+
keys.append("labels")
395408

396-
dask_patch_probabilities = raw_predictions["probabilities"]
409+
raw_predictions = dict.fromkeys(keys)
397410

398-
# --- Calculate canvas parameters from Dask array and locations ---
399411
max_location = np.max(self.output_locations, axis=0)
412+
413+
out_ = self.model.infer_batch(
414+
self.model,
415+
torch.from_numpy(dataloader.dataset[0]["image"][None, :, :, :]),
416+
device=self.device,
417+
)
418+
400419
merged_shape = (
401420
max_location[3],
402421
max_location[2],
403-
dask_patch_probabilities.shape[3],
422+
out_["probabilities"].shape[3],
404423
)
405424

406425
# creating dask arrays for faster processing
407426
merged_probabilities = da.zeros(
408427
shape=merged_shape,
409-
dtype=dask_patch_probabilities.dtype,
428+
dtype=out_["probabilities"].dtype,
410429
chunks=merged_shape,
411430
)
412431

413432
merged_weights = da.zeros(
414-
shape=merged_shape,
433+
shape=merged_shape[:2],
415434
dtype=int,
416-
chunks=merged_shape,
435+
chunks=merged_shape[:2],
417436
)
418437

419-
for idx, location in enumerate(self.output_locations):
420-
start_x, start_y, end_x, end_y = location
421-
patch_probs = dask_patch_probabilities[
422-
idx, 0 : end_y - start_y, 0 : end_x - start_x, :
423-
]
424-
merged_probabilities[start_y:end_y, start_x:end_x, :] = (
425-
merged_probabilities[start_y:end_y, start_x:end_x, :] + patch_probs
438+
for _, batch_data in enumerate(dataloader):
439+
batch_output = self.model.infer_batch(
440+
self.model,
441+
batch_data["image"],
442+
device=self.device,
426443
)
427-
merged_weights[start_y:end_y, start_x:end_x] = (
428-
merged_weights[start_y:end_y, start_x:end_x] + 1
444+
445+
batch_output["coordinates"] = self._get_coordinates(batch_data)
446+
447+
if self.return_labels: # be careful of `s`
448+
if isinstance(batch_data["label"], torch.Tensor):
449+
batch_output["labels"] = batch_data["label"].numpy()
450+
else:
451+
batch_output["labels"] = np.array(batch_data["label"])
452+
453+
output_locs = batch_data["output_locs"]
454+
455+
for idx, location in enumerate(output_locs.numpy()):
456+
start_x, start_y, end_x, end_y = location
457+
patch_probs = batch_output["probabilities"][
458+
idx, 0 : end_y - start_y, 0 : end_x - start_x, :
459+
]
460+
merged_probabilities[start_y:end_y, start_x:end_x, :] = (
461+
merged_probabilities[start_y:end_y, start_x:end_x, :] + patch_probs
462+
)
463+
merged_weights[start_y:end_y, start_x:end_x] = (
464+
merged_weights[start_y:end_y, start_x:end_x] + 1
465+
)
466+
467+
del batch_output["probabilities"]
468+
raw_predictions = self._update_model_output(
469+
raw_predictions=raw_predictions,
470+
raw_output=batch_output,
429471
)
472+
430473
if progress_bar:
431474
progress_bar.update()
432475

476+
merged_weights = da.maximum(merged_weights, 1)
477+
raw_predictions["probabilities"] = (
478+
merged_probabilities / merged_weights[:, :, None]
479+
)
480+
433481
if progress_bar:
434482
progress_bar.close()
435483

436-
# Normalize where weight > 1
437-
final_probabilities_dask = da.where(
438-
merged_weights > 1,
439-
merged_probabilities / merged_weights,
440-
merged_probabilities,
441-
)
484+
return raw_predictions
485+
486+
def post_process_wsi(
487+
self: SemanticSegmentor,
488+
raw_predictions: Path,
489+
**kwargs: Unpack[PredictorRunParams],
490+
) -> Path:
491+
"""Returns an array from raw predictions.
492+
493+
Merges raw predictions from individual patches into a single prediction array if
494+
patch_mode is False.
495+
496+
"""
497+
_ = kwargs.get("return_probabilities")
498+
499+
num_post_proc_workers = self.num_post_proc_workers
500+
501+
if num_post_proc_workers is not None and num_post_proc_workers > 0:
502+
dask.config.set(scheduler="threads", num_workers=num_post_proc_workers)
503+
else:
504+
dask.config.set(scheduler="threads")
442505

443506
# Applying Post-Processing
444507
raw_predictions["predictions"] = self.model.postproc_func(
445-
final_probabilities_dask,
508+
raw_predictions["probabilities"],
446509
)
447510

448511
return raw_predictions

tiatoolbox/models/models_abc.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,17 @@ def forward(
102102

103103
@staticmethod
104104
@abstractmethod
105-
def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> dict:
105+
def infer_batch(
106+
model: nn.Module, batch_data: np.ndarray | torch.Tensor, *, device: str
107+
) -> dict:
106108
"""Run inference on an input batch.
107109
108110
Contains logic for forward operation as well as I/O aggregation.
109111
110112
Args:
111113
model (nn.Module):
112114
PyTorch defined model.
113-
batch_data (np.ndarray):
115+
batch_data (np.ndarray | torch.Tensor):
114116
A batch of data generated by
115117
`torch.utils.data.DataLoader`.
116118
device (str):

0 commit comments

Comments
 (0)