@@ -22,23 +22,32 @@ def _get_architecture(
2222 weights : str or WeightsEnum = "DEFAULT" ,
2323 ** kwargs : dict ,
2424) -> list [nn .Sequential , ...] | nn .Sequential :
25- """Get a model.
25+ """Retrieve a CNN model architecture .
2626
27- Model architectures are either already defined within torchvision or
28- they can be custom-made within tiatoolbox.
27+ This function fetches a Convolutional Neural Network (CNN) model architecture,
28+ either predefined in torchvision or custom-made within tiatoolbox, for
29+ patch classification tasks.
2930
3031 Args:
3132 arch_name (str):
32- Architecture name .
33+ Name of the architecture (e.g. 'resnet50', 'alexnet') .
3334 weights (str or WeightsEnum):
34- torchvision model weights (get_model_weights).
35- kwargs (dict):
35+ Pretrained torchvision model weights to use (get_model_weights).
36+ Defaults to "DEFAULT".
37+ **kwargs (dict):
3638 Key-word arguments.
3739
3840 Returns:
39- List of PyTorch network layers wrapped with ` nn.Sequential`.
40- https://pytorch.org/docs/stable/generated/torch. nn.Sequential.html
41+ list[nn.Sequential, ...] | nn.Sequential:
42+ A list of PyTorch network layers wrapped with ` nn.Sequential`.
4143
44+ Raises:
45+ ValueError:
46+ If `arch_name` is not supported.
47+
48+ Example:
49+ >>> model = _get_architecture("resnet18")
50+ >>> print(model)
4251 """
4352 backbone_dict = {
4453 "alexnet" : torch_models .alexnet ,
@@ -85,7 +94,10 @@ def _get_timm_architecture(
8594 * ,
8695 pretrained : bool ,
8796) -> list [nn .Sequential , ...] | nn .Sequential :
88- """Get architecture and weights for pathology-specific timm models.
97+ """Retrieve a timm model architecture.
98+
99+ This function fetches a model architecture from the timm library, specifically for
100+ pathology-related tasks.
89101
90102 Args:
91103 arch_name (str):
@@ -94,12 +106,16 @@ def _get_timm_architecture(
94106 Whether to load pretrained weights.
95107
96108 Returns:
97- A ready-to-use timm model.
109+ list[nn.Sequential, ...] | nn.Sequential:
110+ A ready-to-use timm model.
98111
99112 Raises:
100113 ValueError:
101114 If the backbone architecture is not supported.
102115
116+ Example:
117+ >>> model = _get_timm_architecture("UNI", pretrained=True)
118+ >>> print(model)
103119 """
104120 if arch_name in [f"efficientnet_b{ i } " for i in range (8 )]:
105121 model = timm .create_model (arch_name , pretrained = pretrained )
@@ -177,6 +193,13 @@ def _postproc(image: np.ndarray) -> np.ndarray:
177193
178194 This simply applies argmax along last axis of the input.
179195
196+ Args:
197+ image (np.ndarray):
198+ The input image array.
199+
200+ Returns:
201+ np.ndarray:
202+ The post-processed image array.
180203 """
181204 return np .argmax (image , axis = - 1 )
182205
@@ -197,8 +220,15 @@ def _infer_batch(
197220 A batch of data generated by
198221 `torch.utils.data.DataLoader`.
199222 device (str):
200- Transfers model to the specified device. Default is "cpu".
223+ Transfers model to the specified device. Default is "cpu".
224+
225+ Returns:
226+ dict[str, np.ndarray]:
227+ The model predictions as a NumPy array.
201228
229+ Example:
230+ >>> output = _infer_batch(model, batch_data, "cuda")
231+ >>> print(output)
202232 """
203233 img_patches_device = batch_data .to (device = device ).type (
204234 torch .float32 ,
@@ -217,11 +247,14 @@ def _infer_batch(
217247class CNNModel (ModelABC ):
218248 """Retrieve the model backbone and attach an extra FCN to perform classification.
219249
250+ This class initializes a Convolutional Neural Network (CNN) model with a specified
251+ backbone and attaches a fully connected layer for classification tasks.
252+
220253 Args:
221254 backbone (str):
222- Model name .
255+ Name of the CNN model backbone (e.g. "resnet18", "densenet121") .
223256 num_classes (int):
224- Number of classes output by model.
257+ Number of classes output by model. Defaults to 1.
225258
226259 Attributes:
227260 num_classes (int):
@@ -231,9 +264,12 @@ class CNNModel(ModelABC):
231264 pool (nn.Module):
232265 Type of pooling applied after feature extraction.
233266 classifier (nn.Module):
234- Linear classifier module used to map the features to the
235- output.
267+ Linear classifier module used to map the features to the output.
236268
269+ Example:
270+ >>> model = CNNModel("resnet18", num_classes=2)
271+ >>> output = model(torch.randn(1, 3, 224, 224))
272+ >>> print(output.shape)
237273 """
238274
239275 def __init__ (self : CNNModel , backbone : str , num_classes : int = 1 ) -> None :
@@ -257,6 +293,9 @@ def forward(self: CNNModel, imgs: torch.Tensor) -> torch.Tensor:
257293 imgs (torch.Tensor):
258294 Model input.
259295
296+ Returns:
297+ torch.Tensor:
298+ The output logits after passing through the model.
260299 """
261300 feat = self .feat_extract (imgs )
262301 gap_feat = self .pool (feat )
@@ -270,6 +309,13 @@ def postproc(image: np.ndarray) -> np.ndarray:
270309
271310 This simply applies argmax along last axis of the input.
272311
312+ Args:
313+ image (np.ndarray):
314+ The input image array.
315+
316+ Returns:
317+ np.ndarray:
318+ The post-processed image array.
273319 """
274320 return _postproc (image = image )
275321
@@ -292,6 +338,9 @@ def infer_batch(
292338 device (str):
293339 Transfers model to the specified device. Default is "cpu".
294340
341+ Example:
342+ >>> output = _infer_batch(model, batch_data, "cuda")
343+ >>> print(output)
295344 """
296345 return _infer_batch (model = model , batch_data = batch_data , device = device )
297346
@@ -327,6 +376,11 @@ class TimmModel(ModelABC):
327376 classifier (nn.Module):
328377 Linear classifier module used to map the features to the
329378 output.
379+
380+ Example:
381+ >>> model = TimmModel("UNI", pretrained=True)
382+ >>> output = model(torch.randn(1, 3, 224, 224))
383+ >>> print(output.shape)
330384 """
331385
332386 def __init__ (
@@ -357,6 +411,9 @@ def forward(self: TimmModel, imgs: torch.Tensor) -> torch.Tensor:
357411 imgs (torch.Tensor):
358412 Model input.
359413
414+ Returns:
415+ torch.Tensor:
416+ The output logits after passing through the model.
360417 """
361418 feat = self .feat_extract (imgs )
362419 feat = torch .flatten (feat , 1 )
@@ -369,6 +426,14 @@ def postproc(image: np.ndarray) -> np.ndarray:
369426
370427 This simply applies argmax along last axis of the input.
371428
429+ Args:
430+ image (np.ndarray):
431+ The input image array.
432+
433+ Returns:
434+ np.ndarray:
435+ The post-processed image array.
436+
372437 """
373438 return _postproc (image = image )
374439
@@ -391,6 +456,14 @@ def infer_batch(
391456 device (str):
392457 Transfers model to the specified device. Default is "cpu".
393458
459+ Returns:
460+ dict[str, np.ndarray]:
461+ The model predictions as a NumPy array.
462+
463+ Example:
464+ >>> output = _infer_batch(model, batch_data, "cuda")
465+ >>> print(output)
466+
394467 """
395468 return _infer_batch (model = model , batch_data = batch_data , device = device )
396469
@@ -423,6 +496,12 @@ class CNNBackbone(ModelABC):
423496 - "mobilenet_v3_large"
424497 - "mobilenet_v3_small"
425498
499+ Attributes:
500+ feat_extract (nn.Module):
501+ Backbone CNN model.
502+ pool (nn.Module):
503+ Type of pooling applied after feature extraction.
504+
426505 Examples:
427506 >>> # Creating resnet50 architecture from default pytorch
428507 >>> # without the classification layer with its associated
@@ -452,6 +531,9 @@ def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor:
452531 imgs (torch.Tensor):
453532 Model input.
454533
534+ Returns:
535+ torch.Tensor:
536+ The extracted features.
455537 """
456538 feat = self .feat_extract (imgs )
457539 gap_feat = self .pool (feat )
@@ -480,6 +562,9 @@ def infer_batch(
480562 list[dict[str, np.ndarray]]:
481563 list of dictionary values with numpy arrays.
482564
565+ Example:
566+ >>> output = CNNBackbone.infer_batch(model, batch_data, "cuda")
567+ >>> print(output)
483568 """
484569 return [_infer_batch (model = model , batch_data = batch_data , device = device )]
485570
@@ -491,8 +576,7 @@ class TimmBackbone(ModelABC):
491576
492577 Args:
493578 backbone (str):
494- Model name. Currently, the tool supports following
495- model names and their default associated weights from timm.
579+ Model name. Supported model names include:
496580 - "efficientnet_b{i}" for i in [0, 1, ..., 7]
497581 - "UNI"
498582 - "prov-gigapath"
@@ -503,6 +587,10 @@ class TimmBackbone(ModelABC):
503587 pretrained (bool, keyword-only):
504588 Whether to load pretrained weights.
505589
590+ Attributes:
591+ feat_extract (nn.Module):
592+ Backbone timm model.
593+
506594 Examples:
507595 >>> # Creating UNI tile encoder
508596 >>> model = TimmBackbone(backbone="UNI", pretrained=True)
@@ -531,6 +619,9 @@ def forward(self: TimmBackbone, imgs: torch.Tensor) -> torch.Tensor:
531619 imgs (torch.Tensor):
532620 Model input.
533621
622+ Returns:
623+ torch.Tensor:
624+ The extracted features.
534625 """
535626 feats = self .feat_extract (imgs )
536627 return torch .flatten (feats , 1 )
@@ -558,5 +649,8 @@ def infer_batch(
558649 list[dict[str, np.ndarray]]:
559650 list of dictionary values with numpy arrays.
560651
652+ Example:
653+ >>> output = TimmBackbone.infer_batch(model, batch_data, "cuda")
654+ >>> print(output)
561655 """
562656 return [_infer_batch (model = model , batch_data = batch_data , device = device )]
0 commit comments