@@ -700,45 +700,11 @@ def save_predictions(
700700 keys_to_compute = [k for k in processed_predictions if k not in self .drop_keys ]
701701
702702 if output_type .lower () == "zarr" :
703- if is_zarr (save_path ):
704- zarr_group = zarr .open (save_path , mode = "r" )
705- keys_to_compute = [k for k in keys_to_compute if k not in zarr_group ]
706- write_tasks = []
707- for key in keys_to_compute :
708- dask_output = processed_predictions [key ]
709- if isinstance (dask_output , da .Array ):
710- dask_output = dask_output .rechunk ("auto" )
711- task = dask_output .to_zarr (
712- url = save_path , component = key , compute = False , object_codec = None
713- )
714- write_tasks .append (task )
715-
716- if isinstance (dask_output , list ) and all (
717- isinstance (dask_array , da .Array ) for dask_array in dask_output
718- ):
719- for i , dask_array in enumerate (dask_output ):
720- object_codec = (
721- Pickle () if dask_array .dtype == "object" else None
722- )
723- task = dask_array .to_zarr (
724- url = save_path ,
725- component = f"{ key } /{ i } " ,
726- compute = False ,
727- object_codec = object_codec ,
728- )
729- write_tasks .append (task )
730-
731- msg = f"Saving output to { save_path } ."
732- logger .info (msg = msg )
733- with ProgressBar ():
734- compute (* write_tasks )
735-
736- zarr_group = zarr .open (save_path , mode = "r+" )
737- for key in self .drop_keys :
738- if key in zarr_group :
739- del zarr_group [key ]
740-
741- return save_path
703+ return self .save_predictions_as_zarr (
704+ processed_predictions = processed_predictions ,
705+ save_path = save_path ,
706+ keys_to_compute = keys_to_compute ,
707+ )
742708
743709 values_to_compute = [processed_predictions [k ] for k in keys_to_compute ]
744710
@@ -771,6 +737,68 @@ def save_predictions(
771737 msg = f"Unsupported output type: { output_type } "
772738 raise TypeError (msg )
773739
740+ def save_predictions_as_zarr (
741+ self : EngineABC ,
742+ processed_predictions : dict ,
743+ save_path : Path ,
744+ keys_to_compute : list ,
745+ ) -> Path :
746+ """Save model predictions as a zarr file.
747+
748+ This method saves the processed predictions to a zarr file at the specified
749+ path.
750+
751+ Args:
752+ processed_predictions (dict):
753+ Dictionary containing processed model predictions.
754+ save_path (Path):
755+ Path to save the zarr file.
756+ keys_to_compute (list):
757+ List of keys in processed_predictions to save.
758+
759+ Returns:
760+ save_path (Path):
761+ Path to the saved zarr file.
762+
763+ """
764+ if is_zarr (save_path ):
765+ zarr_group = zarr .open (save_path , mode = "r" )
766+ keys_to_compute = [k for k in keys_to_compute if k not in zarr_group ]
767+ write_tasks = []
768+ for key in keys_to_compute :
769+ dask_output = processed_predictions [key ]
770+ if isinstance (dask_output , da .Array ):
771+ dask_output = dask_output .rechunk ("auto" )
772+ task = dask_output .to_zarr (
773+ url = save_path , component = key , compute = False , object_codec = None
774+ )
775+ write_tasks .append (task )
776+
777+ if isinstance (dask_output , list ) and all (
778+ isinstance (dask_array , da .Array ) for dask_array in dask_output
779+ ):
780+ for i , dask_array in enumerate (dask_output ):
781+ object_codec = Pickle () if dask_array .dtype == "object" else None
782+ task = dask_array .to_zarr (
783+ url = save_path ,
784+ component = f"{ key } /{ i } " ,
785+ compute = False ,
786+ object_codec = object_codec ,
787+ )
788+ write_tasks .append (task )
789+
790+ msg = f"Saving output to { save_path } ."
791+ logger .info (msg = msg )
792+ with ProgressBar ():
793+ compute (* write_tasks )
794+
795+ zarr_group = zarr .open (save_path , mode = "r+" )
796+ for key in self .drop_keys :
797+ if key in zarr_group :
798+ del zarr_group [key ]
799+
800+ return save_path
801+
774802 def infer_wsi (
775803 self : EngineABC ,
776804 dataloader : DataLoader ,
0 commit comments