|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import copy |
| 6 | +import shutil |
6 | 7 | from abc import ABC, abstractmethod |
7 | 8 | from pathlib import Path |
8 | 9 | from typing import TYPE_CHECKING, TypedDict |
@@ -639,7 +640,7 @@ def post_process_patches( |
639 | 640 |
|
640 | 641 | def save_predictions( |
641 | 642 | self: EngineABC, |
642 | | - processed_predictions: dict, |
| 643 | + processed_predictions: dict | Path, |
643 | 644 | output_type: str, |
644 | 645 | save_dir: Path | None = None, |
645 | 646 | **kwargs: dict, |
@@ -681,16 +682,23 @@ def save_predictions( |
681 | 682 | # class_dict set from kwargs |
682 | 683 | class_dict = kwargs.get("class_dict") |
683 | 684 |
|
| 685 | + processed_predictions_path: str | Path | None = None |
| 686 | + |
684 | 687 | # Need to add support for zarr conversion. |
685 | 688 | if self.cache_mode: |
| 689 | + processed_predictions_path = processed_predictions |
686 | 690 | processed_predictions = zarr.open(processed_predictions, mode="r") |
687 | 691 |
|
688 | | - return dict_to_store( |
| 692 | + out_file = dict_to_store( |
689 | 693 | processed_predictions, |
690 | 694 | scale_factor, |
691 | 695 | class_dict, |
692 | 696 | save_path, |
693 | 697 | ) |
| 698 | + if processed_predictions_path is not None: |
| 699 | + shutil.rmtree(processed_predictions_path) |
| 700 | + |
| 701 | + return out_file |
694 | 702 |
|
695 | 703 | return ( |
696 | 704 | dict_to_zarr( |
@@ -1057,15 +1065,22 @@ def _run_wsi_mode( |
1057 | 1065 | dataloader_units = dataloader.dataset.units |
1058 | 1066 | dataloader_resolution = dataloader.dataset.resolution |
1059 | 1067 |
|
1060 | | - slide_resolution = (1.0, 1.0) |
1061 | | - if dataloader_units != "baseline": |
| 1068 | + # if dataloader units is baseline slide resolution is 1.0. |
| 1069 | + # in this case dataloader resolution / slide resolution will be |
| 1070 | + # equal to dataloader resolution. |
| 1071 | + scale_factor = dataloader_resolution |
| 1072 | + |
| 1073 | + if dataloader_units == "mpp": |
1062 | 1074 | wsimeta_dict = dataloader.dataset.reader.info.as_dict() |
1063 | 1075 | slide_resolution = wsimeta_dict[dataloader_units] |
| 1076 | + scale_factor = tuple(np.divide(slide_resolution, dataloader_resolution)) |
1064 | 1077 |
|
1065 | | - scale_factor = tuple(np.divide(slide_resolution, dataloader_resolution)) |
1066 | | - |
1067 | | - if dataloader_units != "mpp": |
1068 | | - scale_factor = tuple(np.divide(dataloader_resolution, slide_resolution)) |
| 1078 | + if dataloader_units == "level": |
| 1079 | + wsimeta_dict = dataloader.dataset.reader.info.as_dict() |
| 1080 | + downsample_ratio = wsimeta_dict["level_downsamples"][ |
| 1081 | + dataloader_resolution |
| 1082 | + ] |
| 1083 | + scale_factor = (1.0 / downsample_ratio, 1.0 / downsample_ratio) |
1069 | 1084 |
|
1070 | 1085 | raw_predictions = self.infer_wsi( |
1071 | 1086 | dataloader=dataloader, |
|
0 commit comments