Skip to content

Commit 0242512

Browse files
committed
📝 Update patch_predictor.py docstrings
1 parent ba28d91 commit 0242512

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

tiatoolbox/models/engine/patch_predictor.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,22 @@
1-
"""Defines PatchPredictor Engine."""
1+
"""Defines the PatchPredictor engine for patch-level inference in digital pathology.
2+
3+
This module implements the PatchPredictor class, which extends the EngineABC base
4+
class to support patch-based and whole slide image (WSI) inference using deep learning
5+
models from TIAToolbox. It provides utilities for model initialization, post-processing,
6+
and output management, including support for multiple output formats.
7+
8+
Classes:
9+
- PatchPredictor:
10+
Engine for performing patch-level predictions.
11+
- PredictorRunParams:
12+
TypedDict for configuring runtime parameters.
13+
14+
Example:
15+
>>> images = [np.ndarray, np.ndarray]
16+
>>> predictor = PatchPredictor(model="resnet18-kather100k")
17+
>>> output = predictor.run(images, patch_mode=True)
18+
19+
"""
220

321
from __future__ import annotations
422

@@ -255,7 +273,7 @@ def __init__(
255273
Number of workers for data loading. Default is 0.
256274
weights (str | Path | None): Path to model weights.
257275
If None, default weights are used.
258-
device (str): D
276+
device (str):
259277
device to run the model on (e.g., "cpu", "cuda"). Default is "cpu".
260278
verbose (bool):
261279
Whether to enable verbose logging. Default is True.
@@ -272,7 +290,7 @@ def __init__(
272290

273291
def post_process_patches(
274292
self: PatchPredictor,
275-
raw_predictions: da.Array | np.ndarray,
293+
raw_predictions: da.Array,
276294
prediction_shape: tuple[int, ...],
277295
prediction_dtype: type,
278296
**kwargs: Unpack[PredictorRunParams],
@@ -284,15 +302,14 @@ def post_process_patches(
284302
efficient computation and memory handling.
285303
286304
Args:
287-
raw_predictions (dask.array.Array | np.ndarray):
305+
raw_predictions (da.Array | np.ndarray):
288306
Raw model predictions.
289307
prediction_shape (tuple[int, ...]):
290308
Expected shape of the prediction output.
291309
prediction_dtype (type):
292310
Data type of the prediction output.
293311
**kwargs (PredictorRunParams):
294-
Additional runtime parameters, including
295-
`return_probabilities`.
312+
Additional runtime parameters, including `return_probabilities`.
296313
297314
Returns:
298315
dask.array.Array: Post-processed predictions as a Dask array.
@@ -305,7 +322,7 @@ def post_process_patches(
305322

306323
def post_process_wsi(
307324
self: PatchPredictor,
308-
raw_predictions: da.Array | np.ndarray,
325+
raw_predictions: da.Array,
309326
prediction_shape: tuple[int, ...],
310327
prediction_dtype: type,
311328
**kwargs: Unpack[PredictorRunParams],
@@ -318,15 +335,14 @@ def post_process_wsi(
318335
`post_process_patches()`.
319336
320337
Args:
321-
raw_predictions (dask.array.Array | np.ndarray):
338+
raw_predictions (dask.array.Array):
322339
Raw model predictions.
323340
prediction_shape (tuple[int, ...]):
324341
Expected shape of the prediction output.
325342
prediction_dtype (type):
326343
Data type of the prediction output.
327344
**kwargs (PredictorRunParams):
328-
Additional runtime parameters, including
329-
`return_probabilities`.
345+
Additional runtime parameters, including `return_probabilities`.
330346
331347
Returns:
332348
dask.array.Array: Post-processed predictions as a Dask array.

0 commit comments

Comments
 (0)