Skip to content

Commit bbedca4

Browse files
authored
[cherry pick] add warning for dataloader incompatable upgrade (#33514)
* add warning log for DataLoader output format imcompatible upgrade. test=develop * add unittest. test=develop * fix ci converage. test=develop * fix ci coverage. test=develop
1 parent 0079e0b commit bbedca4

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed

python/paddle/fluid/dataloader/fetcher.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,51 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
16+
from ..log_helper import get_logger
17+
18+
from collections.abc import Sequence
19+
1520

1621
class _DatasetFetcher(object):
1722
def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
1823
self.dataset = dataset
1924
self.auto_collate_batch = auto_collate_batch
2025
self.collate_fn = collate_fn
2126
self.drop_last = drop_last
27+
self._is_warning_logged = False
2228

2329
def fetch(self, batch_indices):
2430
raise NotImplementedError("'fetch' not implement for class {}".format(
2531
self.__class__.__name__))
2632

33+
def _log_warning(self):
34+
warn_str = "Detect dataset only contains single fileds, return format " \
35+
"changed since Paddle 2.1. In Paddle <= 2.0, DataLoader add " \
36+
"a list surround output data(e.g. return [data]), and in " \
37+
"Paddle >= 2.1, DataLoader return the single filed directly " \
38+
"(e.g. return data). For example, in following code: \n\n"
39+
warn_str += \
40+
"import numpy as np\n" \
41+
"from paddle.io import DataLoader, Dataset\n\n" \
42+
"class RandomDataset(Dataset):\n" \
43+
" def __getitem__(self, idx):\n" \
44+
" data = np.random.random((2, 3)).astype('float32')\n\n" \
45+
" return data\n\n" \
46+
" def __len__(self):\n" \
47+
" return 10\n\n" \
48+
"dataset = RandomDataset()\n" \
49+
"loader = DataLoader(dataset, batch_size=1)\n" \
50+
"data = next(loader())\n\n"
51+
52+
warn_str += "In Paddle <= 2.0, data is in format '[Tensor(shape=(1, 2, 3), " \
53+
"dtype=float32)]', and in Paddle >= 2.1, data is in format" \
54+
" 'Tensor(shape=(1, 2, 3), dtype=float32)'\n"
55+
56+
logger = get_logger(
57+
"DataLoader", logging.INFO, fmt='%(levelname)s: %(message)s')
58+
logger.warning(warn_str)
59+
2760

2861
class _IterableDatasetFetcher(_DatasetFetcher):
2962
def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
@@ -40,9 +73,14 @@ def fetch(self, batch_indices):
4073
data.append(next(self.dataset_iter))
4174
except StopIteration:
4275
break
76+
4377
if len(data) == 0 or (self.drop_last and
4478
len(data) < len(batch_indices)):
4579
raise StopIteration
80+
if not isinstance(data[0],
81+
Sequence) and not self._is_warning_logged:
82+
self._log_warning()
83+
self._is_warning_logged = True
4684
else:
4785
data = next(self.dataset_iter)
4886

@@ -59,6 +97,11 @@ def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
5997
def fetch(self, batch_indices):
6098
if self.auto_collate_batch:
6199
data = [self.dataset[idx] for idx in batch_indices]
100+
101+
if not isinstance(data[0],
102+
Sequence) and not self._is_warning_logged:
103+
self._log_warning()
104+
self._is_warning_logged = True
62105
else:
63106
data = self.dataset[batch_indices]
64107

python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,59 @@ def test_main(self):
330330
self.run_main(num_workers)
331331

332332

333+
class SingleFieldDataset(Dataset):
334+
def __init__(self, sample_num):
335+
self.sample_num = sample_num
336+
337+
def __len__(self):
338+
return self.sample_num
339+
340+
def __getitem__(self, idx):
341+
return np.random.random((2, 3)).astype('float32')
342+
343+
344+
class TestSingleFieldDataset(unittest.TestCase):
345+
def init_dataset(self):
346+
self.sample_num = 16
347+
self.dataset = SingleFieldDataset(self.sample_num)
348+
349+
def run_main(self, num_workers):
350+
paddle.static.default_startup_program().random_seed = 1
351+
paddle.static.default_main_program().random_seed = 1
352+
place = paddle.CPUPlace()
353+
with fluid.dygraph.guard(place):
354+
self.init_dataset()
355+
dataloader = DataLoader(
356+
self.dataset,
357+
places=place,
358+
num_workers=num_workers,
359+
batch_size=2,
360+
drop_last=True)
361+
362+
for i, data in enumerate(dataloader()):
363+
assert isinstance(data, paddle.Tensor)
364+
assert data.shape == [2, 2, 3]
365+
366+
def test_main(self):
367+
for num_workers in [0, 2]:
368+
self.run_main(num_workers)
369+
370+
371+
class SingleFieldIterableDataset(IterableDataset):
372+
def __init__(self, sample_num):
373+
self.sample_num = sample_num
374+
375+
def __iter__(self):
376+
for _ in range(self.sample_num):
377+
yield np.random.random((2, 3)).astype('float32')
378+
379+
380+
class TestSingleFieldIterableDataset(TestSingleFieldDataset):
381+
def init_dataset(self):
382+
self.sample_num = 16
383+
self.dataset = SingleFieldIterableDataset(self.sample_num)
384+
385+
333386
class TestDataLoaderGenerateStates(unittest.TestCase):
334387
def setUp(self):
335388
self.inputs = [(0, 1), (0, 2), (1, 3)]

0 commit comments

Comments
 (0)