Skip to content

Commit 5d4a110

Browse files
Fix model average on multi-GPUs. (#11814) (#11826)
* Fix average_accumulate_op for parallel executor. * Fix model average on multi-GPUs.
1 parent 494cecd commit 5d4a110

File tree

1 file changed

+5
-12
lines changed

1 file changed

+5
-12
lines changed

python/paddle/fluid/optimizer.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,6 @@ class ModelAverage(Optimizer):
11131113
11141114
Args:
11151115
average_window_rate: The rate of average window.
1116-
params_grads: A list of parameter-grad variable pairs.
11171116
min_average_window: The minimum size of average window.
11181117
max_average_window: The maximum size of average window.
11191118
@@ -1122,8 +1121,8 @@ class ModelAverage(Optimizer):
11221121
.. code-block:: python
11231122
11241123
optimizer = fluid.optimizer.Momentum()
1125-
_, params_grads = optimizer.minimize(cost)
1126-
model_average = fluid.optimizer.ModelAverage(params_grads, 0.15,
1124+
optimizer.minimize(cost)
1125+
model_average = fluid.optimizer.ModelAverage(0.15,
11271126
min_average_window=10000,
11281127
max_average_window=20000)
11291128
for pass_id in range(args.pass_num):
@@ -1137,7 +1136,6 @@ class ModelAverage(Optimizer):
11371136

11381137
def __init__(self,
11391138
average_window_rate,
1140-
params_grads=None,
11411139
min_average_window=10000,
11421140
max_average_window=10000,
11431141
**kwargs):
@@ -1146,21 +1144,16 @@ def __init__(self,
11461144
self.min_average_window = min_average_window
11471145
self.max_average_window = max_average_window
11481146

1149-
self.params_grads = [] if params_grads is None else params_grads
1150-
params = {}
1151-
for param, grad in self.params_grads:
1152-
if param.do_model_average != False:
1153-
params[param.name] = (param, grad)
1147+
self.params_grads = []
11541148
for param in framework.default_main_program().global_block(
11551149
).all_parameters():
1156-
if param.name not in params and param.do_model_average != False:
1150+
if param.do_model_average != False:
11571151
grad = param.block.create_var(
11581152
name=unique_name.generate(".".join([param.name, 'tmp'])),
11591153
dtype=param.dtype,
11601154
persistable=False,
11611155
stop_gradient=True)
1162-
params[param.name] = (param, grad)
1163-
self.params_grads = params.values()
1156+
self.params_grads.append((param, grad))
11641157

11651158
for param, grad in self.params_grads:
11661159
self._append_average_accumulate_op(param)

0 commit comments

Comments
 (0)