1212from typing import Any , Dict , Optional , Tuple , Union
1313
1414from monai .networks .nets import NetAdapter
15- from monai .utils import deprecated , optional_import
15+ from monai .utils import deprecated , deprecated_arg , optional_import
1616
1717models , _ = optional_import ("torchvision.models" )
1818
@@ -29,7 +29,7 @@ class TorchVisionFCModel(NetAdapter):
2929 ``resnet18`` (default), ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``,
3030 ``resnext50_32x4d``, ``resnext101_32x8d``, ``wide_resnet50_2``, ``wide_resnet101_2``.
3131 model details: https://pytorch.org/vision/stable/models.html.
32- n_classes : number of classes for the last classification layer. Default to 1.
32+ num_classes : number of classes for the last classification layer. Default to 1.
3333 dim: number of spatial dimensions, default to 2.
3434 in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer.
3535 use_conv: whether use convolutional layer to replace the last layer, default to False.
@@ -41,25 +41,30 @@ class TorchVisionFCModel(NetAdapter):
4141 pretrained: whether to use the imagenet pretrained weights. Default to False.
4242 """
4343
44+ @deprecated_arg ("n_classes" , since = "0.6" )
4445 def __init__ (
4546 self ,
4647 model_name : str = "resnet18" ,
47- n_classes : int = 1 ,
48+ num_classes : int = 1 ,
4849 dim : int = 2 ,
4950 in_channels : Optional [int ] = None ,
5051 use_conv : bool = False ,
5152 pool : Optional [Tuple [str , Dict [str , Any ]]] = ("avg" , {"kernel_size" : 7 , "stride" : 1 }),
5253 bias : bool = True ,
5354 pretrained : bool = False ,
55+ n_classes : Optional [int ] = None ,
5456 ):
57+ # in case the new num_classes is default but you still call deprecated n_classes
58+ if n_classes is not None and num_classes == 1 :
59+ num_classes = n_classes
5560 model = getattr (models , model_name )(pretrained = pretrained )
5661 # check if the model is compatible, should have a FC layer at the end
5762 if not str (list (model .children ())[- 1 ]).startswith ("Linear" ):
5863 raise ValueError (f"Model ['{ model_name } '] does not have a Linear layer at the end." )
5964
6065 super ().__init__ (
6166 model = model ,
62- n_classes = n_classes ,
67+ num_classes = num_classes ,
6368 dim = dim ,
6469 in_channels = in_channels ,
6570 use_conv = use_conv ,
@@ -77,7 +82,7 @@ class TorchVisionFullyConvModel(TorchVisionFCModel):
7782 model_name: name of any torchvision with adaptive avg pooling and fully connected layer at the end.
7883 ``resnet18`` (default), ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``,
7984 ``resnext50_32x4d``, ``resnext101_32x8d``, ``wide_resnet50_2``, ``wide_resnet101_2``.
80- n_classes : number of classes for the last classification layer. Default to 1.
85+ num_classes : number of classes for the last classification layer. Default to 1.
8186 pool_size: the kernel size for `AvgPool2d` to replace `AdaptiveAvgPool2d`. Default to (7, 7).
8287 pool_stride: the stride for `AvgPool2d` to replace `AdaptiveAvgPool2d`. Default to 1.
8388 pretrained: whether to use the imagenet pretrained weights. Default to False.
@@ -87,17 +92,22 @@ class TorchVisionFullyConvModel(TorchVisionFCModel):
8792
8893 """
8994
95+ @deprecated_arg ("n_classes" , since = "0.6" )
9096 def __init__ (
9197 self ,
9298 model_name : str = "resnet18" ,
93- n_classes : int = 1 ,
99+ num_classes : int = 1 ,
94100 pool_size : Union [int , Tuple [int , int ]] = (7 , 7 ),
95101 pool_stride : Union [int , Tuple [int , int ]] = 1 ,
96102 pretrained : bool = False ,
103+ n_classes : Optional [int ] = None ,
97104 ):
105+ # in case the new num_classes is default but you still call deprecated n_classes
106+ if n_classes is not None and num_classes == 1 :
107+ num_classes = n_classes
98108 super ().__init__ (
99109 model_name = model_name ,
100- n_classes = n_classes ,
110+ num_classes = num_classes ,
101111 use_conv = True ,
102112 pool = ("avg" , {"kernel_size" : pool_size , "stride" : pool_stride }),
103113 pretrained = pretrained ,
0 commit comments