Skip to content

Commit 2644ef3

Browse files
authored
Merge pull request #16923 from xjqbest/my_cherry_pick_16746
Merge pull request #16746 from xjqbest/dataset_merge_develop
2 parents 1ffbfc4 + a742bb0 commit 2644ef3

File tree

3 files changed

+57
-1
lines changed

3 files changed

+57
-1
lines changed

python/paddle/fluid/dataset.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def local_shuffle(self):
213213
>>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
214214
>>> filelist = ["a.txt", "b.txt"]
215215
>>> dataset.set_filelist(filelist)
216+
>>> dataset.load_into_memory()
216217
>>> dataset.local_shuffle()
217218
"""
218219
self.dataset.local_shuffle()
@@ -230,6 +231,7 @@ def global_shuffle(self, fleet=None):
230231
>>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
231232
>>> filelist = ["a.txt", "b.txt"]
232233
>>> dataset.set_filelist(filelist)
234+
>>> dataset.load_into_memory()
233235
>>> dataset.global_shuffle(fleet)
234236
235237
Args:
@@ -249,6 +251,25 @@ def global_shuffle(self, fleet=None):
249251
if fleet is not None:
250252
fleet.fleet_instance.role_maker_._barrier_worker()
251253

254+
def release_memory(self):
255+
"""
256+
Release InMemoryDataset memory data, when data will not be used again.
257+
258+
Example:
259+
>>> import paddle.fluid as fluid
260+
>>> import paddle.fluid.incubate.fleet.parameter_server as fleet
261+
>>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
262+
>>> filelist = ["a.txt", "b.txt"]
263+
>>> dataset.set_filelist(filelist)
264+
>>> dataset.load_into_memory()
265+
>>> dataset.global_shuffle(fleet)
266+
>>> exe = fluid.Executor(fluid.CPUPlace())
267+
>>> exe.run(fluid.default_startup_program())
268+
>>> exe.train_from_dataset(fluid.default_main_program(), dataset)
269+
>>> dataset.release_memory()
270+
"""
271+
self.dataset.release_memory()
272+
252273

253274
class QueueDataset(DatasetBase):
254275
"""

python/paddle/fluid/incubate/fleet/base/role_maker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _finalize(self):
128128
"""
129129
finalize the current MPI instance.
130130
"""
131-
self.comm_.finalize()
131+
pass
132132

133133

134134
class MPISymetricRoleMaker(MPIRoleMaker):

python/paddle/fluid/incubate/fleet/parameter_server/__init__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,40 @@ def save_pserver_model(self, save_path):
241241
"""
242242
self._fleet_ptr.save_model(save_path)
243243

244+
def split_filelist(self, filelist):
245+
"""
246+
split filelist before distributed training,
247+
for example, filelist is [a, b, c ,d, e] and trainer_num = 2,
248+
then trainer 0 gets [a, b, c] and trainer 1 gets [d, e]
249+
250+
Example:
251+
>>> all_filelist = ["a.txt", "b.txt", "c.txt"]
252+
>>> my_filelist = fleet.split_filelist(all_filelist)
253+
>>> dataset = fluid.DatasetFactory().create_dataset()
254+
>>> dataset.set_filelist(my_filelist)
255+
256+
Args:
257+
filelist(list): list of filename, can be local or hdfs/afs.
258+
259+
Returns:
260+
list of filename which belongs to this trainer.
261+
"""
262+
file_num = len(filelist)
263+
trainer_id = self.get_worker_index()
264+
trainer_num = self.get_worker_num()
265+
if trainer_num > file_num:
266+
raise ValueError("trainer_num should be <= file_num : "
267+
"%s > %s" % (trainer_num, file_num))
268+
# get interval of filelist, it's [ )
269+
start = 0
270+
end = 0
271+
for i in range(0, trainer_id + 1):
272+
length = file_num / trainer_num + (i < (file_num % trainer_num))
273+
start = end
274+
end += length
275+
my_filelist = filelist[start:end]
276+
return my_filelist
277+
244278
def _set_opt_info(self, opt_info):
245279
"""
246280
this function saves the result from DistributedOptimizer.minimize()
@@ -337,3 +371,4 @@ def minimize(self,
337371
worker_num = fleet_instance.get_worker_num
338372
server_num = fleet_instance.get_server_num
339373
worker_index = fleet_instance.get_worker_index
374+
split_filelist = fleet_instance.split_filelist

0 commit comments

Comments
 (0)