Skip to content

Commit 5c84eac

Browse files
authored
[cherry-pick] make default_collate_fn visible (#25324)
* make default_collate_fn visible. test=develop. test=release/1.8
1 parent 914fd81 commit 5c84eac

File tree

4 files changed

+228
-7
lines changed

4 files changed

+228
-7
lines changed

python/paddle/fluid/dataloader/dataloader_iter.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,27 @@
3838
MP_INDICES_CHECK_INTERVAL = 5
3939

4040

41-
def _default_collate_fn(batch):
41+
def default_collate_fn(batch):
42+
"""
43+
Default batch collating function for :code:`fluid.io.DataLoader`,
44+
batch should be a list of samples, and each sample should be a list
45+
of fields as follows:
46+
47+
[[field1, field2, ...], [field1, field2, ...], ...]
48+
49+
This default collate function zipped each field together and stack
50+
each field as the batch field as follows:
51+
52+
[batch_field1, batch_field2, ...]
53+
54+
Args:
55+
batch(list of list of numpy array): the batch data, each fields
56+
should be a numpy array, each sample should be a list of
57+
fields, and batch should be a list of sample.
58+
59+
Returns:
60+
a list of numpy array: collated batch
61+
"""
4262
sample = batch[0]
4363
# dataset has only 1 field
4464
if isinstance(sample, np.ndarray):
@@ -82,7 +102,7 @@ def __init__(self, loader):
82102
self._return_list = loader.return_list
83103
self._batch_sampler = loader.batch_sampler
84104
self._sampler_iter = iter(loader.batch_sampler)
85-
self._collate_fn = loader.collate_fn or _default_collate_fn
105+
self._collate_fn = loader.collate_fn or default_collate_fn
86106
self._num_workers = loader.num_workers
87107
self._use_buffer_reader = loader.use_buffer_reader
88108
self._use_shared_memory = loader.use_shared_memory

python/paddle/fluid/reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from .data_feeder import DataFeeder, BatchedTensorProvider
2424
from .multiprocess_utils import multiprocess_queue_set, CleanupFuncRegistrar, _cleanup_mmap, _cleanup, _set_SIGCHLD_handler
2525
from .dataloader import BatchSampler, Dataset
26-
from .dataloader.dataloader_iter import _DataLoaderIterSingleProcess, _DataLoaderIterMultiProcess
26+
from .dataloader.dataloader_iter import _DataLoaderIterSingleProcess, _DataLoaderIterMultiProcess, default_collate_fn
2727
from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer
2828
from .unique_name import UniqueNameGenerator
2929
import logging
@@ -43,7 +43,7 @@
4343
# NOTE: [ avoid hanging & failed quickly ] These value is used in getting data from another process
4444
QUEUE_GET_TIMEOUT = 60
4545

46-
__all__ = ['PyReader', 'DataLoader']
46+
__all__ = ['PyReader', 'DataLoader', 'default_collate_fn']
4747

4848
data_loader_unique_name_generator = UniqueNameGenerator()
4949

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ if (APPLE OR WIN32)
210210
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_fds_clear)
211211
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_exit_func)
212212
list(REMOVE_ITEM TEST_OPS test_imperative_signal_handler)
213+
list(REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_exception)
213214
endif()
214215

215216
if(NOT WITH_GPU OR WIN32 OR APPLE)
@@ -378,7 +379,8 @@ set_tests_properties(test_parallel_executor_crf test_sync_batch_norm_op test_inp
378379
PROPERTIES LABELS "RUN_TYPE=DIST" RUN_SERIAL TRUE)
379380

