Skip to content

Commit a46b2b1

Browse files
carmoccalantiga
authored andcommitted
Skip length checks for non-sized iterables (#17218)
1 parent 9a97fcc commit a46b2b1

File tree

1 file changed

+6
-13
lines changed
  • src/lightning/pytorch/utilities

1 file changed

+6
-13
lines changed

src/lightning/pytorch/utilities/data.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import torch
1919
from lightning_utilities.core.apply_func import is_dataclass_instance
20-
from lightning_utilities.core.rank_zero import rank_prefixed_message
2120
from torch import Tensor
2221
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler, Sampler, SequentialSampler
2322
from typing_extensions import TypeGuard
@@ -93,13 +92,12 @@ def has_len_all_ranks(
9392
strategy: "pl.strategies.Strategy",
9493
allow_zero_length_dataloader_with_multiple_devices: bool = False,
9594
) -> TypeGuard[Sized]:
96-
"""Checks if a given object has ``__len__`` method implemented on all aranks."""
95+
"""Checks if a given object has ``__len__`` method implemented on all ranks."""
9796
local_length = sized_len(dataloader)
98-
has_len = True
9997
if local_length is None:
100-
# if one rank does not define a length, the reduction after would fail, default to 0
101-
local_length = 0
102-
has_len = False
98+
# __len__ is not defined, skip these checks
99+
return False
100+
103101
total_length = strategy.reduce(torch.tensor(local_length, device=strategy.root_device), reduce_op="sum")
104102
if total_length == 0:
105103
rank_zero_warn(
@@ -108,10 +106,6 @@ def has_len_all_ranks(
108106
)
109107
if total_length > 0 and local_length == 0:
110108
dataloader_cls_name = type(dataloader).__name__
111-
if not has_len:
112-
raise RuntimeError(
113-
rank_prefixed_message(f"The `{dataloader_cls_name}` does not define a length.", strategy.global_rank)
114-
)
115109
if not allow_zero_length_dataloader_with_multiple_devices:
116110
raise RuntimeError(
117111
f"`{dataloader_cls_name}` within local rank has zero length."
@@ -121,16 +115,15 @@ def has_len_all_ranks(
121115
f"Total length of `{dataloader_cls_name}` across ranks is zero, but local rank has zero"
122116
" length. Please be cautious of uneven batch length."
123117
)
124-
has_len = False
125118

126-
if has_len and has_iterable_dataset(dataloader):
119+
if has_iterable_dataset(dataloader):
127120
rank_zero_warn(
128121
"Your `IterableDataset` has `__len__` defined."
129122
" In combination with multi-process data loading (when num_workers > 1),"
130123
" `__len__` could be inaccurate if each worker is not configured independently"
131124
" to avoid having duplicate data."
132125
)
133-
return has_len
126+
return True
134127

135128

136129
def _update_dataloader(

0 commit comments

Comments
 (0)