|
24 | 24 | if TYPE_CHECKING: # pragma: no cover |
25 | 25 | import os |
26 | 26 |
|
| 27 | + from torch.utils.data import DataLoader |
| 28 | + |
27 | 29 | from tiatoolbox.annotation import AnnotationStore |
28 | 30 | from tiatoolbox.models.engine.io_config import IOSegmentorConfig |
29 | 31 | from tiatoolbox.models.models_abc import ModelABC |
30 | 32 | from tiatoolbox.type_hints import Resolution |
31 | 33 | from tiatoolbox.wsicore import WSIReader |
32 | 34 |
|
| 35 | + from .engine_abc import EngineABC, EngineABCRunParams |
| 36 | + |
33 | 37 |
|
34 | 38 | class SemanticSegmentorRunParams(PredictorRunParams): |
35 | 39 | """Class describing the input parameters for the :func:`EngineABC.run()` method. |
@@ -364,85 +368,144 @@ def get_dataloader( |
364 | 368 | patch_mode=patch_mode, |
365 | 369 | ) |
366 | 370 |
|
367 | | - def post_process_wsi( |
368 | | - self: SemanticSegmentor, |
369 | | - raw_predictions: Path, |
370 | | - **kwargs: Unpack[PredictorRunParams], |
| 371 | + def infer_wsi( |
| 372 | + self: EngineABC, |
| 373 | + dataloader: DataLoader, |
| 374 | + **kwargs: EngineABCRunParams, |
371 | 375 | ) -> Path: |
372 | | - """Returns an array from raw predictions. |
| 376 | + """Model inference on a WSI. |
373 | 377 |
|
374 | | - Merges raw predictions from individual patches into a single prediction array if |
375 | | - patch_mode is False. |
| 378 | + Args: |
| 379 | + dataloader (DataLoader): |
| 380 | + A torch dataloader to process WSIs. |
| 381 | +
|
| 382 | + save_path (Path): |
| 383 | + Path to save the intermediate output. The intermediate output is saved |
| 384 | + in a zarr file. |
| 385 | + **kwargs (EngineABCRunParams): |
| 386 | + Keyword Args to update setup_patch_dataset() method attributes. See |
| 387 | + :class:`EngineRunParams` for accepted keyword arguments. |
| 388 | +
|
| 389 | + Returns: |
| 390 | + save_path (Path): |
| 391 | + Path to zarr file where intermediate output is saved. |
376 | 392 |
|
377 | 393 | """ |
378 | 394 | _ = kwargs.get("return_probabilities") |
379 | | - progress_bar = None |
| 395 | + |
380 | 396 | tqdm = get_tqdm() |
381 | 397 |
|
382 | | - if self.verbose: |
383 | | - progress_bar = tqdm( |
384 | | - total=len(self.output_locations), |
385 | | - leave=False, |
386 | | - desc="Merging Patch Outputs", |
387 | | - ) |
| 398 | + progress_bar = ( |
| 399 | + tqdm(total=len(dataloader), leave=self.patch_mode, desc="Inferring patches") |
| 400 | + if self.verbose |
| 401 | + else None |
| 402 | + ) |
388 | 403 |
|
389 | | - num_post_proc_workers = self.num_post_proc_workers |
| 404 | + keys = ["coordinates"] |
390 | 405 |
|
391 | | - if num_post_proc_workers is not None and num_post_proc_workers > 0: |
392 | | - dask.config.set(scheduler="threads", num_workers=num_post_proc_workers) |
393 | | - else: |
394 | | - dask.config.set(scheduler="threads") |
| 406 | + if self.return_labels: |
| 407 | + keys.append("labels") |
395 | 408 |
|
396 | | - dask_patch_probabilities = raw_predictions["probabilities"] |
| 409 | + raw_predictions = dict.fromkeys(keys) |
397 | 410 |
|
398 | | - # --- Calculate canvas parameters from Dask array and locations --- |
399 | 411 | max_location = np.max(self.output_locations, axis=0) |
| 412 | + |
| 413 | + out_ = self.model.infer_batch( |
| 414 | + self.model, |
| 415 | + torch.from_numpy(dataloader.dataset[0]["image"][None, :, :, :]), |
| 416 | + device=self.device, |
| 417 | + ) |
| 418 | + |
400 | 419 | merged_shape = ( |
401 | 420 | max_location[3], |
402 | 421 | max_location[2], |
403 | | - dask_patch_probabilities.shape[3], |
| 422 | + out_["probabilities"].shape[3], |
404 | 423 | ) |
405 | 424 |
|
406 | 425 | # creating dask arrays for faster processing |
407 | 426 | merged_probabilities = da.zeros( |
408 | 427 | shape=merged_shape, |
409 | | - dtype=dask_patch_probabilities.dtype, |
| 428 | + dtype=out_["probabilities"].dtype, |
410 | 429 | chunks=merged_shape, |
411 | 430 | ) |
412 | 431 |
|
413 | 432 | merged_weights = da.zeros( |
414 | | - shape=merged_shape, |
| 433 | + shape=merged_shape[:2], |
415 | 434 | dtype=int, |
416 | | - chunks=merged_shape, |
| 435 | + chunks=merged_shape[:2], |
417 | 436 | ) |
418 | 437 |
|
419 | | - for idx, location in enumerate(self.output_locations): |
420 | | - start_x, start_y, end_x, end_y = location |
421 | | - patch_probs = dask_patch_probabilities[ |
422 | | - idx, 0 : end_y - start_y, 0 : end_x - start_x, : |
423 | | - ] |
424 | | - merged_probabilities[start_y:end_y, start_x:end_x, :] = ( |
425 | | - merged_probabilities[start_y:end_y, start_x:end_x, :] + patch_probs |
| 438 | + for _, batch_data in enumerate(dataloader): |
| 439 | + batch_output = self.model.infer_batch( |
| 440 | + self.model, |
| 441 | + batch_data["image"], |
| 442 | + device=self.device, |
426 | 443 | ) |
427 | | - merged_weights[start_y:end_y, start_x:end_x] = ( |
428 | | - merged_weights[start_y:end_y, start_x:end_x] + 1 |
| 444 | + |
| 445 | + batch_output["coordinates"] = self._get_coordinates(batch_data) |
| 446 | + |
| 447 | + if self.return_labels: # be careful of `s` |
| 448 | + if isinstance(batch_data["label"], torch.Tensor): |
| 449 | + batch_output["labels"] = batch_data["label"].numpy() |
| 450 | + else: |
| 451 | + batch_output["labels"] = np.array(batch_data["label"]) |
| 452 | + |
| 453 | + output_locs = batch_data["output_locs"] |
| 454 | + |
| 455 | + for idx, location in enumerate(output_locs.numpy()): |
| 456 | + start_x, start_y, end_x, end_y = location |
| 457 | + patch_probs = batch_output["probabilities"][ |
| 458 | + idx, 0 : end_y - start_y, 0 : end_x - start_x, : |
| 459 | + ] |
| 460 | + merged_probabilities[start_y:end_y, start_x:end_x, :] = ( |
| 461 | + merged_probabilities[start_y:end_y, start_x:end_x, :] + patch_probs |
| 462 | + ) |
| 463 | + merged_weights[start_y:end_y, start_x:end_x] = ( |
| 464 | + merged_weights[start_y:end_y, start_x:end_x] + 1 |
| 465 | + ) |
| 466 | + |
| 467 | + del batch_output["probabilities"] |
| 468 | + raw_predictions = self._update_model_output( |
| 469 | + raw_predictions=raw_predictions, |
| 470 | + raw_output=batch_output, |
429 | 471 | ) |
| 472 | + |
430 | 473 | if progress_bar: |
431 | 474 | progress_bar.update() |
432 | 475 |
|
| 476 | + merged_weights = da.maximum(merged_weights, 1) |
| 477 | + raw_predictions["probabilities"] = ( |
| 478 | + merged_probabilities / merged_weights[:, :, None] |
| 479 | + ) |
| 480 | + |
433 | 481 | if progress_bar: |
434 | 482 | progress_bar.close() |
435 | 483 |
|
436 | | - # Normalize where weight > 1 |
437 | | - final_probabilities_dask = da.where( |
438 | | - merged_weights > 1, |
439 | | - merged_probabilities / merged_weights, |
440 | | - merged_probabilities, |
441 | | - ) |
| 484 | + return raw_predictions |
| 485 | + |
| 486 | + def post_process_wsi( |
| 487 | + self: SemanticSegmentor, |
| 488 | + raw_predictions: Path, |
| 489 | + **kwargs: Unpack[PredictorRunParams], |
| 490 | + ) -> Path: |
| 491 | + """Returns an array from raw predictions. |
| 492 | +
|
| 493 | + Merges raw predictions from individual patches into a single prediction array if |
| 494 | + patch_mode is False. |
| 495 | +
|
| 496 | + """ |
| 497 | + _ = kwargs.get("return_probabilities") |
| 498 | + |
| 499 | + num_post_proc_workers = self.num_post_proc_workers |
| 500 | + |
| 501 | + if num_post_proc_workers is not None and num_post_proc_workers > 0: |
| 502 | + dask.config.set(scheduler="threads", num_workers=num_post_proc_workers) |
| 503 | + else: |
| 504 | + dask.config.set(scheduler="threads") |
442 | 505 |
|
443 | 506 | # Applying Post-Processing |
444 | 507 | raw_predictions["predictions"] = self.model.postproc_func( |
445 | | - final_probabilities_dask, |
| 508 | + raw_predictions["probabilities"], |
446 | 509 | ) |
447 | 510 |
|
448 | 511 | return raw_predictions |
|
0 commit comments