380381
if(NOT WIN32 AND NOT APPLE)
381-
set_tests_properties(test_imperative_data_loader_base PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE)
382-
set_tests_properties(test_imperative_data_loader_exception PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE)
383-
set_tests_properties(test_imperative_data_loader_fds_clear PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE)
382+
set_tests_properties(test_imperative_data_loader_base PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
383+
set_tests_properties(test_imperative_data_loader_fds_clear PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
384+
set_tests_properties(test_imperative_data_loader_exception PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
385+
set_tests_properties(test_multiprocess_dataloader_exception PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
384386
endif()
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import division
16+
17+
import os
18+
import sys
19+
import six
20+
import time
21+
import unittest
22+
import multiprocessing
23+
import numpy as np
24+
25+
import paddle.fluid as fluid
26+
from paddle.fluid.io import Dataset, BatchSampler, DataLoader
27+
from paddle.fluid.dygraph.nn import Linear
28+
from paddle.fluid.dygraph.base import to_variable
29+
30+
31+
class RandomDataset(Dataset):
32+
def __init__(self, sample_num):
33+
self.sample_num = sample_num
34+
35+
def __getitem__(self, idx):
36+
np.random.seed(idx)
37+
image = np.random.random([784]).astype('float32')
38+
label = np.random.randint(0, 9, (1, )).astype('int64')
39+
return image, label
40+
41+
def __len__(self):
42+
return self.sample_num
43+
44+
45+
class TestDataLoaderAssert(unittest.TestCase):
46+
def test_main(self):
47+
place = fluid.cpu_places()[0]
48+
with fluid.dygraph.guard(place):
49+
dataset = RandomDataset(100)
50+
batch_sampler = BatchSampler(dataset=dataset, batch_size=4)
51+
52+
# dataset is not instance of Dataset
53+
try:
54+
loader = DataLoader(dataset=batch_sampler, places=place)
55+
self.assertTrue(False)
56+
except AssertionError:
57+
pass
58+
59+
# places is None
60+
try:
61+
loader = DataLoader(dataset=dataset, places=None)
62+
self.assertTrue(False)
63+
except AssertionError:
64+
pass
65+
66+
# num_workers < 0
67+
try:
68+
loader = DataLoader(
69+
dataset=dataset, places=place, num_workers=-1)
70+
self.assertTrue(False)
71+
except AssertionError:
72+
pass
73+
74+
# timeout < 0
75+
try:
76+
loader = DataLoader(dataset=dataset, places=place, timeout=-1)
77+
self.assertTrue(False)
78+
except AssertionError:
79+
pass
80+
81+
# batch_sampler is not instance of BatchSampler
82+
try:
83+
loader = DataLoader(
84+
dataset=dataset, places=place, batch_sampler=dataset)
85+
self.assertTrue(False)
86+
except AssertionError:
87+
pass
88+
89+
# set batch_sampler and shuffle/batch_size/drop_last
90+
try:
91+
loader = DataLoader(
92+
dataset=dataset,
93+
places=place,
94+
batch_sampler=batch_sampler,
95+
shuffle=True,
96+
drop_last=True)
97+
self.assertTrue(False)
98+
except AssertionError:
99+
pass
100+
101+
# set batch_sampler correctly
102+
try:
103+
loader = DataLoader(
104+
dataset=dataset, places=place, batch_sampler=batch_sampler)
105+
self.assertTrue(True)
106+
except AssertionError:
107+
self.assertTrue(False)
108+
109+
110+
# CI Converage cannot record stub in subprocess,
111+
# HACK a _worker_loop in main process call here
112+
class TestDataLoaderWorkerLoop(unittest.TestCase):
113+
def run_without_worker_done(self, use_shared_memory=True):
114+
try:
115+
place = fluid.cpu_places()[0]
116+
with fluid.dygraph.guard(place):
117+
dataset = RandomDataset(800)
118+
119+
# test init_fn
120+
def _init_fn(worker_id):
121+
pass
122+
123+
# test collate_fn
124+
def _collate_fn(sample_list):
125+
return [
126+
np.stack(
127+
s, axis=0) for s in list(zip(*sample_list))
128+
]
129+
130+
loader = DataLoader(
131+
dataset,
132+
num_workers=1,
133+
places=place,
134+
use_shared_memory=use_shared_memory)
135+
assert loader.num_workers > 0, \
136+
"go to AssertionError and pass in Mac and Windows"
137+
loader = iter(loader)
138+
print("loader length", len(loader))
139+
indices_queue = multiprocessing.Queue()
140+
for i in range(10):
141+
indices_queue.put([i, i + 10])
142+
indices_queue.put(None)
143+
loader._worker_loop(
144+
loader._dataset, indices_queue, loader._data_queue,
145+
loader._workers_done_event, _collate_fn, _init_fn, 0)
146+
self.assertTrue(False)
147+
except AssertionError:
148+
pass
149+
except Exception:
150+
self.assertTrue(False)
151+
152+
def run_with_worker_done(self, use_shared_memory=True):
153+
try:
154+
place = fluid.cpu_places()[0]
155+
with fluid.dygraph.guard(place):
156+
dataset = RandomDataset(800)
157+
158+
# test init_fn
159+
def _init_fn(worker_id):
160+
pass
161+
162+
# test collate_fn
163+
def _collate_fn(sample_list):
164+
return [
165+
np.stack(
166+
s, axis=0) for s in list(zip(*sample_list))
167+
]
168+
169+
loader = DataLoader(
170+
dataset,
171+
num_workers=1,
172+
places=place,
173+
use_shared_memory=use_shared_memory)
174+
assert loader.num_workers > 0, \
175+
"go to AssertionError and pass in Mac and Windows"
176+
loader = iter(loader)
177+
print("loader length", len(loader))
178+
indices_queue = multiprocessing.Queue()
179+
for i in range(10):
180+
indices_queue.put([i, i + 10])
181+
indices_queue.put(None)
182+
loader._workers_done_event.set()
183+
loader._worker_loop(
184+
loader._dataset, indices_queue, loader._data_queue,
185+
loader._workers_done_event, _collate_fn, _init_fn, 0)
186+
self.assertTrue(True)
187+
except AssertionError:
188+
pass
189+
except Exception:
190+
self.assertTrue(False)
191+
192+
def test_main(self):
193+
for use_shared_memory in [True, False]:
194+
self.run_without_worker_done(use_shared_memory)
195+
self.run_with_worker_done(use_shared_memory)
196+
197+
198+
if __name__ == '__main__':
199+
unittest.main()

0 commit comments

Comments
 (0)