1- """Defines PatchPredictor Engine."""
1+ """Defines the PatchPredictor engine for patch-level inference in digital pathology.
2+
3+ This module implements the PatchPredictor class, which extends the EngineABC base
4+ class to support patch-based and whole slide image (WSI) inference using deep learning
5+ models from TIAToolbox. It provides utilities for model initialization, post-processing,
6+ and output management, including support for multiple output formats.
7+
8+ Classes:
9+ - PatchPredictor:
10+ Engine for performing patch-level predictions.
11+ - PredictorRunParams:
12+ TypedDict for configuring runtime parameters.
13+
14+ Example:
15+ >>> images = [np.ndarray, np.ndarray]
16+ >>> predictor = PatchPredictor(model="resnet18-kather100k")
17+ >>> output = predictor.run(images, patch_mode=True)
18+
19+ """
220
321from __future__ import annotations
422
@@ -255,7 +273,7 @@ def __init__(
255273 Number of workers for data loading. Default is 0.
256274 weights (str | Path | None): Path to model weights.
257275 If None, default weights are used.
258- device (str): D
276+ device (str):
259277 device to run the model on (e.g., "cpu", "cuda"). Default is "cpu".
260278 verbose (bool):
261279 Whether to enable verbose logging. Default is True.
@@ -272,7 +290,7 @@ def __init__(
272290
273291 def post_process_patches (
274292 self : PatchPredictor ,
275- raw_predictions : da .Array | np . ndarray ,
293+ raw_predictions : da .Array ,
276294 prediction_shape : tuple [int , ...],
277295 prediction_dtype : type ,
278296 ** kwargs : Unpack [PredictorRunParams ],
@@ -284,15 +302,14 @@ def post_process_patches(
284302 efficient computation and memory handling.
285303
286304 Args:
287- raw_predictions (dask.array .Array | np.ndarray):
305+ raw_predictions (da .Array | np.ndarray):
288306 Raw model predictions.
289307 prediction_shape (tuple[int, ...]):
290308 Expected shape of the prediction output.
291309 prediction_dtype (type):
292310 Data type of the prediction output.
293311 **kwargs (PredictorRunParams):
294- Additional runtime parameters, including
295- `return_probabilities`.
312+ Additional runtime parameters, including `return_probabilities`.
296313
297314 Returns:
298315 dask.array.Array: Post-processed predictions as a Dask array.
@@ -305,7 +322,7 @@ def post_process_patches(
305322
306323 def post_process_wsi (
307324 self : PatchPredictor ,
308- raw_predictions : da .Array | np . ndarray ,
325+ raw_predictions : da .Array ,
309326 prediction_shape : tuple [int , ...],
310327 prediction_dtype : type ,
311328 ** kwargs : Unpack [PredictorRunParams ],
@@ -318,15 +335,14 @@ def post_process_wsi(
318335 `post_process_patches()`.
319336
320337 Args:
321- raw_predictions (dask.array.Array | np.ndarray ):
338+ raw_predictions (dask.array.Array):
322339 Raw model predictions.
323340 prediction_shape (tuple[int, ...]):
324341 Expected shape of the prediction output.
325342 prediction_dtype (type):
326343 Data type of the prediction output.
327344 **kwargs (PredictorRunParams):
328- Additional runtime parameters, including
329- `return_probabilities`.
345+ Additional runtime parameters, including `return_probabilities`.
330346
331347 Returns:
332348 dask.array.Array: Post-processed predictions as a Dask array.
0 commit comments