Skip to content

Commit d44d943

Browse files
committed
fix deepsource
1 parent d2eae54 commit d44d943

File tree

1 file changed

+67
-39
lines changed

1 file changed

+67
-39
lines changed

tiatoolbox/models/engine/engine_abc.py

Lines changed: 67 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -700,45 +700,11 @@ def save_predictions(
700700
keys_to_compute = [k for k in processed_predictions if k not in self.drop_keys]
701701

702702
if output_type.lower() == "zarr":
703-
if is_zarr(save_path):
704-
zarr_group = zarr.open(save_path, mode="r")
705-
keys_to_compute = [k for k in keys_to_compute if k not in zarr_group]
706-
write_tasks = []
707-
for key in keys_to_compute:
708-
dask_output = processed_predictions[key]
709-
if isinstance(dask_output, da.Array):
710-
dask_output = dask_output.rechunk("auto")
711-
task = dask_output.to_zarr(
712-
url=save_path, component=key, compute=False, object_codec=None
713-
)
714-
write_tasks.append(task)
715-
716-
if isinstance(dask_output, list) and all(
717-
isinstance(dask_array, da.Array) for dask_array in dask_output
718-
):
719-
for i, dask_array in enumerate(dask_output):
720-
object_codec = (
721-
Pickle() if dask_array.dtype == "object" else None
722-
)
723-
task = dask_array.to_zarr(
724-
url=save_path,
725-
component=f"{key}/{i}",
726-
compute=False,
727-
object_codec=object_codec,
728-
)
729-
write_tasks.append(task)
730-
731-
msg = f"Saving output to {save_path}."
732-
logger.info(msg=msg)
733-
with ProgressBar():
734-
compute(*write_tasks)
735-
736-
zarr_group = zarr.open(save_path, mode="r+")
737-
for key in self.drop_keys:
738-
if key in zarr_group:
739-
del zarr_group[key]
740-
741-
return save_path
703+
return self.save_predictions_as_zarr(
704+
processed_predictions=processed_predictions,
705+
save_path=save_path,
706+
keys_to_compute=keys_to_compute,
707+
)
742708

743709
values_to_compute = [processed_predictions[k] for k in keys_to_compute]
744710

@@ -771,6 +737,68 @@ def save_predictions(
771737
msg = f"Unsupported output type: {output_type}"
772738
raise TypeError(msg)
773739

740+
def save_predictions_as_zarr(
741+
self: EngineABC,
742+
processed_predictions: dict,
743+
save_path: Path,
744+
keys_to_compute: list,
745+
) -> Path:
746+
"""Save model predictions as a zarr file.
747+
748+
This method saves the processed predictions to a zarr file at the specified
749+
path.
750+
751+
Args:
752+
processed_predictions (dict):
753+
Dictionary containing processed model predictions.
754+
save_path (Path):
755+
Path to save the zarr file.
756+
keys_to_compute (list):
757+
List of keys in processed_predictions to save.
758+
759+
Returns:
760+
save_path (Path):
761+
Path to the saved zarr file.
762+
763+
"""
764+
if is_zarr(save_path):
765+
zarr_group = zarr.open(save_path, mode="r")
766+
keys_to_compute = [k for k in keys_to_compute if k not in zarr_group]
767+
write_tasks = []
768+
for key in keys_to_compute:
769+
dask_output = processed_predictions[key]
770+
if isinstance(dask_output, da.Array):
771+
dask_output = dask_output.rechunk("auto")
772+
task = dask_output.to_zarr(
773+
url=save_path, component=key, compute=False, object_codec=None
774+
)
775+
write_tasks.append(task)
776+
777+
if isinstance(dask_output, list) and all(
778+
isinstance(dask_array, da.Array) for dask_array in dask_output
779+
):
780+
for i, dask_array in enumerate(dask_output):
781+
object_codec = Pickle() if dask_array.dtype == "object" else None
782+
task = dask_array.to_zarr(
783+
url=save_path,
784+
component=f"{key}/{i}",
785+
compute=False,
786+
object_codec=object_codec,
787+
)
788+
write_tasks.append(task)
789+
790+
msg = f"Saving output to {save_path}."
791+
logger.info(msg=msg)
792+
with ProgressBar():
793+
compute(*write_tasks)
794+
795+
zarr_group = zarr.open(save_path, mode="r+")
796+
for key in self.drop_keys:
797+
if key in zarr_group:
798+
del zarr_group[key]
799+
800+
return save_path
801+
774802
def infer_wsi(
775803
self: EngineABC,
776804
dataloader: DataLoader,

0 commit comments

Comments
 (0)