Skip to content

Commit 98f38ae

Browse files
authored
Merge pull request #14243 from panyx0718/fix2
fix test
2 parents 34e9e59 + 9735e30 commit 98f38ae

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

python/paddle/fluid/tests/unittests/test_dist_base.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,18 @@ def run_trainer(self, args):
9898
strategy.allow_op_delay = False
9999

100100
build_stra = fluid.BuildStrategy()
101-
if args.batch_merge_repeat > 1:
102-
pass_builder = build_stra._create_passes_from_strategy()
103-
mypass = pass_builder.insert_pass(
104-
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
105-
mypass.set_int("num_repeats", args.batch_merge_repeat)
106101

107102
if args.use_reduce:
108103
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
109104
else:
110105
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
111106

107+
if args.batch_merge_repeat > 1:
108+
pass_builder = build_stra._create_passes_from_strategy()
109+
mypass = pass_builder.insert_pass(
110+
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
111+
mypass.set_int("num_repeats", args.batch_merge_repeat)
112+
112113
exe = fluid.ParallelExecutor(
113114
args.use_cuda,
114115
loss_name=avg_cost.name,

0 commit comments

Comments
 (0)