Skip to content

Commit ca631ff

Browse files
committed
lookup table bug fix about lr, test=release/1.0.0 (#13946)
1 parent ceea195 commit ca631ff

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

python/paddle/fluid/framework.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,13 +1522,17 @@ def _lr_schedule_guard(self):
15221522
>>> with program.lr_schedule_guard():
15231523
>>> lr = lr * decay
15241524
"""
1525+
1526+
tmp_role = self._current_role
1527+
tmp_var = self._op_role_var
1528+
15251529
OpRole = core.op_proto_and_checker_maker.OpRole
15261530
self._current_role = OpRole.LRSched
15271531
# TODO(typhoonzero): how to set target learning rate var
15281532
self._op_role_var = []
15291533
yield
1530-
self._op_role_var = []
1531-
self._current_role = OpRole.Forward
1534+
self._op_role_var = tmp_var
1535+
self._current_role = tmp_role
15321536

15331537
def __str__(self):
15341538
"""

python/paddle/fluid/optimizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import print_function
1616
import re
1717
from collections import defaultdict
18-
from paddle.fluid.framework import Program, Variable, name_scope
18+
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program
1919
from . import framework
2020
from . import layers
2121
from .backward import append_backward
@@ -111,7 +111,8 @@ def _create_param_lr(self, param_and_grad):
111111
if param_lr == 1.0:
112112
return self._global_learning_rate()
113113
else:
114-
return self._global_learning_rate() * param_lr
114+
with default_main_program()._lr_schedule_guard():
115+
return self._global_learning_rate() * param_lr
115116

116117
def _create_accumulators(self, block, parameters):
117118
"""Create all accumulators needed by the parameters

0 commit comments

Comments
 (0)