Skip to content

Commit 679a4c2

Browse files
Fix lost of learning rate variable in distillatoin when using lr decay. (#16471)
test=develop
1 parent 57dc3c1 commit 679a4c2

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from ..core.strategy import Strategy
16-
from ....framework import Program, program_guard
16+
from ....framework import Program, Variable, program_guard
1717
from .... import Executor
1818
import logging
1919

@@ -74,8 +74,17 @@ def _create_distillation_graph(self, context):
7474
startup_program = Program()
7575
with program_guard(graph.program, startup_program):
7676
context.distiller_optimizer._name = 'distillation_optimizer'
77-
context.distiller_optimizer.minimize(
78-
graph.var(graph.out_nodes['loss'])._var)
77+
78+
# The learning rate variable may be created in other program.
79+
# Update information in optimizer to make
80+
# learning rate variable being accessible in current program.
81+
optimizer = context.distiller_optimizer
82+
if isinstance(optimizer._learning_rate, Variable):
83+
optimizer._learning_rate_map[
84+
graph.program] = optimizer._learning_rate
85+
86+
optimizer.minimize(graph.var(graph.out_nodes['loss'])._var)
87+
7988
exe = Executor(context.place)
8089
exe.run(startup_program, scope=context.scope)
8190

python/paddle/fluid/contrib/slim/graph/graph_wrapper.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,12 @@ def get_optimize_graph(self, optimizer, place, scope, no_grad_var_names=[]):
402402
elif 'cost' in graph.out_nodes:
403403
target_name = graph.out_nodes['cost']
404404
target = graph.var(target_name)._var
405+
# The learning rate variable may be created in other program.
406+
# Update information in optimizer to make
407+
# learning rate variable being accessible in current program.
408+
if isinstance(optimizer._learning_rate, Variable):
409+
optimizer._learning_rate_map[
410+
graph.program] = optimizer._learning_rate
405411
optimizer.minimize(target, no_grad_set=no_grad_var_names)
406412

407413
exe = Executor(place)

python/paddle/fluid/contrib/slim/tests/test_distillation_strategy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@ def test_compression(self):
4141

4242
cost = fluid.layers.cross_entropy(input=out, label=label)
4343
avg_cost = fluid.layers.mean(x=cost)
44+
4445
optimizer = fluid.optimizer.Momentum(
4546
momentum=0.9,
46-
learning_rate=0.01,
47+
learning_rate=fluid.layers.piecewise_decay(
48+
boundaries=[5, 10], values=[0.01, 0.001, 0.0001]),
4749
regularization=fluid.regularizer.L2Decay(4e-5))
4850

4951
place = fluid.CUDAPlace(0)

0 commit comments

Comments
 (0)