Skip to content

Commit f1b622e

Browse files
committed
⚡ Implement hybrid approach
1 parent d202756 commit f1b622e

File tree

1 file changed

+92
-23
lines changed

1 file changed

+92
-23
lines changed

tiatoolbox/models/engine/semantic_segmentor.py

Lines changed: 92 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
import dask.array as da
99
import numpy as np
1010
import torch
11-
from dask import delayed
1211
from typing_extensions import Unpack
1312

1413
from tiatoolbox import logger
1514
from tiatoolbox.models.dataset.dataset_abc import WSIPatchDataset
1615
from tiatoolbox.utils.misc import (
1716
dict_to_store_semantic_segmentor,
1817
dict_to_zarr,
18+
get_tqdm,
1919
)
2020

2121
from .patch_predictor import PatchPredictor, PredictorRunParams
@@ -33,19 +33,21 @@
3333

3434

3535
def merge_all(
36-
blocks: da.Array,
36+
blocks: np.ndarray,
3737
output_locations: np.ndarray,
3838
merged_shape: tuple,
3939
dtype_: type,
40-
) -> np.ndarray:
40+
) -> tuple[np.ndarray, np.ndarray]:
4141
"""Helper function to merge predictions."""
4242
canvas = np.zeros(merged_shape, dtype=dtype_)
43-
count = np.zeros(merged_shape, dtype=np.uint8)
43+
count = np.zeros(merged_shape[:2], dtype=np.uint8)
4444
for i, block in enumerate(blocks):
4545
xs, ys, xe, ye = output_locations[i]
46-
canvas[ys:ye, xs:xe, :] += block
47-
count[ys:ye, xs:xe, :] += 1
48-
return canvas / np.maximum(count, 1)
46+
# To deal with edge cases
47+
ye, xe = min(ye, canvas.shape[0]), min(xe, canvas.shape[1])
48+
canvas[ys:ye, xs:xe, :] += block[0 : ye - ys, 0 : xe - xs, :]
49+
count[ys:ye, xs:xe] += 1
50+
return canvas, count
4951

5052

5153
class SemanticSegmentorRunParams(PredictorRunParams):
@@ -404,33 +406,100 @@ def infer_wsi(
404406
405407
"""
406408
_ = kwargs.get("return_probabilities", False)
407-
raw_predictions = self.infer_patches(
408-
dataloader=dataloader,
409-
return_coordinates=True,
409+
410+
keys = ["probabilities", "coordinates"]
411+
coordinates = []
412+
413+
if self.return_labels:
414+
keys.append("labels")
415+
labels = []
416+
417+
# Main output dictionary
418+
raw_predictions = dict(zip(keys, [[]] * len(keys)))
419+
420+
# sample for calculating shape for dask arrays
421+
sample = self.dataloader.dataset[0]
422+
sample_output = self.model.infer_batch(
423+
self.model,
424+
torch.Tensor(sample["image"][np.newaxis, ...]),
425+
device=self.device,
410426
)
411427

428+
# Create canvas and counts
412429
max_location = np.max(self.output_locations, axis=0)
413430
merged_shape = (
414431
max_location[3],
415432
max_location[2],
416-
raw_predictions["probabilities"].shape[3],
433+
sample_output.shape[3],
434+
)
435+
canvas = da.zeros(merged_shape, dtype=sample_output.dtype)
436+
count = da.zeros(merged_shape[:2], dtype=np.uint8)
437+
438+
# Inference loop
439+
tqdm = get_tqdm()
440+
tqdm_loop = (
441+
tqdm(dataloader, leave=False, desc="Inferring patches")
442+
if self.verbose
443+
else self.dataloader
417444
)
418445

419-
raw_probs = raw_predictions["probabilities"].rechunk((1, 512, 512, 5))
420-
dtype_ = raw_predictions["probabilities"].dtype
446+
for batch_data in tqdm_loop:
447+
batch_output = self.model.infer_batch(
448+
self.model,
449+
batch_data["image"],
450+
device=self.device,
451+
)
421452

422-
merged = delayed(merge_all)(
423-
raw_probs,
424-
self.output_locations,
425-
merged_shape,
426-
dtype_,
427-
)
453+
output_locs = batch_data["output_locs"].numpy()
454+
455+
batch_xs, batch_ys = np.min(output_locs[:, 0:2], axis=0)
456+
batch_xe, batch_ye = np.max(output_locs[:, 2:4], axis=0)
457+
458+
merged_shape_batch = (
459+
batch_ye - batch_ys,
460+
batch_xe - batch_xs,
461+
sample_output.shape[3],
462+
)
463+
464+
merged_output, merged_count = merge_all(
465+
batch_output,
466+
output_locs - np.array([batch_xs, batch_ys, batch_xs, batch_ys]),
467+
merged_shape_batch,
468+
sample_output.dtype,
469+
)
470+
471+
batch_ye, batch_xe = (
472+
min(batch_ye, canvas.shape[0]),
473+
min(batch_xe, canvas.shape[1]),
474+
)
475+
476+
canvas[
477+
batch_ys:batch_ye,
478+
batch_xs:batch_xe,
479+
:,
480+
] += merged_output
481+
482+
count[
483+
batch_ys:batch_ye,
484+
batch_xs:batch_xe,
485+
] += merged_count
486+
487+
coordinates.append(
488+
da.from_array(
489+
self._get_coordinates(batch_data),
490+
)
491+
)
492+
493+
if self.return_labels:
494+
labels.append(da.from_array(np.array(batch_data["label"])))
428495

429-
raw_predictions["probabilities"] = da.from_delayed(
430-
merged,
431-
shape=merged_shape,
432-
dtype=dtype_,
496+
raw_predictions["probabilities"] = canvas / da.maximum(
497+
count[:, :, np.newaxis], 1
433498
)
499+
raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0)
500+
if self.return_labels:
501+
labels = [label.reshape(-1) for label in labels]
502+
raw_predictions["labels"] = da.concatenate(labels, axis=0)
434503

435504
return raw_predictions
436505

0 commit comments

Comments
 (0)