Skip to content

Commit c3c3653

Browse files
cyyeverpytorchmergebot
authored andcommitted
[1/N] Add return types of Python functions (pytorch#167162)
This PR adds return types of some Python functions. Most of them return `None`. The types were added automatically by ruff `ANN` rules. Pull Request resolved: pytorch#167162 Approved by: https://github.com/Lucaskabela
1 parent f72772b commit c3c3653

35 files changed

+134
-128
lines changed

torch/nn/attention/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _cur_sdpa_kernel_backends(with_priority: bool = False):
9090
return backends
9191

9292

93-
def _sdpa_kernel(backends: Iterable, set_priority: bool = False):
93+
def _sdpa_kernel(backends: Iterable, set_priority: bool = False) -> None:
9494
for name, val in _backend_names.items():
9595
enabled = getattr(SDPBackend, val) in backends
9696
getattr(torch._C, f"_set_sdp_use_{name}")(enabled)

torch/nn/attention/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _validate_sdpa_input(
4040
dropout_p=0.0,
4141
is_causal=False,
4242
scale=None,
43-
):
43+
) -> None:
4444
if query.dtype != key.dtype or query.dtype != value.dtype:
4545
raise ValueError(
4646
f"Expected query, key, and value to have the same dtype, "

torch/nn/attention/bias.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class CausalBias(torch.Tensor):
117117
.. warning:: This class is a prototype and subject to change.
118118
"""
119119

120-
def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int):
120+
def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int) -> None:
121121
"""
122122
Initializes the CausalBias instance with a specified variant and sequence lengths.
123123
@@ -296,7 +296,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
296296
return cls._dispatch(*args, **kwargs)
297297
return super().__torch_function__(func, types, args, kwargs)
298298

299-
def __repr__(self): # type:ignore[override]
299+
def __repr__(self) -> str: # type:ignore[override]
300300
return self._materialize().__repr__()
301301

302302

torch/nn/attention/experimental/_paged_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
page_size: int,
4141
max_batch_size: int,
4242
device: str = "cuda",
43-
):
43+
) -> None:
4444
# number of pages
4545
self.n_pages = n_pages
4646

torch/nn/attention/flex_attention.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def __init__(
550550
full_q_indices: Optional[Tensor],
551551
BLOCK_SIZE: tuple[int, int],
552552
mask_mod: _mask_mod_signature,
553-
):
553+
) -> None:
554554
if kv_indices.dim() < 2:
555555
raise RuntimeError("BlockMask must have at least 2 dimensions")
556556
assert kv_num_blocks is not None, "kv_num_blocks must be provided"
@@ -682,7 +682,7 @@ def shape(self):
682682
*batch_dims, _, _ = self.kv_indices.shape
683683
return tuple(batch_dims) + self.seq_lengths
684684

685-
def __str__(self):
685+
def __str__(self) -> str:
686686
s = f"BlockMask(shape={self.shape}, sparsity={self.sparsity():.2f}%, \n"
687687
mask_str = self.to_string().strip()
688688
s += mask_str
@@ -760,7 +760,7 @@ def causal_mask(b, h, q_idx, kv_idx):
760760
compute_q_blocks=self.q_indices is not None,
761761
)
762762

763-
def __repr__(self):
763+
def __repr__(self) -> str:
764764
def shape_or_none(x: Optional[torch.Tensor]):
765765
return x.shape if x is not None else None
766766

@@ -864,7 +864,7 @@ def create_block_vis(*batch_idx):
864864

865865
vis = ", ".join(reversed(descriptors)) + "\n"
866866

867-
def summarize_section(section):
867+
def summarize_section(section) -> str:
868868
percentage = section.float().mean().item()
869869
if percentage == 1:
870870
return "█"
@@ -1289,15 +1289,15 @@ def _apply_kernel_options(
12891289
return kernel_options
12901290

12911291

1292-
def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor):
1292+
def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor) -> None:
12931293
if query.size(-1) != key.size(-1):
12941294
raise ValueError(
12951295
f"Expect query and key/value to have the same embedding dimension "
12961296
f"but got E={query.size(-1)} and E={key.size(-1)}."
12971297
)
12981298

12991299

1300-
def _validate_device(query: Tensor, key: Tensor, value: Tensor):
1300+
def _validate_device(query: Tensor, key: Tensor, value: Tensor) -> None:
13011301
"""TODO: Remove once non cuda/cpu devices support is added
13021302
We only need to check query since we have already that q,k,v are on the same device
13031303
"""

torch/nn/backends/thnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
# this is for historical pickle deserialization, it is not used otherwise
33

