Skip to content

Commit 5ce2df9

Browse files
authored
Merge pull request #10566 from reyoung/feature/train_by_pe
Parallel Executor revised feeder
2 parents 177324b + 9c8383c commit 5ce2df9

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

python/paddle/fluid/data_feeder.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import core
1717
import numpy
1818
import six.moves as six
19+
import multiprocessing
1920

2021
from framework import Variable, default_main_program
2122

@@ -116,3 +117,60 @@ def feed(self, iterable):
116117
for each_name, each_converter in six.zip(self.feed_names, converter):
117118
ret_dict[each_name] = each_converter.done()
118119
return ret_dict
120+
121+
def feed_parallel(self, iterable, num_places=None):
122+
if isinstance(self.place, core.CUDAPlace):
123+
places = [
124+
core.CUDAPlace(i)
125+
for i in six.xrange(self._get_number_of_places_(num_places))
126+
]
127+
else:
128+
places = [
129+
core.CPUPlace()
130+
for _ in six.xrange(self._get_number_of_places_(num_places))
131+
]
132+
133+
if len(iterable) != len(places):
134+
raise ValueError("feed_parallel takes multiple mini-batches. Each "
135+
"mini-batch will be feed on each device. The "
136+
"number of devices and number of mini-batches "
137+
"must be same.")
138+
139+
place = self.place
140+
for p, batch in six.zip(places, iterable):
141+
self.place = p
142+
yield self.feed(batch)
143+
self.place = place
144+
145+
def _get_number_of_places_(self, num_places):
146+
if num_places is not None:
147+
return int(num_places)
148+
elif isinstance(self.place, core.CUDAPlace):
149+
return core.get_cuda_device_count()
150+
else:
151+
return multiprocessing.cpu_count()
152+
153+
def decorate_reader(self,
154+
reader,
155+
multi_devices,
156+
num_places=None,
157+
drop_last=True):
158+
def __reader_creator__():
159+
if not multi_devices:
160+
for item in reader():
161+
yield self.feed(item)
162+
else:
163+
num = self._get_number_of_places_(num_places)
164+
item = []
165+
for batch in reader():
166+
item.append(batch)
167+
if len(item) == num:
168+
yield list(self.feed_parallel(item, num))
169+
item = []
170+
if not drop_last and len(item) != 0:
171+
raise ValueError(
172+
"The data batch which cannot fit for devices will be "
173+
"dropped is not implementation. Other strategies are "
174+
"not implemented")
175+
176+
return __reader_creator__

python/paddle/fluid/tests/unittests/test_parallel_executor.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,5 +796,42 @@ def test_update_sparse_parameter(self):
796796
self.parallel_exe(train_inputs, seed=1)
797797

798798

799+
class TestFeedParallel(unittest.TestCase):
800+
def test_main(self):
801+
main = fluid.Program()
802+
startup = fluid.Program()
803+
startup.random_seed = 1
804+
with fluid.scope_guard(fluid.core.Scope()):
805+
with fluid.program_guard(main, startup):
806+
data = fluid.layers.data(
807+
name='image', shape=[3, 224, 224], dtype='float32')
808+
label = fluid.layers.data(
809+
name='label', shape=[1], dtype='int64')
810+
out = Lenet(data, class_dim=102)
811+
loss = fluid.layers.cross_entropy(input=out, label=label)
812+
loss = fluid.layers.mean(loss)
813+
opt = fluid.optimizer.Momentum(
814+
learning_rate=0.1,
815+
momentum=0.9,
816+
regularization=fluid.regularizer.L2Decay(1e-4))
817+
818+
opt.minimize(loss)
819+
place = fluid.CUDAPlace(0)
820+
feeder = fluid.DataFeeder(place=place, feed_list=[data, label])
821+
reader = feeder.decorate_reader(
822+
paddle.batch(
823+
flowers.train(), batch_size=16), multi_devices=True)
824+
exe = fluid.Executor(place)
825+
exe.run(startup)
826+
pe = fluid.ParallelExecutor(
827+
use_cuda=True, loss_name=loss.name, main_program=main)
828+
829+
for batch_id, data in enumerate(reader()):
830+
loss_np = np.array(pe.run(feed=data, fetch_list=[loss.name])[0])
831+
print batch_id, loss_np
832+
if batch_id == 2:
833+
break
834+
835+
799836
if __name__ == '__main__':
800837
unittest.main()

0 commit comments

Comments
 (0)