@@ -508,39 +508,6 @@ def get_dataloader(
508508 shuffle = False ,
509509 )
510510
511- @staticmethod
512- def _update_model_output (raw_predictions : dict , raw_output : dict ) -> dict :
513- """Append raw output from model inference to the prediction dictionary.
514-
515- This method wraps each batch output in a Dask array and concatenates it
516- with existing predictions for efficient memory usage and parallel computation.
517-
518- Args:
519- raw_predictions (dict):
520- Dictionary containing accumulated Dask arrays for each output key.
521- raw_output (dict):
522- Dictionary containing the current batch's output as NumPy arrays.
523-
524- Returns:
525- dict:
526- Updated dictionary with concatenated Dask arrays for each output key.
527-
528- """
529- for key , value in raw_output .items ():
530- delayed_value = delayed (value )
531- dask_array = da .from_delayed (
532- delayed_value , shape = value .shape , dtype = value .dtype
533- )
534-
535- if raw_predictions [key ] is None :
536- raw_predictions [key ] = dask_array
537- else :
538- raw_predictions [key ] = da .concatenate (
539- [raw_predictions [key ], dask_array ], axis = 0
540- )
541-
542- return raw_predictions
543-
544511 def _get_coordinates (self : EngineABC , batch_data : dict ) -> np .ndarray :
545512 """Extract coordinates for each image patch in a batch.
546513
@@ -564,53 +531,6 @@ def _get_coordinates(self: EngineABC, batch_data: dict) -> np.ndarray:
564531 return np .tile (coordinates , reps = (batch_data ["image" ].shape [0 ], 1 ))
565532 return np .array (batch_data ["coords" ])
566533
567- @delayed
568- def process_batch (
569- self : EngineABC ,
570- batch_data : dict ,
571- model : ModelABC ,
572- device : str ,
573- * ,
574- return_labels : bool ,
575- return_coordinates : bool ,
576- ) -> dict :
577- """Process a batch of images and return model predictions.
578-
579- This method performs inference on a batch of image patches,
580- optionally including coordinates and labels in the output.
581-
582- Args:
583- batch_data (dict):
584- Dictionary containing batch input data including images,
585- and optionally labels and coordinates.
586- model (ModelABC):
587- The PyTorch or TIAToolbox model used for inference.
588- device (str):
589- Device on which to run inference (e.g., "cpu", "cuda").
590- return_labels (bool):
591- Whether to include labels in the output.
592- return_coordinates (bool):
593- Whether to include coordinates in the output.
594-
595- Returns:
596- dict:
597- Dictionary containing model predictions, and optionally
598- coordinates and labels.
599-
600- """
601- batch_output = model .infer_batch (model , batch_data ["image" ], device = device )
602-
603- if return_coordinates :
604- batch_output ["coordinates" ] = self ._get_coordinates (batch_data )
605-
606- if return_labels :
607- if isinstance (batch_data ["label" ], torch .Tensor ):
608- batch_output ["labels" ] = batch_data ["label" ].numpy ()
609- else :
610- batch_output ["labels" ] = np .array (batch_data ["label" ])
611-
612- return batch_output
613-
614534 def infer_patches (
615535 self : EngineABC ,
616536 dataloader : DataLoader ,
@@ -792,8 +712,6 @@ def save_predictions(
792712 write_tasks = []
793713 for key in keys_to_compute :
794714 dask_array = processed_predictions [key ]
795- if dask_array is None :
796- continue
797715 task = dask_array .to_zarr (
798716 url = save_path ,
799717 component = key ,
0 commit comments