Skip to content

Commit e9ac4e2

Browse files
committed
Merge remote-tracking branch 'origin/dev-define-semantic-segmentor' into dev-define-semantic-segmentor
2 parents edde87a + d49c3d9 commit e9ac4e2

File tree

1 file changed

+0
-82
lines changed

1 file changed

+0
-82
lines changed

tiatoolbox/models/engine/engine_abc.py

Lines changed: 0 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -508,39 +508,6 @@ def get_dataloader(
508508
shuffle=False,
509509
)
510510

511-
@staticmethod
512-
def _update_model_output(raw_predictions: dict, raw_output: dict) -> dict:
513-
"""Append raw output from model inference to the prediction dictionary.
514-
515-
This method wraps each batch output in a Dask array and concatenates it
516-
with existing predictions for efficient memory usage and parallel computation.
517-
518-
Args:
519-
raw_predictions (dict):
520-
Dictionary containing accumulated Dask arrays for each output key.
521-
raw_output (dict):
522-
Dictionary containing the current batch's output as NumPy arrays.
523-
524-
Returns:
525-
dict:
526-
Updated dictionary with concatenated Dask arrays for each output key.
527-
528-
"""
529-
for key, value in raw_output.items():
530-
delayed_value = delayed(value)
531-
dask_array = da.from_delayed(
532-
delayed_value, shape=value.shape, dtype=value.dtype
533-
)
534-
535-
if raw_predictions[key] is None:
536-
raw_predictions[key] = dask_array
537-
else:
538-
raw_predictions[key] = da.concatenate(
539-
[raw_predictions[key], dask_array], axis=0
540-
)
541-
542-
return raw_predictions
543-
544511
def _get_coordinates(self: EngineABC, batch_data: dict) -> np.ndarray:
545512
"""Extract coordinates for each image patch in a batch.
546513
@@ -564,53 +531,6 @@ def _get_coordinates(self: EngineABC, batch_data: dict) -> np.ndarray:
564531
return np.tile(coordinates, reps=(batch_data["image"].shape[0], 1))
565532
return np.array(batch_data["coords"])
566533

567-
@delayed
568-
def process_batch(
569-
self: EngineABC,
570-
batch_data: dict,
571-
model: ModelABC,
572-
device: str,
573-
*,
574-
return_labels: bool,
575-
return_coordinates: bool,
576-
) -> dict:
577-
"""Process a batch of images and return model predictions.
578-
579-
This method performs inference on a batch of image patches,
580-
optionally including coordinates and labels in the output.
581-
582-
Args:
583-
batch_data (dict):
584-
Dictionary containing batch input data including images,
585-
and optionally labels and coordinates.
586-
model (ModelABC):
587-
The PyTorch or TIAToolbox model used for inference.
588-
device (str):
589-
Device on which to run inference (e.g., "cpu", "cuda").
590-
return_labels (bool):
591-
Whether to include labels in the output.
592-
return_coordinates (bool):
593-
Whether to include coordinates in the output.
594-
595-
Returns:
596-
dict:
597-
Dictionary containing model predictions, and optionally
598-
coordinates and labels.
599-
600-
"""
601-
batch_output = model.infer_batch(model, batch_data["image"], device=device)
602-
603-
if return_coordinates:
604-
batch_output["coordinates"] = self._get_coordinates(batch_data)
605-
606-
if return_labels:
607-
if isinstance(batch_data["label"], torch.Tensor):
608-
batch_output["labels"] = batch_data["label"].numpy()
609-
else:
610-
batch_output["labels"] = np.array(batch_data["label"])
611-
612-
return batch_output
613-
614534
def infer_patches(
615535
self: EngineABC,
616536
dataloader: DataLoader,
@@ -792,8 +712,6 @@ def save_predictions(
792712
write_tasks = []
793713
for key in keys_to_compute:
794714
dask_array = processed_predictions[key]
795-
if dask_array is None:
796-
continue
797715
task = dask_array.to_zarr(
798716
url=save_path,
799717
component=key,

0 commit comments

Comments
 (0)