1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import logging
16+ from ..log_helper import get_logger
17+
18+ from collections .abc import Sequence
19+
1520
1621class _DatasetFetcher (object ):
1722 def __init__ (self , dataset , auto_collate_batch , collate_fn , drop_last ):
1823 self .dataset = dataset
1924 self .auto_collate_batch = auto_collate_batch
2025 self .collate_fn = collate_fn
2126 self .drop_last = drop_last
27+ self ._is_warning_logged = False
2228
2329 def fetch (self , batch_indices ):
2430 raise NotImplementedError ("'fetch' not implement for class {}" .format (
2531 self .__class__ .__name__ ))
2632
33+ def _log_warning (self ):
34+ warn_str = "Detect dataset only contains single fileds, return format " \
35+ "changed since Paddle 2.1. In Paddle <= 2.0, DataLoader add " \
36+ "a list surround output data(e.g. return [data]), and in " \
37+ "Paddle >= 2.1, DataLoader return the single filed directly " \
38+ "(e.g. return data). For example, in following code: \n \n "
39+ warn_str += \
40+ "import numpy as np\n " \
41+ "from paddle.io import DataLoader, Dataset\n \n " \
42+ "class RandomDataset(Dataset):\n " \
43+ " def __getitem__(self, idx):\n " \
44+ " data = np.random.random((2, 3)).astype('float32')\n \n " \
45+ " return data\n \n " \
46+ " def __len__(self):\n " \
47+ " return 10\n \n " \
48+ "dataset = RandomDataset()\n " \
49+ "loader = DataLoader(dataset, batch_size=1)\n " \
50+ "data = next(loader())\n \n "
51+
52+ warn_str += "In Paddle <= 2.0, data is in format '[Tensor(shape=(1, 2, 3), " \
53+ "dtype=float32)]', and in Paddle >= 2.1, data is in format" \
54+ " 'Tensor(shape=(1, 2, 3), dtype=float32)'\n "
55+
56+ logger = get_logger (
57+ "DataLoader" , logging .INFO , fmt = '%(levelname)s: %(message)s' )
58+ logger .warning (warn_str )
59+
2760
2861class _IterableDatasetFetcher (_DatasetFetcher ):
2962 def __init__ (self , dataset , auto_collate_batch , collate_fn , drop_last ):
@@ -40,9 +73,14 @@ def fetch(self, batch_indices):
4073 data .append (next (self .dataset_iter ))
4174 except StopIteration :
4275 break
76+
4377 if len (data ) == 0 or (self .drop_last and
4478 len (data ) < len (batch_indices )):
4579 raise StopIteration
80+ if not isinstance (data [0 ],
81+ Sequence ) and not self ._is_warning_logged :
82+ self ._log_warning ()
83+ self ._is_warning_logged = True
4684 else :
4785 data = next (self .dataset_iter )
4886
@@ -59,6 +97,11 @@ def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
5997 def fetch (self , batch_indices ):
6098 if self .auto_collate_batch :
6199 data = [self .dataset [idx ] for idx in batch_indices ]
100+
101+ if not isinstance (data [0 ],
102+ Sequence ) and not self ._is_warning_logged :
103+ self ._log_warning ()
104+ self ._is_warning_logged = True
62105 else :
63106 data = self .dataset [batch_indices ]
64107
0 commit comments