14
14
15
15
import logging
16
16
from ..log_helper import get_logger
17
+ from collections .abc import Sequence , Mapping
17
18
18
- from collections . abc import Sequence
19
+ _WARNING_TO_LOG = True
19
20
20
21
21
22
class _DatasetFetcher (object ):
@@ -24,13 +25,17 @@ def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
24
25
self .auto_collate_batch = auto_collate_batch
25
26
self .collate_fn = collate_fn
26
27
self .drop_last = drop_last
27
- self ._is_warning_logged = False
28
28
29
29
def fetch (self , batch_indices ):
30
30
raise NotImplementedError ("'fetch' not implement for class {}" .format (
31
31
self .__class__ .__name__ ))
32
32
33
33
def _log_warning (self ):
34
+ # only log warning on GPU 0 when distributed launch
35
+ from ...distributed import get_world_size , get_rank
36
+ if get_world_size () >= 2 and get_rank () != 0 :
37
+ return
38
+
34
39
warn_str = "Detect dataset only contains single fileds, return format " \
35
40
"changed since Paddle 2.1. In Paddle <= 2.0, DataLoader add " \
36
41
"a list surround output data(e.g. return [data]), and in " \
@@ -77,10 +82,12 @@ def fetch(self, batch_indices):
77
82
if len (data ) == 0 or (self .drop_last and
78
83
len (data ) < len (batch_indices )):
79
84
raise StopIteration
80
- if not isinstance (data [0 ],
81
- Sequence ) and not self ._is_warning_logged :
85
+
86
+ global _WARNING_TO_LOG
87
+ if not isinstance (data [0 ], (Sequence , Mapping )) \
88
+ and _WARNING_TO_LOG :
82
89
self ._log_warning ()
83
- self . _is_warning_logged = True
90
+ _WARNING_TO_LOG = False
84
91
else :
85
92
data = next (self .dataset_iter )
86
93
@@ -98,10 +105,11 @@ def fetch(self, batch_indices):
98
105
if self .auto_collate_batch :
99
106
data = [self .dataset [idx ] for idx in batch_indices ]
100
107
101
- if not isinstance (data [0 ],
102
- Sequence ) and not self ._is_warning_logged :
108
+ global _WARNING_TO_LOG
109
+ if not isinstance (data [0 ], (Sequence , Mapping )) \
110
+ and _WARNING_TO_LOG :
103
111
self ._log_warning ()
104
- self . _is_warning_logged = True
112
+ _WARNING_TO_LOG = False
105
113
else :
106
114
data = self .dataset [batch_indices ]
107
115
0 commit comments