Skip to content

Commit 1ecf5b6

Browse files
authored
rename n_classes to num_classes (#2842)
* rename n_classes Signed-off-by: Jirka <[email protected]> * back compatibility Signed-off-by: Jirka <[email protected]>
1 parent 9c7b71f commit 1ecf5b6

22 files changed

+121
-81
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
55
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
66

77
## [Unreleased]
8+
* renamed model's `n_classes` to `num_classes`
9+
810
## [0.6.0] - 2021-07-08
911
### Added
1012
* 10 new transforms, a masked loss wrapper, and a `NetAdapter` for transfer learning

monai/losses/tversky.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
155155
if self.reduction == LossReduction.SUM.value:
156156
return torch.sum(score) # sum over the batch and channel dims
157157
if self.reduction == LossReduction.NONE.value:
158-
return score # returns [N, n_classes] losses
158+
return score # returns [N, num_classes] losses
159159
if self.reduction == LossReduction.MEAN.value:
160160
return torch.mean(score)
161161
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

monai/metrics/meandice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def compute_meandice(
114114
the predicted output. Defaults to True.
115115
116116
Returns:
117-
Dice scores per batch and per class, (shape [batch_size, n_classes]).
117+
Dice scores per batch and per class, (shape [batch_size, num_classes]).
118118
119119
Raises:
120120
ValueError: when `y_pred` and `y` have different shapes.

monai/metrics/rocauc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,9 @@ def compute_roc_auc(
131131
y_pred_ndim = y_pred.ndimension()
132132
y_ndim = y.ndimension()
133133
if y_pred_ndim not in (1, 2):
134-
raise ValueError("Predictions should be of shape (batch_size, n_classes) or (batch_size, ).")
134+
raise ValueError("Predictions should be of shape (batch_size, num_classes) or (batch_size, ).")
135135
if y_ndim not in (1, 2):
136-
raise ValueError("Targets should be of shape (batch_size, n_classes) or (batch_size, ).")
136+
raise ValueError("Targets should be of shape (batch_size, num_classes) or (batch_size, ).")
137137
if y_pred_ndim == 2 and y_pred.shape[1] == 1:
138138
y_pred = y_pred.squeeze(dim=-1)
139139
y_pred_ndim = 1

monai/networks/nets/netadapter.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515

1616
from monai.networks.layers import Conv, get_pool_layer
17+
from monai.utils import deprecated_arg
1718

1819

1920
class NetAdapter(torch.nn.Module):
@@ -26,7 +27,7 @@ class NetAdapter(torch.nn.Module):
2627
model: a PyTorch model, support both 2D and 3D models. typically, it can be a pretrained model in Torchvision,
2728
like: ``resnet18``, ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``, etc.
2829
more details: https://pytorch.org/vision/stable/models.html.
29-
n_classes: number of classes for the last classification layer. Default to 1.
30+
num_classes: number of classes for the last classification layer. Default to 1.
3031
dim: number of spatial dimensions, default to 2.
3132
in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer.
3233
use_conv: whether use convolutional layer to replace the last layer, default to False.
@@ -38,17 +39,22 @@ class NetAdapter(torch.nn.Module):
3839
3940
"""
4041

42+
@deprecated_arg("n_classes", since="0.6")
4143
def __init__(
4244
self,
4345
model: torch.nn.Module,
44-
n_classes: int = 1,
46+
num_classes: int = 1,
4547
dim: int = 2,
4648
in_channels: Optional[int] = None,
4749
use_conv: bool = False,
4850
pool: Optional[Tuple[str, Dict[str, Any]]] = ("avg", {"kernel_size": 7, "stride": 1}),
4951
bias: bool = True,
52+
n_classes: Optional[int] = None,
5053
):
5154
super().__init__()
55+
# in case the new num_classes is default but you still call deprecated n_classes
56+
if n_classes is not None and num_classes == 1:
57+
num_classes = n_classes
5258
layers = list(model.children())
5359
orig_fc = layers[-1]
5460
in_channels_: int
@@ -74,7 +80,7 @@ def __init__(
7480
# add 1x1 conv (it behaves like a FC layer)
7581
self.fc = Conv[Conv.CONV, dim](
7682
in_channels=in_channels_,
77-
out_channels=n_classes,
83+
out_channels=num_classes,
7884
kernel_size=1,
7985
bias=bias,
8086
)
@@ -84,7 +90,7 @@ def __init__(
8490
# replace the out_features of FC layer
8591
self.fc = torch.nn.Linear(
8692
in_features=in_channels_,
87-
out_features=n_classes,
93+
out_features=num_classes,
8894
bias=bias,
8995
)
9096
self.use_conv = use_conv

monai/networks/nets/resnet.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# limitations under the License.
1111

1212
from functools import partial
13-
from typing import Any, Callable, List, Type, Union
13+
from typing import Any, Callable, List, Optional, Type, Union
1414

1515
import torch
1616
import torch.nn as nn
@@ -20,6 +20,8 @@
2020

2121
__all__ = ["ResNet", "resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"]
2222

23+
from monai.utils import deprecated_arg
24+
2325

2426
def get_inplanes():
2527
return [64, 128, 256, 512]
@@ -162,9 +164,10 @@ class ResNet(nn.Module):
162164
no_max_pool: bool argument to determine if to use maxpool layer.
163165
shortcut_type: which downsample block to use.
164166
widen_factor: widen output for each layer.
165-
n_classes: number of output (classifications)
167+
num_classes: number of output (classifications)
166168
"""
167169

170+
@deprecated_arg("n_classes", since="0.6")
168171
def __init__(
169172
self,
170173
block: Type[Union[ResNetBlock, ResNetBottleneck]],
@@ -177,11 +180,15 @@ def __init__(
177180
no_max_pool: bool = False,
178181
shortcut_type: str = "B",
179182
widen_factor: float = 1.0,
180-
n_classes: int = 400,
183+
num_classes: int = 400,
181184
feed_forward: bool = True,
185+
n_classes: Optional[int] = None,
182186
) -> None:
183187

184188
super(ResNet, self).__init__()
189+
# in case the new num_classes is default but you still call deprecated n_classes
190+
if n_classes is not None and num_classes == 400:
191+
num_classes = n_classes
185192

186193
conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims]
187194
norm_type: Type[Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims]
@@ -215,7 +222,7 @@ def __init__(
215222
self.avgpool = avgp_type(block_avgpool[spatial_dims])
216223

217224
if feed_forward:
218-
self.fc = nn.Linear(block_inplanes[3] * block.expansion, n_classes)
225+
self.fc = nn.Linear(block_inplanes[3] * block.expansion, num_classes)
219226

220227
for m in self.modules():
221228
if isinstance(m, conv_type):
@@ -303,7 +310,7 @@ def _resnet(
303310
progress: bool,
304311
**kwargs: Any,
305312
) -> ResNet:
306-
model = ResNet(block, layers, block_inplanes, **kwargs)
313+
model: ResNet = ResNet(block, layers, block_inplanes, **kwargs)
307314
if pretrained:
308315
# Author of paper zipped the state_dict on googledrive,
309316
# so would need to download, unzip and read (2.8gb file for a ~150mb state dict).

monai/networks/nets/torchvision_fc.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import Any, Dict, Optional, Tuple, Union
1313

1414
from monai.networks.nets import NetAdapter
15-
from monai.utils import deprecated, optional_import
15+
from monai.utils import deprecated, deprecated_arg, optional_import
1616

1717
models, _ = optional_import("torchvision.models")
1818

@@ -29,7 +29,7 @@ class TorchVisionFCModel(NetAdapter):
2929
``resnet18`` (default), ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``,
3030
``resnext50_32x4d``, ``resnext101_32x8d``, ``wide_resnet50_2``, ``wide_resnet101_2``.
3131
model details: https://pytorch.org/vision/stable/models.html.
32-
n_classes: number of classes for the last classification layer. Default to 1.
32+
num_classes: number of classes for the last classification layer. Default to 1.
3333
dim: number of spatial dimensions, default to 2.
3434
in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer.
3535
use_conv: whether use convolutional layer to replace the last layer, default to False.
@@ -41,25 +41,30 @@ class TorchVisionFCModel(NetAdapter):
4141
pretrained: whether to use the imagenet pretrained weights. Default to False.
4242
"""
4343

44+
@deprecated_arg("n_classes", since="0.6")
4445
def __init__(
4546
self,
4647
model_name: str = "resnet18",
47-
n_classes: int = 1,
48+
num_classes: int = 1,
4849
dim: int = 2,
4950
in_channels: Optional[int] = None,
5051
use_conv: bool = False,
5152
pool: Optional[Tuple[str, Dict[str, Any]]] = ("avg", {"kernel_size": 7, "stride": 1}),
5253
bias: bool = True,
5354
pretrained: bool = False,
55+
n_classes: Optional[int] = None,
5456
):
57+
# in case the new num_classes is default but you still call deprecated n_classes
58+
if n_classes is not None and num_classes == 1:
59+
num_classes = n_classes
5560
model = getattr(models, model_name)(pretrained=pretrained)
5661
# check if the model is compatible, should have a FC layer at the end
5762
if not str(list(model.children())[-1]).startswith("Linear"):
5863
raise ValueError(f"Model ['{model_name}'] does not have a Linear layer at the end.")
5964

6065
super().__init__(
6166
model=model,
62-
n_classes=n_classes,
67+
num_classes=num_classes,
6368
dim=dim,
6469
in_channels=in_channels,
6570
use_conv=use_conv,
@@ -77,7 +82,7 @@ class TorchVisionFullyConvModel(TorchVisionFCModel):
7782
model_name: name of any torchvision with adaptive avg pooling and fully connected layer at the end.
7883
``resnet18`` (default), ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``,
7984
``resnext50_32x4d``, ``resnext101_32x8d``, ``wide_resnet50_2``, ``wide_resnet101_2``.
80-
n_classes: number of classes for the last classification layer. Default to 1.
85+
num_classes: number of classes for the last classification layer. Default to 1.
8186
pool_size: the kernel size for `AvgPool2d` to replace `AdaptiveAvgPool2d`. Default to (7, 7).
8287
pool_stride: the stride for `AvgPool2d` to replace `AdaptiveAvgPool2d`. Default to 1.
8388
pretrained: whether to use the imagenet pretrained weights. Default to False.
@@ -87,17 +92,22 @@ class TorchVisionFullyConvModel(TorchVisionFCModel):
8792
8893
"""
8994

95+
@deprecated_arg("n_classes", since="0.6")
9096
def __init__(
9197
self,
9298
model_name: str = "resnet18",
93-
n_classes: int = 1,
99+
num_classes: int = 1,
94100
pool_size: Union[int, Tuple[int, int]] = (7, 7),
95101
pool_stride: Union[int, Tuple[int, int]] = 1,
96102
pretrained: bool = False,
103+
n_classes: Optional[int] = None,
97104
):
105+
# in case the new num_classes is default but you still call deprecated n_classes
106+
if n_classes is not None and num_classes == 1:
107+
num_classes = n_classes
98108
super().__init__(
99109
model_name=model_name,
100-
n_classes=n_classes,
110+
num_classes=num_classes,
101111
use_conv=True,
102112
pool=("avg", {"kernel_size": pool_size, "stride": pool_stride}),
103113
pretrained=pretrained,

monai/transforms/post/array.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from monai.networks.layers import GaussianFilter
2626
from monai.transforms.transform import Transform
2727
from monai.transforms.utils import fill_holes, get_largest_connected_component_mask
28-
from monai.utils import ensure_tuple, look_up_option
28+
from monai.utils import deprecated_arg, ensure_tuple, look_up_option
2929

3030
__all__ = [
3131
"Activations",
@@ -120,7 +120,7 @@ class AsDiscrete(Transform):
120120
Defaults to ``False``.
121121
to_onehot: whether to convert input data into the one-hot format.
122122
Defaults to ``False``.
123-
n_classes: the number of classes to convert to One-Hot format.
123+
num_classes: the number of classes to convert to One-Hot format.
124124
Defaults to ``None``.
125125
threshold_values: whether threshold the float value to int number 0 or 1.
126126
Defaults to ``False``.
@@ -131,31 +131,38 @@ class AsDiscrete(Transform):
131131
132132
"""
133133

134+
@deprecated_arg("n_classes", since="0.6")
134135
def __init__(
135136
self,
136137
argmax: bool = False,
137138
to_onehot: bool = False,
138-
n_classes: Optional[int] = None,
139+
num_classes: Optional[int] = None,
139140
threshold_values: bool = False,
140141
logit_thresh: float = 0.5,
141142
rounding: Optional[str] = None,
143+
n_classes: Optional[int] = None,
142144
) -> None:
145+
# in case the new num_classes is default but you still call deprecated n_classes
146+
if n_classes is not None and num_classes is None:
147+
num_classes = n_classes
143148
self.argmax = argmax
144149
self.to_onehot = to_onehot
145-
self.n_classes = n_classes
150+
self.num_classes = num_classes
146151
self.threshold_values = threshold_values
147152
self.logit_thresh = logit_thresh
148153
self.rounding = rounding
149154

155+
@deprecated_arg("n_classes", since="0.6")
150156
def __call__(
151157
self,
152158
img: torch.Tensor,
153159
argmax: Optional[bool] = None,
154160
to_onehot: Optional[bool] = None,
155-
n_classes: Optional[int] = None,
161+
num_classes: Optional[int] = None,
156162
threshold_values: Optional[bool] = None,
157163
logit_thresh: Optional[float] = None,
158164
rounding: Optional[str] = None,
165+
n_classes: Optional[int] = None,
159166
) -> torch.Tensor:
160167
"""
161168
Args:
@@ -165,8 +172,8 @@ def __call__(
165172
Defaults to ``self.argmax``.
166173
to_onehot: whether to convert input data into the one-hot format.
167174
Defaults to ``self.to_onehot``.
168-
n_classes: the number of classes to convert to One-Hot format.
169-
Defaults to ``self.n_classes``.
175+
num_classes: the number of classes to convert to One-Hot format.
176+
Defaults to ``self.num_classes``.
170177
threshold_values: whether threshold the float value to int number 0 or 1.
171178
Defaults to ``self.threshold_values``.
172179
logit_thresh: the threshold value for thresholding operation..
@@ -175,13 +182,16 @@ def __call__(
175182
available options: ["torchrounding"].
176183
177184
"""
185+
# in case the new num_classes is default but you still call deprecated n_classes
186+
if n_classes is not None and num_classes is None:
187+
num_classes = n_classes
178188
if argmax or self.argmax:
179189
img = torch.argmax(img, dim=0, keepdim=True)
180190

181191
if to_onehot or self.to_onehot:
182-
_nclasses = self.n_classes if n_classes is None else n_classes
192+
_nclasses = self.num_classes if num_classes is None else num_classes
183193
if not isinstance(_nclasses, int):
184-
raise AssertionError("One of self.n_classes or n_classes must be an integer")
194+
raise AssertionError("One of self.num_classes or num_classes must be an integer")
185195
img = one_hot(img, num_classes=_nclasses, dim=0)
186196

187197
if threshold_values or self.threshold_values:

0 commit comments

Comments
 (0)