Skip to content

Commit 2fdbc1c

Browse files
author
Yancey
authored
hidden bcast_params call in dist train (#11575)
1 parent b94f784 commit 2fdbc1c

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

benchmark/fluid/fluid_benchmark.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,6 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
264264
break
265265
else:
266266
loss, = exe.run([avg_loss.name], feed=feeder.feed(data))
267-
if args.update_method == "pserver":
268-
exe.bcast_params()
269267
if args.use_reader_op:
270268
num_samples += args.batch_size * args.gpus
271269
else:

python/paddle/fluid/parallel_executor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def __init__(self,
7171
num_trainers=1,
7272
trainer_id=0,
7373
**kwargs):
74-
7574
if len(kwargs) != 0:
7675
err_msg = ""
7776
for key in kwargs:
@@ -130,6 +129,11 @@ def __init__(self,
130129
main = main_program
131130
main = main if main else framework.default_main_program()
132131
scope = executor.global_scope()
132+
# FIXME(Yancey1989): it's a temporary approach to determinate the distribute
133+
# train program, call self.bcast_param() at the end of each mini-batch.
134+
self.is_dist = True if "recv" in [
135+
op.type for op in main.global_block().ops
136+
] else False
133137

134138
if share_vars_from and not isinstance(share_vars_from,
135139
ParallelExecutor):
@@ -262,6 +266,10 @@ def run(self, fetch_list, feed=None, feed_dict=None):
262266
fetch_var_name = '@FETCHED_VAR_NAME@'
263267
self.executor.run(fetch_list, fetch_var_name)
264268
arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
269+
270+
if self.is_dist:
271+
self.bcast_params()
272+
265273
return [arr[i] for i in range(len(arr))]
266274

267275
def bcast_params(self):

0 commit comments

Comments
 (0)