Skip to content

Commit 163f0d8

Browse files
Skylion007pytorchmergebot
authored andcommitted
[BE][Ez]: Auto add return type annotations for methods in torch/nn/module (pytorch#157925)
Automatically type a bunch of methods in nn.Module using ruff's type inference rules Pull Request resolved: pytorch#157925 Approved by: https://github.com/albanD
1 parent f742b32 commit 163f0d8

File tree

10 files changed

+62
-62
lines changed

10 files changed

+62
-62
lines changed

torch/nn/modules/activation.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(self, threshold: float, value: float, inplace: bool = False) -> Non
9191
def forward(self, input: Tensor) -> Tensor:
9292
return F.threshold(input, self.threshold, self.value, self.inplace)
9393

94-
def extra_repr(self):
94+
def extra_repr(self) -> str:
9595
inplace_str = ", inplace=True" if self.inplace else ""
9696
return f"threshold={self.threshold}, value={self.value}{inplace_str}"
9797

@@ -127,7 +127,7 @@ class ReLU(Module):
127127
__constants__ = ["inplace"]
128128
inplace: bool
129129

130-
def __init__(self, inplace: bool = False):
130+
def __init__(self, inplace: bool = False) -> None:
131131
super().__init__()
132132
self.inplace = inplace
133133

@@ -185,7 +185,7 @@ class RReLU(Module):
185185

186186
def __init__(
187187
self, lower: float = 1.0 / 8, upper: float = 1.0 / 3, inplace: bool = False
188-
):
188+
) -> None:
189189
super().__init__()
190190
self.lower = lower
191191
self.upper = upper
@@ -194,7 +194,7 @@ def __init__(
194194
def forward(self, input: Tensor) -> Tensor:
195195
return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
196196

197-
def extra_repr(self):
197+
def extra_repr(self) -> str:
198198
inplace_str = ", inplace=True" if self.inplace else ""
199199
return f"lower={self.lower}, upper={self.upper}{inplace_str}"
200200

@@ -297,7 +297,7 @@ class ReLU6(Hardtanh):
297297
>>> output = m(input)
298298
"""
299299

300-
def __init__(self, inplace: bool = False):
300+
def __init__(self, inplace: bool = False) -> None:
301301
super().__init__(0.0, 6.0, inplace)
302302

303303
def extra_repr(self) -> str:
@@ -426,7 +426,7 @@ class SiLU(Module):
426426
__constants__ = ["inplace"]
427427
inplace: bool
428428

429-
def __init__(self, inplace: bool = False):
429+
def __init__(self, inplace: bool = False) -> None:
430430
super().__init__()
431431
self.inplace = inplace
432432

@@ -465,7 +465,7 @@ class Mish(Module):
465465
__constants__ = ["inplace"]
466466
inplace: bool
467467

468-
def __init__(self, inplace: bool = False):
468+
def __init__(self, inplace: bool = False) -> None:
469469
super().__init__()
470470
self.inplace = inplace
471471

@@ -1118,7 +1118,7 @@ def __init__(
11181118

11191119
self._reset_parameters()
11201120

1121-
def _reset_parameters(self):
1121+
def _reset_parameters(self) -> None:
11221122
if self._qkv_same_embed_dim:
11231123
xavier_uniform_(self.in_proj_weight)
11241124
else:
@@ -1517,7 +1517,7 @@ def __init__(
15171517
self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs))
15181518
self.reset_parameters()
15191519

1520-
def reset_parameters(self):
1520+
def reset_parameters(self) -> None:
15211521
torch.nn.init.constant_(self.weight, self.init)
15221522

15231523
def forward(self, input: Tensor) -> Tensor:
@@ -1619,7 +1619,7 @@ def __setstate__(self, state):
16191619
def forward(self, input: Tensor) -> Tensor:
16201620
return F.softmin(input, self.dim, _stacklevel=5)
16211621

1622-
def extra_repr(self):
1622+
def extra_repr(self) -> str:
16231623
return f"dim={self.dim}"
16241624

16251625

@@ -1754,5 +1754,5 @@ def __setstate__(self, state):
17541754
def forward(self, input: Tensor) -> Tensor:
17551755
return F.log_softmax(input, self.dim, _stacklevel=5)
17561756

1757-
def extra_repr(self):
1757+
def extra_repr(self) -> str:
17581758
return f"dim={self.dim}"

torch/nn/modules/batchnorm.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _load_from_state_dict(
114114
missing_keys,
115115
unexpected_keys,
116116
error_msgs,
117-
):
117+
) -> None:
118118
version = local_metadata.get("version", None)
119119

120120
if (version is None or version < 2) and self.track_running_stats:
@@ -336,7 +336,7 @@ class BatchNorm1d(_BatchNorm):
336336
>>> output = m(input)
337337
"""
338338

339-
def _check_input_dim(self, input):
339+
def _check_input_dim(self, input) -> None:
340340
if input.dim() != 2 and input.dim() != 3:
341341
raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
342342

@@ -370,7 +370,7 @@ class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
370370

371371
cls_to_become = BatchNorm1d # type: ignore[assignment]
372372

373-
def _check_input_dim(self, input):
373+
def _check_input_dim(self, input) -> None:
374374
if input.dim() != 2 and input.dim() != 3:
375375
raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
376376

@@ -447,7 +447,7 @@ class BatchNorm2d(_BatchNorm):
447447
>>> output = m(input)
448448
"""
449449

450-
def _check_input_dim(self, input):
450+
def _check_input_dim(self, input) -> None:
451451
if input.dim() != 4:
452452
raise ValueError(f"expected 4D input (got {input.dim()}D input)")
453453

@@ -481,7 +481,7 @@ class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
481481

482482
cls_to_become = BatchNorm2d # type: ignore[assignment]
483483

484-
def _check_input_dim(self, input):
484+
def _check_input_dim(self, input) -> None:
485485
if input.dim() != 4:
486486
raise ValueError(f"expected 4D input (got {input.dim()}D input)")
487487

@@ -558,7 +558,7 @@ class BatchNorm3d(_BatchNorm):
558558
>>> output = m(input)
559559
"""
560560

561-
def _check_input_dim(self, input):
561+
def _check_input_dim(self, input) -> None:
562562
if input.dim() != 5:
563563
raise ValueError(f"expected 5D input (got {input.dim()}D input)")
564564

@@ -592,7 +592,7 @@ class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
592592

593593
cls_to_become = BatchNorm3d # type: ignore[assignment]
594594

595-
def _check_input_dim(self, input):
595+
def _check_input_dim(self, input) -> None:
596596
if input.dim() != 5:
597597
raise ValueError(f"expected 5D input (got {input.dim()}D input)")
598598

@@ -717,11 +717,11 @@ def __init__(
717717
)
718718
self.process_group = process_group
719719

720-
def _check_input_dim(self, input):
720+
def _check_input_dim(self, input) -> None:
721721
if input.dim() < 2:
722722
raise ValueError(f"expected at least 2D input (got {input.dim()}D input)")
723723

724-
def _check_non_zero_input_channels(self, input):
724+
def _check_non_zero_input_channels(self, input) -> None:
725725
if input.size(1) == 0:
726726
raise ValueError(
727727
"SyncBatchNorm number of input channels should be non-zero"

torch/nn/modules/conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1374,7 +1374,7 @@ class _ConvTransposeMixin(_ConvTransposeNd):
13741374
"Please consider using public APIs.",
13751375
category=FutureWarning,
13761376
)
1377-
def __init__(self, *args, **kwargs):
1377+
def __init__(self, *args, **kwargs) -> None:
13781378
super().__init__(*args, **kwargs)
13791379

13801380

torch/nn/modules/flatten.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __init__(
124124
self.dim = dim
125125
self.unflattened_size = unflattened_size
126126

127-
def _require_tuple_tuple(self, input):
127+
def _require_tuple_tuple(self, input) -> None:
128128
if isinstance(input, tuple):
129129
for idx, elem in enumerate(input):
130130
if not isinstance(elem, tuple):
@@ -138,7 +138,7 @@ def _require_tuple_tuple(self, input):
138138
+ f"but found type {type(input).__name__}"
139139
)
140140

141-
def _require_tuple_int(self, input):
141+
def _require_tuple_int(self, input) -> None:
142142
if isinstance(input, (tuple, list)):
143143
for idx, elem in enumerate(input):
144144
if not isinstance(elem, int):

torch/nn/modules/instancenorm.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _load_from_state_dict(
6464
missing_keys,
6565
unexpected_keys,
6666
error_msgs,
67-
):
67+
) -> None:
6868
version = local_metadata.get("version", None)
6969
# at version 1: removed running_mean and running_var when
7070
# track_running_stats=False (default)
@@ -193,10 +193,10 @@ class InstanceNorm1d(_InstanceNorm):
193193
>>> output = m(input)
194194
"""
195195

196-
def _get_no_batch_dim(self):
196+
def _get_no_batch_dim(self) -> int:
197197
return 2
198198

199-
def _check_input_dim(self, input):
199+
def _check_input_dim(self, input) -> None:
200200
if input.dim() not in (2, 3):
201201
raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
202202

@@ -230,10 +230,10 @@ class LazyInstanceNorm1d(_LazyNormBase, _InstanceNorm):
230230

231231
cls_to_become = InstanceNorm1d # type: ignore[assignment]
232232

233-
def _get_no_batch_dim(self):
233+
def _get_no_batch_dim(self) -> int:
234234
return 2
235235

236-
def _check_input_dim(self, input):
236+
def _check_input_dim(self, input) -> None:
237237
if input.dim() not in (2, 3):
238238
raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
239239

@@ -309,10 +309,10 @@ class InstanceNorm2d(_InstanceNorm):
309309
>>> output = m(input)
310310
"""
311311

