@@ -401,11 +401,8 @@ def get_pserver_program(self, endpoint):
401
401
# HACK: optimization global ops only used to scale beta1 and beta2
402
402
# replace it with dependency engine.
403
403
for op in self .optimize_ops :
404
- if op .type == "scale" :
405
- for in_name in op .input_arg_names :
406
- if in_name .startswith ("beta1_pow_acc" ) or \
407
- in_name .startswith ("beta2_pow_acc" ):
408
- global_ops .append (op )
404
+ if self ._is_adam_connected_op (op ):
405
+ global_ops .append (op )
409
406
410
407
def __append_optimize_op__ (op , block , grad_to_block_id ):
411
408
if self ._is_opt_op (op ):
@@ -1152,13 +1149,20 @@ def _get_optimize_pass(self):
1152
1149
op .input ("Param" )[0 ]),
1153
1150
self .origin_program .global_block ().var (
1154
1151
op .input ("Grad" )[0 ])))
1155
- elif op .type == "scale" :
1156
- # for adam optimize op
1157
- for in_name in op .input_arg_names :
1158
- if in_name .startswith ("beta1_pow_acc" ) or \
1159
- in_name .startswith ("beta2_pow_acc" ):
1160
- opt_ops .append (op )
1161
- break
1152
+ elif self ._is_adam_connected_op (op ):
1153
+ opt_ops .append (op )
1162
1154
else :
1163
1155
pass
1164
1156
return opt_ops , params_grads
1157
+
1158
+ def _is_adam_connected_op (self , op ):
1159
+ """
1160
+ A hack function to determinate whether the input operator
1161
+ is connected to optimize operator.
1162
+ """
1163
+ if op .type == "scale" :
1164
+ for in_name in op .input_arg_names :
1165
+ if in_name .startswith ("beta1_pow_acc" ) or \
1166
+ in_name .startswith ("beta2_pow_acc" ):
1167
+ return True
1168
+ return False
0 commit comments