Skip to content

Commit d45b214

Browse files
committed
ADD: Add docstrings
1 parent 8f42a87 commit d45b214

File tree

1 file changed

+111
-17
lines changed

1 file changed

+111
-17
lines changed

tiatoolbox/models/architecture/vanilla.py

Lines changed: 111 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,32 @@ def _get_architecture(
2222
weights: str or WeightsEnum = "DEFAULT",
2323
**kwargs: dict,
2424
) -> list[nn.Sequential, ...] | nn.Sequential:
25-
"""Get a model.
25+
"""Retrieve a CNN model architecture.
2626
27-
Model architectures are either already defined within torchvision or
28-
they can be custom-made within tiatoolbox.
27+
This function fetches a Convolutional Neural Network (CNN) model architecture,
28+
either predefined in torchvision or custom-made within tiatoolbox, for
29+
patch classification tasks.
2930
3031
Args:
3132
arch_name (str):
32-
Architecture name.
33+
Name of the architecture (e.g. 'resnet50', 'alexnet').
3334
weights (str or WeightsEnum):
34-
torchvision model weights (get_model_weights).
35-
kwargs (dict):
35+
Pretrained torchvision model weights to use (get_model_weights).
36+
Defaults to "DEFAULT".
37+
**kwargs (dict):
3638
Key-word arguments.
3739
3840
Returns:
39-
List of PyTorch network layers wrapped with `nn.Sequential`.
40-
https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html
41+
list[nn.Sequential, ...] | nn.Sequential:
42+
A list of PyTorch network layers wrapped with `nn.Sequential`.
4143
44+
Raises:
45+
ValueError:
46+
If `arch_name` is not supported.
47+
48+
Example:
49+
>>> model = _get_architecture("resnet18")
50+
>>> print(model)
4251
"""
4352
backbone_dict = {
4453
"alexnet": torch_models.alexnet,
@@ -85,7 +94,10 @@ def _get_timm_architecture(
8594
*,
8695
pretrained: bool,
8796
) -> list[nn.Sequential, ...] | nn.Sequential:
88-
"""Get architecture and weights for pathology-specific timm models.
97+
"""Retrieve a timm model architecture.
98+
99+
This function fetches a model architecture from the timm library, specifically for
100+
pathology-related tasks.
89101
90102
Args:
91103
arch_name (str):
@@ -94,12 +106,16 @@ def _get_timm_architecture(
94106
Whether to load pretrained weights.
95107
96108
Returns:
97-
A ready-to-use timm model.
109+
list[nn.Sequential, ...] | nn.Sequential:
110+
A ready-to-use timm model.
98111
99112
Raises:
100113
ValueError:
101114
If the backbone architecture is not supported.
102115
116+
Example:
117+
>>> model = _get_timm_architecture("UNI", pretrained=True)
118+
>>> print(model)
103119
"""
104120
if arch_name in [f"efficientnet_b{i}" for i in range(8)]:
105121
model = timm.create_model(arch_name, pretrained=pretrained)
@@ -177,6 +193,13 @@ def _postproc(image: np.ndarray) -> np.ndarray:
177193
178194
This simply applies argmax along last axis of the input.
179195
196+
Args:
197+
image (np.ndarray):
198+
The input image array.
199+
200+
Returns:
201+
np.ndarray:
202+
The post-processed image array.
180203
"""
181204
return np.argmax(image, axis=-1)
182205

@@ -197,8 +220,15 @@ def _infer_batch(
197220
A batch of data generated by
198221
`torch.utils.data.DataLoader`.
199222
device (str):
200-
Transfers model to the specified device. Default is "cpu".
223+
Transfers model to the specified device. Default is "cpu".
224+
225+
Returns:
226+
dict[str, np.ndarray]:
227+
The model predictions as a NumPy array.
201228
229+
Example:
230+
>>> output = _infer_batch(model, batch_data, "cuda")
231+
>>> print(output)
202232
"""
203233
img_patches_device = batch_data.to(device=device).type(
204234
torch.float32,
@@ -217,11 +247,14 @@ def _infer_batch(
217247
class CNNModel(ModelABC):
218248
"""Retrieve the model backbone and attach an extra FCN to perform classification.
219249
250+
This class initializes a Convolutional Neural Network (CNN) model with a specified
251+
backbone and attaches a fully connected layer for classification tasks.
252+
220253
Args:
221254
backbone (str):
222-
Model name.
255+
Name of the CNN model backbone (e.g. "resnet18", "densenet121").
223256
num_classes (int):
224-
Number of classes output by model.
257+
Number of classes output by model. Defaults to 1.
225258
226259
Attributes:
227260
num_classes (int):
@@ -231,9 +264,12 @@ class CNNModel(ModelABC):
231264
pool (nn.Module):
232265
Type of pooling applied after feature extraction.
233266
classifier (nn.Module):
234-
Linear classifier module used to map the features to the
235-
output.
267+
Linear classifier module used to map the features to the output.
236268
269+
Example:
270+
>>> model = CNNModel("resnet18", num_classes=2)
271+
>>> output = model(torch.randn(1, 3, 224, 224))
272+
>>> print(output.shape)
237273
"""
238274

239275
def __init__(self: CNNModel, backbone: str, num_classes: int = 1) -> None:
@@ -257,6 +293,9 @@ def forward(self: CNNModel, imgs: torch.Tensor) -> torch.Tensor:
257293
imgs (torch.Tensor):
258294
Model input.
259295
296+
Returns:
297+
torch.Tensor:
298+
The output logits after passing through the model.
260299
"""
261300
feat = self.feat_extract(imgs)
262301
gap_feat = self.pool(feat)
@@ -270,6 +309,13 @@ def postproc(image: np.ndarray) -> np.ndarray:
270309
271310
This simply applies argmax along last axis of the input.
272311
312+
Args:
313+
image (np.ndarray):
314+
The input image array.
315+
316+
Returns:
317+
np.ndarray:
318+
The post-processed image array.
273319
"""
274320
return _postproc(image=image)
275321

@@ -292,6 +338,9 @@ def infer_batch(
292338
device (str):
293339
Transfers model to the specified device. Default is "cpu".
294340
341+
Example:
342+
>>> output = _infer_batch(model, batch_data, "cuda")
343+
>>> print(output)
295344
"""
296345
return _infer_batch(model=model, batch_data=batch_data, device=device)
297346

@@ -327,6 +376,11 @@ class TimmModel(ModelABC):
327376
classifier (nn.Module):
328377
Linear classifier module used to map the features to the
329378
output.
379+
380+
Example:
381+
>>> model = TimmModel("UNI", pretrained=True)
382+
>>> output = model(torch.randn(1, 3, 224, 224))
383+
>>> print(output.shape)
330384
"""
331385

332386
def __init__(
@@ -357,6 +411,9 @@ def forward(self: TimmModel, imgs: torch.Tensor) -> torch.Tensor:
357411
imgs (torch.Tensor):
358412
Model input.
359413
414+
Returns:
415+
torch.Tensor:
416+
The output logits after passing through the model.
360417
"""
361418
feat = self.feat_extract(imgs)
362419
feat = torch.flatten(feat, 1)
@@ -369,6 +426,14 @@ def postproc(image: np.ndarray) -> np.ndarray:
369426
370427
This simply applies argmax along last axis of the input.
371428
429+
Args:
430+
image (np.ndarray):
431+
The input image array.
432+
433+
Returns:
434+
np.ndarray:
435+
The post-processed image array.
436+
372437
"""
373438
return _postproc(image=image)
374439

@@ -391,6 +456,14 @@ def infer_batch(
391456
device (str):
392457
Transfers model to the specified device. Default is "cpu".
393458
459+
Returns:
460+
dict[str, np.ndarray]:
461+
The model predictions as a NumPy array.
462+
463+
Example:
464+
>>> output = _infer_batch(model, batch_data, "cuda")
465+
>>> print(output)
466+
394467
"""
395468
return _infer_batch(model=model, batch_data=batch_data, device=device)
396469

@@ -423,6 +496,12 @@ class CNNBackbone(ModelABC):
423496
- "mobilenet_v3_large"
424497
- "mobilenet_v3_small"
425498
499+
Attributes:
500+
feat_extract (nn.Module):
501+
Backbone CNN model.
502+
pool (nn.Module):
503+
Type of pooling applied after feature extraction.
504+
426505
Examples:
427506
>>> # Creating resnet50 architecture from default pytorch
428507
>>> # without the classification layer with its associated
@@ -452,6 +531,9 @@ def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor:
452531
imgs (torch.Tensor):
453532
Model input.
454533
534+
Returns:
535+
torch.Tensor:
536+
The extracted features.
455537
"""
456538
feat = self.feat_extract(imgs)
457539
gap_feat = self.pool(feat)
@@ -480,6 +562,9 @@ def infer_batch(
480562
list[dict[str, np.ndarray]]:
481563
list of dictionary values with numpy arrays.
482564
565+
Example:
566+
>>> output = CNNBackbone.infer_batch(model, batch_data, "cuda")
567+
>>> print(output)
483568
"""
484569
return [_infer_batch(model=model, batch_data=batch_data, device=device)]
485570

@@ -491,8 +576,7 @@ class TimmBackbone(ModelABC):
491576
492577
Args:
493578
backbone (str):
494-
Model name. Currently, the tool supports following
495-
model names and their default associated weights from timm.
579+
Model name. Supported model names include:
496580
- "efficientnet_b{i}" for i in [0, 1, ..., 7]
497581
- "UNI"
498582
- "prov-gigapath"
@@ -503,6 +587,10 @@ class TimmBackbone(ModelABC):
503587
pretrained (bool, keyword-only):
504588
Whether to load pretrained weights.
505589
590+
Attributes:
591+
feat_extract (nn.Module):
592+
Backbone timm model.
593+
506594
Examples:
507595
>>> # Creating UNI tile encoder
508596
>>> model = TimmBackbone(backbone="UNI", pretrained=True)
@@ -531,6 +619,9 @@ def forward(self: TimmBackbone, imgs: torch.Tensor) -> torch.Tensor:
531619
imgs (torch.Tensor):
532620
Model input.
533621
622+
Returns:
623+
torch.Tensor:
624+
The extracted features.
534625
"""
535626
feats = self.feat_extract(imgs)
536627
return torch.flatten(feats, 1)
@@ -558,5 +649,8 @@ def infer_batch(
558649
list[dict[str, np.ndarray]]:
559650
list of dictionary values with numpy arrays.
560651
652+
Example:
653+
>>> output = TimmBackbone.infer_batch(model, batch_data, "cuda")
654+
>>> print(output)
561655
"""
562656
return [_infer_batch(model=model, batch_data=batch_data, device=device)]

0 commit comments

Comments
 (0)