Skip to content

Commit 2de7a7f

Browse files
authored
[cherry pick] Modify adam to adamw in Optimizer AdamW (#36028) (#36103)
The AdamW optimizer modify the op from adamw to adam in pr35521, this is a inappropriate modify. Modify adam to adamw in AdamW.
1 parent c3a0eaa commit 2de7a7f

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

paddle/fluid/pybind/op_function_generator.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
7171
{"adam",
7272
{"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow",
7373
"Beta2Pow", "MasterParam"}},
74+
{"adamw",
75+
{"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow",
76+
"Beta2Pow", "MasterParam"}},
7477
};
7578

7679
// NOTE(zhiqiu): Like op_ins_map.
@@ -110,6 +113,9 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
110113
{"adam",
111114
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
112115
"MasterParamOut"}},
116+
{"adamw",
117+
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
118+
"MasterParamOut"}},
113119
};
114120

115121
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
@@ -129,7 +135,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
129135
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
130136
"MasterParamOut"}},
131137
{"adamw",
132-
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
138+
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
139+
"MasterParamOut"}},
133140
{"average_accumulates",
134141
{"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates",
135142
"out_old_num_accumulates", "out_num_updates"}},

python/paddle/optimizer/adamw.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,14 +298,14 @@ def _append_optimize_op(self, block, param_and_grad):
298298
_beta2 = self._beta2 if not isinstance(
299299
self._beta2, Variable) else self._beta2.numpy().item(0)
300300

301-
_, _, _, _, _, _ = _C_ops.adam(
301+
_, _, _, _, _, _ = _C_ops.adamw(
302302
param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
303303
beta1_pow_acc, beta2_pow_acc, master_weight, param_and_grad[0],
304304
moment1, moment2, beta1_pow_acc, beta2_pow_acc, master_weight,
305305
'epsilon', self._epsilon, 'lazy_mode', self._lazy_mode,
306306
'min_row_size_to_use_multithread', 1000, 'beta1', _beta1,
307307
'beta2', _beta2, 'coeff', self._coeff, 'multi_precision',
308-
find_master)
308+
find_master, "lr_ratio", lr_ratio_)
309309

310310
return None
311311

0 commit comments

Comments
 (0)