312-
def _get_no_batch_dim(self):
312+
def _get_no_batch_dim(self) -> int:
313313
return 3
314314

315-
def _check_input_dim(self, input):
315+
def _check_input_dim(self, input) -> None:
316316
if input.dim() not in (3, 4):
317317
raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)")
318318

@@ -347,10 +347,10 @@ class LazyInstanceNorm2d(_LazyNormBase, _InstanceNorm):
347347

348348
cls_to_become = InstanceNorm2d # type: ignore[assignment]
349349

350-
def _get_no_batch_dim(self):
350+
def _get_no_batch_dim(self) -> int:
351351
return 3
352352

353-
def _check_input_dim(self, input):
353+
def _check_input_dim(self, input) -> None:
354354
if input.dim() not in (3, 4):
355355
raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)")
356356

@@ -425,10 +425,10 @@ class InstanceNorm3d(_InstanceNorm):
425425
>>> output = m(input)
426426
"""
427427

428-
def _get_no_batch_dim(self):
428+
def _get_no_batch_dim(self) -> int:
429429
return 4
430430

431-
def _check_input_dim(self, input):
431+
def _check_input_dim(self, input) -> None:
432432
if input.dim() not in (4, 5):
433433
raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)")
434434

@@ -463,9 +463,9 @@ class LazyInstanceNorm3d(_LazyNormBase, _InstanceNorm):
463463

464464
cls_to_become = InstanceNorm3d # type: ignore[assignment]
465465

466-
def _get_no_batch_dim(self):
466+
def _get_no_batch_dim(self) -> int:
467467
return 4
468468

469-
def _check_input_dim(self, input):
469+
def _check_input_dim(self, input) -> None:
470470
if input.dim() not in (4, 5):
471471
raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)")

torch/nn/modules/loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,7 +1690,7 @@ def __init__(
16901690
size_average=None,
16911691
reduce=None,
16921692
reduction: str = "mean",
1693-
):
1693+
) -> None:
16941694
super().__init__(size_average, reduce, reduction)
16951695
if margin <= 0:
16961696
raise ValueError(
@@ -1824,7 +1824,7 @@ def __init__(
18241824
margin: float = 1.0,
18251825
swap: bool = False,
18261826
reduction: str = "mean",
1827-
):
1827+
) -> None:
18281828
super().__init__(size_average=None, reduce=None, reduction=reduction)
18291829
if margin <= 0:
18301830
raise ValueError(
@@ -2004,7 +2004,7 @@ class CTCLoss(_Loss):
20042004

20052005
def __init__(
20062006
self, blank: int = 0, reduction: str = "mean", zero_infinity: bool = False
2007-
):
2007+
) -> None:
20082008
super().__init__(reduction=reduction)
20092009
self.blank = blank
20102010
self.zero_infinity = zero_infinity

0 commit comments

Comments
 (0)