Skip to content

Commit a742bb0

Browse files
guru4elephantxjqbest
authored andcommitted
Merge pull request #16746 from xjqbest/dataset_merge_develop
move split filelist from trainer.py to fleet & fix error
1 parent 1237dfa commit a742bb0

File tree

3 files changed

+57
-1
lines changed

3 files changed

+57
-1
lines changed

python/paddle/fluid/dataset.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def local_shuffle(self):
213213
>>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
214214
>>> filelist = ["a.txt", "b.txt"]
215215
>>> dataset.set_filelist(filelist)
216+
>>> dataset.load_into_memory()
216217
>>> dataset.local_shuffle()
217218
"""
218219
self.dataset.local_shuffle()
@@ -230,6 +231,7 @@ def global_shuffle(self, fleet=None):
230231
>>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
231232
>>> filelist = ["a.txt", "b.txt"]
232233
>>> dataset.set_filelist(filelist)
234+
>>> dataset.load_into_memory()
233235
>>> dataset.global_shuffle(fleet)
234236
235237
Args:
@@ -247,6 +249,25 @@ def global_shuffle(self, fleet=None):
247249
if fleet is not None:
248250
fleet.fleet_instance.role_maker_._barrier_worker()
249251

252+
def release_memory(self):
253+
"""
254+
Release InMemoryDataset memory data, when data will not be used again.
255+
256+
Example:
257+
>>> import paddle.fluid as fluid
258+
>>> import paddle.fluid.incubate.fleet.parameter_server as fleet
259+
>>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
260+
>>> filelist = ["a.txt", "b.txt"]
261+
>>> dataset.set_filelist(filelist)
262+
>>> dataset.load_into_memory()
263+
>>> dataset.global_shuffle(fleet)
264+
>>> exe = fluid.Executor(fluid.CPUPlace())
265+
>>> exe.run(fluid.default_startup_program())
266+
>>> exe.train_from_dataset(fluid.default_main_program(), dataset)
267+
>>> dataset.release_memory()
268+
"""
269+
self.dataset.release_memory()
270+
250271

251272
class QueueDataset(DatasetBase):
252273
"""

python/paddle/fluid/incubate/fleet/base/role_maker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _finalize(self):
128128
"""
129129
finalize the current MPI instance.
130130
"""
131-
self.comm_.finalize()
131+
pass
132132

133133

134134
class MPISymetricRoleMaker(MPIRoleMaker):

python/paddle/fluid/incubate/fleet/parameter_server/__init__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,40 @@ def save_pserver_model(self, save_path):
228228
"""
229229
self._fleet_ptr.save_model(save_path)
230230

231+
def split_filelist(self, filelist):
232+
"""
233+
split filelist before distributed training,
234+
for example, filelist is [a, b, c ,d, e] and trainer_num = 2,
235+
then trainer 0 gets [a, b, c] and trainer 1 gets [d, e]
236+
237+
Example:
238+
>>> all_filelist = ["a.txt", "b.txt", "c.txt"]
239+
>>> my_filelist = fleet.split_filelist(all_filelist)
240+
>>> dataset = fluid.DatasetFactory().create_dataset()
241+
>>> dataset.set_filelist(my_filelist)
242+
243+
Args:
244+
filelist(list): list of filename, can be local or hdfs/afs.
245+
246+
Returns:
247+
list of filename which belongs to this trainer.
248+
"""
249+
file_num = len(filelist)
250+
trainer_id = self.get_worker_index()
251+
trainer_num = self.get_worker_num()
252+
if trainer_num > file_num:
253+
raise ValueError("trainer_num should be <= file_num : "
254+
"%s > %s" % (trainer_num, file_num))
255+
# get interval of filelist, it's [ )
256+
start = 0
257+
end = 0
258+
for i in range(0, trainer_id + 1):
259+
length = file_num / trainer_num + (i < (file_num % trainer_num))
260+
start = end
261+
end += length
262+
my_filelist = filelist[start:end]
263+
return my_filelist
264+
231265
def _set_opt_info(self, opt_info):
232266
"""
233267
this function saves the result from DistributedOptimizer.minimize()
@@ -324,3 +358,4 @@ def minimize(self,
324358
worker_num = fleet_instance.get_worker_num
325359
server_num = fleet_instance.get_server_num
326360
worker_index = fleet_instance.get_worker_index
361+
split_filelist = fleet_instance.split_filelist

0 commit comments

Comments
 (0)