Skip to content

Commit 4004f85

Browse files
authored
Add @override for files in src/lightning/pytorch/overrides (#19316)
1 parent 97d71ab commit 4004f85

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/lightning/pytorch/overrides/distributed.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch import Tensor
1919
from torch.nn.parallel.distributed import DistributedDataParallel
2020
from torch.utils.data import DistributedSampler, Sampler
21-
from typing_extensions import Self
21+
from typing_extensions import Self, override
2222

2323
from lightning.fabric.utilities.distributed import _DatasetSamplerWrapper
2424
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info
@@ -200,6 +200,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
200200
# have at least one batch, or the DistributedDataParallel could lock up.
201201
assert self.num_samples >= 1 or self.total_size == 0
202202

203+
@override
203204
def __iter__(self) -> Iterator[List[int]]:
204205
if not isinstance(self.dataset, Sized):
205206
raise TypeError("The given dataset must implement the `__len__` method.")
@@ -226,6 +227,7 @@ class UnrepeatedDistributedSamplerWrapper(UnrepeatedDistributedSampler):
226227
def __init__(self, sampler: Union[Sampler, Iterable], *args: Any, **kwargs: Any) -> None:
227228
super().__init__(_DatasetSamplerWrapper(sampler), *args, **kwargs)
228229

230+
@override
229231
def __iter__(self) -> Iterator:
230232
self.dataset.reset()
231233
return (self.dataset[index] for index in super().__iter__())

0 commit comments

Comments
 (0)