@@ -1182,18 +1182,39 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
1182
1182
program = optimize_block .program
1183
1183
pserver_block = program .global_block ()
1184
1184
new_inputs = dict ()
1185
+
1185
1186
# update param/grad shape first, then other inputs like
1186
1187
# moment can use the updated shape
1188
+ def _get_param_block (opt_op ):
1189
+ # param is already created on global program
1190
+ param_block = None
1191
+ for p in self .param_grad_ep_mapping [endpoint ]["params" ]:
1192
+ if same_or_split_var (p .name , opt_op .input ("Param" )[0 ]):
1193
+ param_block = p
1194
+ break
1195
+ return param_block
1196
+
1187
1197
for key in opt_op .input_names :
1188
1198
if key == "Grad" :
1189
1199
new_inputs [key ] = merged_var
1200
+ # For RMSProp optimizer
1201
+ elif key == "Moment" or key == "MeanSquare" :
1202
+ param_block = _get_param_block (opt_op )
1203
+ if not param_block :
1204
+ return
1205
+ moment_var = origin_program .global_block ().vars [opt_op .input (
1206
+ key )[0 ]]
1207
+ tmpvar = pserver_block .create_var (
1208
+ name = moment_var .name ,
1209
+ persistable = moment_var .persistable ,
1210
+ dtype = moment_var .dtype ,
1211
+ # change to use same shape as param
1212
+ # TODO(typhoonzero): didn't append .block in the var name,
1213
+ # may affect checkpoint saving? Need to verify.
1214
+ shape = param_block .shape )
1215
+ new_inputs [key ] = tmpvar
1190
1216
elif key == "Param" :
1191
- # param is already created on global program
1192
- param_block = None
1193
- for p in self .param_grad_ep_mapping [endpoint ]["params" ]:
1194
- if same_or_split_var (p .name , opt_op .input (key )[0 ]):
1195
- param_block = p
1196
- break
1217
+ param_block = _get_param_block (opt_op )
1197
1218
if not param_block :
1198
1219
return
1199
1220
tmpvar = pserver_block .create_var (
@@ -1219,7 +1240,7 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
1219
1240
1220
1241
for key in opt_op .input_names :
1221
1242
new_shape = None
1222
- if key in ["Param" , "Grad" , "LearningRate" ]:
1243
+ if key in ["Param" , "Grad" , "LearningRate" , "Moment" , "MeanSquare" ]:
1223
1244
continue
1224
1245
var = self .origin_program .global_block ().vars [opt_op .input (key )[0 ]]
1225
1246
# update accumulator variable shape
0 commit comments