|
8 | 8 | import dask.array as da |
9 | 9 | import numpy as np |
10 | 10 | import torch |
11 | | -from dask import delayed |
12 | 11 | from typing_extensions import Unpack |
13 | 12 |
|
14 | 13 | from tiatoolbox import logger |
15 | 14 | from tiatoolbox.models.dataset.dataset_abc import WSIPatchDataset |
16 | 15 | from tiatoolbox.utils.misc import ( |
17 | 16 | dict_to_store_semantic_segmentor, |
18 | 17 | dict_to_zarr, |
| 18 | + get_tqdm, |
19 | 19 | ) |
20 | 20 |
|
21 | 21 | from .patch_predictor import PatchPredictor, PredictorRunParams |
|
33 | 33 |
|
34 | 34 |
|
35 | 35 | def merge_all( |
36 | | - blocks: da.Array, |
| 36 | + blocks: np.ndarray, |
37 | 37 | output_locations: np.ndarray, |
38 | 38 | merged_shape: tuple, |
39 | 39 | dtype_: type, |
40 | | -) -> np.ndarray: |
| 40 | +) -> tuple[np.ndarray, np.ndarray]: |
41 | 41 | """Helper function to merge predictions.""" |
42 | 42 | canvas = np.zeros(merged_shape, dtype=dtype_) |
43 | | - count = np.zeros(merged_shape, dtype=np.uint8) |
| 43 | + count = np.zeros(merged_shape[:2], dtype=np.uint8) |
44 | 44 | for i, block in enumerate(blocks): |
45 | 45 | xs, ys, xe, ye = output_locations[i] |
46 | | - canvas[ys:ye, xs:xe, :] += block |
47 | | - count[ys:ye, xs:xe, :] += 1 |
48 | | - return canvas / np.maximum(count, 1) |
| 46 | + # To deal with edge cases |
| 47 | + ye, xe = min(ye, canvas.shape[0]), min(xe, canvas.shape[1]) |
| 48 | + canvas[ys:ye, xs:xe, :] += block[0 : ye - ys, 0 : xe - xs, :] |
| 49 | + count[ys:ye, xs:xe] += 1 |
| 50 | + return canvas, count |
49 | 51 |
|
50 | 52 |
|
51 | 53 | class SemanticSegmentorRunParams(PredictorRunParams): |
@@ -404,33 +406,100 @@ def infer_wsi( |
404 | 406 |
|
405 | 407 | """ |
406 | 408 | _ = kwargs.get("return_probabilities", False) |
407 | | - raw_predictions = self.infer_patches( |
408 | | - dataloader=dataloader, |
409 | | - return_coordinates=True, |
| 409 | + |
| 410 | + keys = ["probabilities", "coordinates"] |
| 411 | + coordinates = [] |
| 412 | + |
| 413 | + if self.return_labels: |
| 414 | + keys.append("labels") |
| 415 | + labels = [] |
| 416 | + |
| 417 | + # Main output dictionary |
| 418 | + raw_predictions = dict(zip(keys, [[]] * len(keys))) |
| 419 | + |
| 420 | + # sample for calculating shape for dask arrays |
| 421 | + sample = self.dataloader.dataset[0] |
| 422 | + sample_output = self.model.infer_batch( |
| 423 | + self.model, |
| 424 | + torch.Tensor(sample["image"][np.newaxis, ...]), |
| 425 | + device=self.device, |
410 | 426 | ) |
411 | 427 |
|
| 428 | + # Create canvas and counts |
412 | 429 | max_location = np.max(self.output_locations, axis=0) |
413 | 430 | merged_shape = ( |
414 | 431 | max_location[3], |
415 | 432 | max_location[2], |
416 | | - raw_predictions["probabilities"].shape[3], |
| 433 | + sample_output.shape[3], |
| 434 | + ) |
| 435 | + canvas = da.zeros(merged_shape, dtype=sample_output.dtype) |
| 436 | + count = da.zeros(merged_shape[:2], dtype=np.uint8) |
| 437 | + |
| 438 | + # Inference loop |
| 439 | + tqdm = get_tqdm() |
| 440 | + tqdm_loop = ( |
| 441 | + tqdm(dataloader, leave=False, desc="Inferring patches") |
| 442 | + if self.verbose |
| 443 | + else self.dataloader |
417 | 444 | ) |
418 | 445 |
|
419 | | - raw_probs = raw_predictions["probabilities"].rechunk((1, 512, 512, 5)) |
420 | | - dtype_ = raw_predictions["probabilities"].dtype |
| 446 | + for batch_data in tqdm_loop: |
| 447 | + batch_output = self.model.infer_batch( |
| 448 | + self.model, |
| 449 | + batch_data["image"], |
| 450 | + device=self.device, |
| 451 | + ) |
421 | 452 |
|
422 | | - merged = delayed(merge_all)( |
423 | | - raw_probs, |
424 | | - self.output_locations, |
425 | | - merged_shape, |
426 | | - dtype_, |
427 | | - ) |
| 453 | + output_locs = batch_data["output_locs"].numpy() |
| 454 | + |
| 455 | + batch_xs, batch_ys = np.min(output_locs[:, 0:2], axis=0) |
| 456 | + batch_xe, batch_ye = np.max(output_locs[:, 2:4], axis=0) |
| 457 | + |
| 458 | + merged_shape_batch = ( |
| 459 | + batch_ye - batch_ys, |
| 460 | + batch_xe - batch_xs, |
| 461 | + sample_output.shape[3], |
| 462 | + ) |
| 463 | + |
| 464 | + merged_output, merged_count = merge_all( |
| 465 | + batch_output, |
| 466 | + output_locs - np.array([batch_xs, batch_ys, batch_xs, batch_ys]), |
| 467 | + merged_shape_batch, |
| 468 | + sample_output.dtype, |
| 469 | + ) |
| 470 | + |
| 471 | + batch_ye, batch_xe = ( |
| 472 | + min(batch_ye, canvas.shape[0]), |
| 473 | + min(batch_xe, canvas.shape[1]), |
| 474 | + ) |
| 475 | + |
| 476 | + canvas[ |
| 477 | + batch_ys:batch_ye, |
| 478 | + batch_xs:batch_xe, |
| 479 | + :, |
| 480 | + ] += merged_output |
| 481 | + |
| 482 | + count[ |
| 483 | + batch_ys:batch_ye, |
| 484 | + batch_xs:batch_xe, |
| 485 | + ] += merged_count |
| 486 | + |
| 487 | + coordinates.append( |
| 488 | + da.from_array( |
| 489 | + self._get_coordinates(batch_data), |
| 490 | + ) |
| 491 | + ) |
| 492 | + |
| 493 | + if self.return_labels: |
| 494 | + labels.append(da.from_array(np.array(batch_data["label"]))) |
428 | 495 |
|
429 | | - raw_predictions["probabilities"] = da.from_delayed( |
430 | | - merged, |
431 | | - shape=merged_shape, |
432 | | - dtype=dtype_, |
| 496 | + raw_predictions["probabilities"] = canvas / da.maximum( |
| 497 | + count[:, :, np.newaxis], 1 |
433 | 498 | ) |
| 499 | + raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0) |
| 500 | + if self.return_labels: |
| 501 | + labels = [label.reshape(-1) for label in labels] |
| 502 | + raw_predictions["labels"] = da.concatenate(labels, axis=0) |
434 | 503 |
|
435 | 504 | return raw_predictions |
436 | 505 |
|
|
0 commit comments