44

5-
def _get_thnn_function_backend():
5+
def _get_thnn_function_backend() -> None:
66
pass

torch/nn/cpp.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class OrderedDictWrapper:
1414
so using properties does not work.
1515
"""
1616

17-
def __init__(self, cpp_module, attr):
17+
def __init__(self, cpp_module, attr) -> None:
1818
self.cpp_module = cpp_module
1919
self.attr = attr
2020

@@ -37,10 +37,10 @@ def values(self):
3737
def __iter__(self):
3838
return self.cpp_dict.__iter__()
3939

40-
def __len__(self):
40+
def __len__(self) -> int:
4141
return self.cpp_dict.__len__()
4242

43-
def __contains__(self, key):
43+
def __contains__(self, key) -> bool:
4444
return self.cpp_dict.__contains__(key)
4545

4646
def __getitem__(self, key):
@@ -50,7 +50,7 @@ def __getitem__(self, key):
5050
class ModuleWrapper(nn.Module):
5151
"""A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access."""
5252

53-
def __init__(self, cpp_module):
53+
def __init__(self, cpp_module) -> None:
5454
# Assign before the super class constructor so ``self.training`` can be
5555
# assigned to in the super class constructor.
5656
self.cpp_module = cpp_module
@@ -83,8 +83,8 @@ def training(self):
8383
return self.cpp_module.training
8484

8585
@training.setter
86-
def training(self, mode):
86+
def training(self, mode) -> None:
8787
self.cpp_module.train(mode)
8888

89-
def __repr__(self):
89+
def __repr__(self) -> str:
9090
return self.cpp_module.__repr__()

torch/nn/modules/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3040,7 +3040,7 @@ def _replicate_for_data_parallel(self):
30403040

30413041
return replica
30423042

3043-
def compile(self, *args, **kwargs):
3043+
def compile(self, *args, **kwargs) -> None:
30443044
"""
30453045
Compile this Module's forward using :func:`torch.compile`.
30463046

torch/nn/parallel/data_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def _check_balance(device_ids: Sequence[Union[int, torch.device]]) -> None:
3030
device_ids = [_get_device_index(x, True) for x in device_ids]
3131
dev_props = _get_devices_properties(device_ids)
3232

33-
def warn_imbalance(get_prop):
33+
def warn_imbalance(get_prop) -> bool:
3434
values = [get_prop(props) for props in dev_props]
3535
min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
3636
max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))

torch/nn/parameter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# Metaclass to combine _TensorMeta and the instance check override for Parameter.
1919
class _ParameterMeta(torch._C._TensorMeta):
2020
# Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag.
21-
def __instancecheck__(self, instance):
21+
def __instancecheck__(self, instance) -> bool:
2222
if self is Parameter:
2323
if isinstance(instance, torch.Tensor) and getattr(
2424
instance, "_is_param", False
@@ -82,7 +82,7 @@ def __deepcopy__(self, memo):
8282
return result
8383

8484
# pyrefly: ignore [bad-override]
85-
def __repr__(self):
85+
def __repr__(self) -> str:
8686
return "Parameter containing:\n" + super().__repr__()
8787

8888
def __reduce_ex__(self, proto):
@@ -125,7 +125,7 @@ class UninitializedTensorMixin:
125125
torch._has_compatible_shallow_copy_type,
126126
]
127127

128-
def materialize(self, shape, device=None, dtype=None):
128+
def materialize(self, shape, device=None, dtype=None) -> None:
129129
r"""Create a Parameter or Tensor with the same properties of the uninitialized one.
130130
131131
Given a shape, it materializes a parameter in the same device
@@ -163,7 +163,7 @@ def share_memory_(self):
163163
"`module.share_memory()`."
164164
)
165165

166-
def __repr__(self):
166+
def __repr__(self) -> str:
167167
return f"<{self.__class__.__name__}>"
168168

169169
def __reduce_ex__(self, proto):
@@ -235,7 +235,7 @@ def __deepcopy__(self, memo):
235235
# Metaclass to combine _TensorMeta and the instance check override for Buffer.
236236
class _BufferMeta(torch._C._TensorMeta):
237237
# Make `isinstance(t, Buffer)` return True for custom tensor instances that have the _is_buffer flag.
238-
def __instancecheck__(self, instance):
238+
def __instancecheck__(self, instance) -> bool:
239239
if self is Buffer:
240240
if isinstance(instance, torch.Tensor) and getattr(
241241
instance, "_is_buffer", False

0 commit comments

Comments
 (0)