Skip to content

Commit ed7903c

Browse files
authored
make DataLoader warning less noisy. test=develop (#34001)
1 parent 8417ad6 commit ed7903c

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

python/paddle/fluid/dataloader/fetcher.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
import logging
1616
from ..log_helper import get_logger
17+
from collections.abc import Sequence, Mapping
1718

18-
from collections.abc import Sequence
19+
_WARNING_TO_LOG = True
1920

2021

2122
class _DatasetFetcher(object):
@@ -24,13 +25,17 @@ def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
2425
self.auto_collate_batch = auto_collate_batch
2526
self.collate_fn = collate_fn
2627
self.drop_last = drop_last
27-
self._is_warning_logged = False
2828

2929
def fetch(self, batch_indices):
3030
raise NotImplementedError("'fetch' not implement for class {}".format(
3131
self.__class__.__name__))
3232

3333
def _log_warning(self):
34+
# only log warning on GPU 0 when distributed launch
35+
from ...distributed import get_world_size, get_rank
36+
if get_world_size() >= 2 and get_rank() != 0:
37+
return
38+
3439
warn_str = "Detect dataset only contains single fileds, return format " \
3540
"changed since Paddle 2.1. In Paddle <= 2.0, DataLoader add " \
3641
"a list surround output data(e.g. return [data]), and in " \
@@ -77,10 +82,12 @@ def fetch(self, batch_indices):
7782
if len(data) == 0 or (self.drop_last and
7883
len(data) < len(batch_indices)):
7984
raise StopIteration
80-
if not isinstance(data[0],
81-
Sequence) and not self._is_warning_logged:
85+
86+
global _WARNING_TO_LOG
87+
if not isinstance(data[0], (Sequence, Mapping)) \
88+
and _WARNING_TO_LOG:
8289
self._log_warning()
83-
self._is_warning_logged = True
90+
_WARNING_TO_LOG = False
8491
else:
8592
data = next(self.dataset_iter)
8693

@@ -98,10 +105,11 @@ def fetch(self, batch_indices):
98105
if self.auto_collate_batch:
99106
data = [self.dataset[idx] for idx in batch_indices]
100107

101-
if not isinstance(data[0],
102-
Sequence) and not self._is_warning_logged:
108+
global _WARNING_TO_LOG
109+
if not isinstance(data[0], (Sequence, Mapping)) \
110+
and _WARNING_TO_LOG:
103111
self._log_warning()
104-
self._is_warning_logged = True
112+
_WARNING_TO_LOG = False
105113
else:
106114
data = self.dataset[batch_indices]
107115

0 commit comments

Comments
 (0)