Skip to content

Commit e89f46a

Browse files
authored
Add @override for files in src/lightning/pytorch/utilities (#19315)
1 parent 75510dd commit e89f46a

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

src/lightning/pytorch/utilities/combined_loader.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union
1717

1818
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter
19-
from typing_extensions import Self, TypedDict
19+
from typing_extensions import Self, TypedDict, override
2020

2121
from lightning.fabric.utilities.data import sized_len
2222
from lightning.pytorch.utilities._pytree import _map_and_unflatten, _tree_flatten, tree_unflatten
@@ -33,9 +33,11 @@ def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, f
3333
self._idx = 0 # what would be batch_idx
3434
self.limits = limits
3535

36+
@override
3637
def __next__(self) -> _ITERATOR_RETURN:
3738
raise NotImplementedError
3839

40+
@override
3941
def __iter__(self) -> Self:
4042
self.iterators = [iter(iterable) for iterable in self.iterables]
4143
self._idx = 0
@@ -66,6 +68,7 @@ def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, f
6668
super().__init__(iterables, limits)
6769
self._consumed: List[bool] = []
6870

71+
@override
6972
def __next__(self) -> _ITERATOR_RETURN:
7073
n = len(self.iterators)
7174
out = [None] * n # values per iterator
@@ -83,29 +86,34 @@ def __next__(self) -> _ITERATOR_RETURN:
8386
self._idx += 1
8487
return out, index, 0
8588

89+
@override
8690
def __iter__(self) -> Self:
8791
super().__iter__()
8892
self._consumed = [False] * len(self.iterables)
8993
return self
9094

95+
@override
9196
def __len__(self) -> int:
9297
lengths = _get_iterables_lengths(self.iterables)
9398
if self.limits is not None:
9499
return max(min(length, limit) for length, limit in zip(lengths, self.limits)) # type: ignore[return-value]
95100
return max(lengths) # type: ignore[return-value]
96101

102+
@override
97103
def reset(self) -> None:
98104
super().reset()
99105
self._consumed = []
100106

101107

102108
class _MinSize(_ModeIterator):
109+
@override
103110
def __next__(self) -> _ITERATOR_RETURN:
104111
out = [next(it) for it in self.iterators]
105112
index = self._idx
106113
self._idx += 1
107114
return out, index, 0
108115

116+
@override
109117
def __len__(self) -> int:
110118
lengths = _get_iterables_lengths(self.iterables)
111119
return min(lengths + self.limits) if self.limits is not None else min(lengths) # type: ignore[return-value]
@@ -116,6 +124,7 @@ def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, f
116124
super().__init__(iterables, limits)
117125
self._iterator_idx = 0 # what would be dataloader_idx
118126

127+
@override
119128
def __next__(self) -> _ITERATOR_RETURN:
120129
n = len(self.iterables)
121130
if n == 0 or self._iterator_idx >= n:
@@ -138,18 +147,21 @@ def __next__(self) -> _ITERATOR_RETURN:
138147
self._idx += 1
139148
return out, index, self._iterator_idx
140149

150+
@override
141151
def __iter__(self) -> Self:
142152
self._iterator_idx = 0
143153
self._idx = 0
144154
self._load_current_iterator()
145155
return self
146156

157+
@override
147158
def __len__(self) -> int:
148159
lengths = _get_iterables_lengths(self.iterables)
149160
if self.limits is not None:
150161
return sum(min(length, limit) for length, limit in zip(lengths, self.limits)) # type: ignore[misc]
151162
return sum(lengths) # type: ignore[arg-type]
152163

164+
@override
153165
def reset(self) -> None:
154166
super().reset()
155167
self._iterator_idx = 0
@@ -169,6 +181,7 @@ def _use_next_iterator(self) -> None:
169181

170182

171183
class _MaxSize(_ModeIterator):
184+
@override
172185
def __next__(self) -> _ITERATOR_RETURN:
173186
n = len(self.iterators)
174187
out = [None] * n
@@ -183,6 +196,7 @@ def __next__(self) -> _ITERATOR_RETURN:
183196
self._idx += 1
184197
return out, index, 0
185198

199+
@override
186200
def __len__(self) -> int:
187201
lengths = _get_iterables_lengths(self.iterables)
188202
if self.limits is not None:
@@ -329,6 +343,7 @@ def __next__(self) -> _ITERATOR_RETURN:
329343
out, batch_idx, dataloader_idx = out
330344
return tree_unflatten(out, self._spec), batch_idx, dataloader_idx
331345

346+
@override
332347
def __iter__(self) -> Self:
333348
cls = _SUPPORTED_MODES[self._mode]["iterator"]
334349
iterator = cls(self.flattened, self._limits)

src/lightning/pytorch/utilities/migration/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import Any, Dict, List, Optional, Tuple, Type
2222

2323
from packaging.version import Version
24+
from typing_extensions import override
2425

2526
import lightning.pytorch as pl
2627
from lightning.fabric.utilities.enums import LightningEnum
@@ -188,6 +189,7 @@ class _RedirectingUnpickler(pickle._Unpickler):
188189
189190
"""
190191

192+
@override
191193
def find_class(self, module: str, name: str) -> Any:
192194
new_module = _patch_pl_to_mirror_if_necessary(module)
193195
# this warning won't trigger for standalone as these imports are identical

src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
from lightning_utilities.core.imports import RequirementCache
2121
from torch.nn import Parameter
22+
from typing_extensions import override
2223

2324
from lightning.pytorch.utilities.model_summary.model_summary import (
2425
NOT_APPLICABLE,
@@ -36,6 +37,7 @@ def deepspeed_param_size(p: torch.nn.Parameter) -> int:
3637

3738
class DeepSpeedLayerSummary(LayerSummary):
3839
@property
40+
@override
3941
def num_parameters(self) -> int:
4042
"""Returns the number of parameters in this module."""
4143
return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
@@ -51,6 +53,7 @@ def partitioned_size(p: Parameter) -> int:
5153

5254

5355
class DeepSpeedSummary(ModelSummary):
56+
@override
5457
def summarize(self) -> Dict[str, DeepSpeedLayerSummary]: # type: ignore[override]
5558
summary = OrderedDict((name, DeepSpeedLayerSummary(module)) for name, module in self.named_modules)
5659
if self._model.example_input_array is not None:
@@ -66,10 +69,12 @@ def summarize(self) -> Dict[str, DeepSpeedLayerSummary]: # type: ignore[overrid
6669
return summary
6770

6871
@property
72+
@override
6973
def total_parameters(self) -> int:
7074
return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
7175

7276
@property
77+
@override
7378
def trainable_parameters(self) -> int:
7479
return sum(
7580
deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0
@@ -81,6 +86,7 @@ def trainable_parameters(self) -> int:
8186
def parameters_per_layer(self) -> List[int]:
8287
return [layer.average_shard_parameters for layer in self._layer_summary.values()]
8388

89+
@override
8490
def _get_summary_data(self) -> List[Tuple[str, List[str]]]:
8591
"""Makes a summary listing with:
8692
@@ -104,6 +110,7 @@ def _get_summary_data(self) -> List[Tuple[str, List[str]]]:
104110

105111
return arrays
106112

113+
@override
107114
def _add_leftover_params_to_summary(self, arrays: List[Tuple[str, List[str]]], total_leftover_params: int) -> None:
108115
"""Add summary of params not associated with module or layer to model summary."""
109116
super()._add_leftover_params_to_summary(arrays, total_leftover_params)

0 commit comments

Comments
 (0)