Skip to content

Commit c24a59c

Browse files
[Adam/AdamW] Update adamw.py (#1426)
1 parent c8372d5 commit c24a59c

File tree

7 files changed

+13
-13
lines changed

7 files changed

+13
-13
lines changed

backends/gcu/tests/unittests_legacy/test_merged_adam_op_gcu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def run_adam_op(
5858

5959
if not use_merged:
6060
for i in range(len(param_vars)):
61-
_, _, _, _, _, _ = _legacy_C_ops.adam(
61+
_, _, _, _, _, *_ = _legacy_C_ops.adam(
6262
param_vars[i],
6363
grad_vars[i],
6464
lr_vars[i],

backends/mlu/tests/unittests/test_adamw_op_mlu.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def adamw_wrapper(
8181
with_decay=True,
8282
lazy_mode=False,
8383
):
84-
_, _, _, _, _, _ = paddle._C_ops.adamw_(
84+
_, _, _, _, _, *_ = paddle._C_ops.adamw_(
8585
param,
8686
grad,
8787
lr,
@@ -375,7 +375,7 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad(
375375
ref_moment_2 = moment2.astype(paddle.float32)
376376

377377
# reference code
378-
_, _, _, _, _, _ = paddle._C_ops.adamw_(
378+
_, _, _, _, _, *_ = paddle._C_ops.adamw_(
379379
ref_param,
380380
main_grad,
381381
lr,
@@ -398,7 +398,7 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad(
398398
)
399399

400400
if use_main_grad:
401-
_, _, _, _, _, _ = paddle._C_ops.adamw_(
401+
_, _, _, _, _, *_ = paddle._C_ops.adamw_(
402402
param,
403403
main_grad,
404404
lr,
@@ -426,7 +426,7 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad(
426426
master_weight.numpy(), ref_param.numpy(), atol=1e-5
427427
)
428428
else:
429-
_, _, _, _, _, _ = paddle._C_ops.adamw_(
429+
_, _, _, _, _, *_ = paddle._C_ops.adamw_(
430430
param,
431431
grad,
432432
lr,
@@ -973,7 +973,7 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad(
973973
ref_moment_2 = moment2.astype(paddle.float32)
974974

975975
# reference code
976-
_, _, _, _, _, _ = paddle._C_ops.adamw_(
976+
_, _, _, _, _, *_ = paddle._C_ops.adamw_(
977977
ref_param,
978978
main_grad,
979979
lr,
@@ -996,7 +996,7 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad(
996996
)
997997

998998
if use_main_grad:
999-
_, _, _, _, _, _ = paddle._C_ops.adamw_(
999+
_, _, _, _, _, *_ = paddle._C_ops.adamw_(
10001000
param,
10011001
main_grad,
10021002
lr,
@@ -1024,7 +1024,7 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad(
10241024
master_weight.numpy(), ref_param.numpy(), atol=1e-4
10251025
)
10261026
else:
1027-
_, _, _, _, _, _ = paddle._C_ops.adamw_(
1027+
_, _, _, _, _, *_ = paddle._C_ops.adamw_(
10281028
param,
10291029
grad,
10301030
lr,

backends/mlu/tests/unittests/test_merged_adam_op_mlu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def run_adam_op(
5555

5656
if not use_merged:
5757
for i in range(len(param_vars)):
58-
_, _, _, _, _, _ = _C_ops.adamw_(
58+
_, _, _, _, _, *_ = _C_ops.adamw_(
5959
param_vars[i],
6060
grad_vars[i],
6161
lr_vars[i],

backends/sdaa/sdaa_ext/python/custom_parallel/Adam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def _append_optimize_op(self, block, param_and_grad):
292292
moment1_ = moment1
293293
moment2_ = moment2
294294

295-
_, _, _, _, _, _ = paddle._C_ops.adam_(
295+
_, _, _, _, _, *_ = paddle._C_ops.adam_(
296296
param_,
297297
grad_,
298298
lr,

backends/sdaa/sdaa_ext/python/custom_parallel/AdamW.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def _append_optimize_op(self, block, param_and_grad):
294294
moment1_ = moment1
295295
moment2_ = moment2
296296

297-
_, _, _, _, _, _ = paddle._C_ops.adamw_(
297+
_, _, _, _, _, *_ = paddle._C_ops.adamw_(
298298
param_,
299299
grad_,
300300
lr,

backends/sdaa/tests/unittests/test_adam_op_sdaa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def adam_wrapper(
3737
epsilon=1e-4,
3838
lazy_mode=False,
3939
):
40-
_, _, _, _, _, _ = paddle._C_ops.adam_(
40+
_, _, _, _, _, *_ = paddle._C_ops.adam_(
4141
param,
4242
grad,
4343
LearningRate,

backends/sdaa/tests/unittests/test_merged_adam_op_sdaa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def run_adam_op(
6969
if not use_merged:
7070
paddle.set_device("cpu")
7171
for i in range(len(param_vars)):
72-
_, _, _, _, _, _ = _legacy_C_ops.adam(
72+
_, _, _, _, _, *_ = _legacy_C_ops.adam(
7373
param_vars[i],
7474
grad_vars[i],
7575
lr_vars[i],

0 commit comments

Comments
 (0)