Skip to content

Commit 6bf75c1

Browse files
authored
cherry-pick,fix optimizer.state_dict and LRScheduler.state_dict to save/load dygraph (#25447)
cherry-pick,fix optimizer.state_dict and LRScheduler.state_dict to save/load dygraph
1 parent 316afbb commit 6bf75c1

File tree

5 files changed

+267
-85
lines changed

5 files changed

+267
-85
lines changed

python/paddle/fluid/dygraph/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ def save_dygraph(state_dict, model_path):
7878
for k, v in state_dict.items():
7979
if isinstance(v, (Variable, core.VarBase)):
8080
model_dict[k] = v.numpy()
81+
name_table[k] = v.name
8182
else:
8283
model_dict[k] = v
83-
name_table[k] = v.name
8484
model_dict["StructuredToParameterName@@"] = name_table
8585

8686
file_name = model_path + suffix

python/paddle/fluid/dygraph/learning_rate_scheduler.py

Lines changed: 76 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import print_function
1616

1717
import math
18+
import warnings
1819

1920
from .. import unique_name
2021
from ..framework import Variable
@@ -66,6 +67,51 @@ def create_lr_var(self, lr):
6667
persistable=False)
6768
return lr
6869

70+
def state_dict(self):
71+
"""
72+
Returns the state of the scheduler as a :class:`dict`.
73+
74+
It is a subset of self.__dict__ .
75+
"""
76+
self._state_keys()
77+
state_dict = {}
78+
for key in self.keys:
79+
if key not in self.__dict__:
80+
continue
81+
value = self.__dict__[key]
82+
if isinstance(value, Variable):
83+
assert value.shape == [
84+
1
85+
], "shape of Variable in state_dict must be [1] {}".format(
86+
value.shape)
87+
value = value.numpy()[0]
88+
state_dict[key] = value
89+
90+
return state_dict
91+
92+
def _state_keys(self):
93+
"""
94+
set the keys in self.__dict__ that are needed to be saved.
95+
"""
96+
self.keys = ['step_num']
97+
98+
def set_dict(self, state_dict):
99+
"""
100+
Loads the schedulers state.
101+
"""
102+
self._state_keys()
103+
for key in self.keys:
104+
if key in state_dict:
105+
self.__dict__[key] = state_dict[key]
106+
else:
107+
raise RuntimeError(
108+
"Please check whether state_dict is correct for optimizer. Can't find [ {} ] in state_dict".
109+
format(key))
110+
if len(state_dict) > len(self.keys):
111+
warnings.warn(
112+
"There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
113+
)
114+
69115
def step(self):
70116
raise NotImplementedError()
71117

@@ -402,7 +448,7 @@ class PolynomialDecay(LearningRateDecay):
402448
learning_rate(Variable|float): The initial learning rate. If the type
403449
is Variable, it's a tensor with shape [1], the data type can be
404450
float32 or float64. It also can be set to python int number.
405-
decay_steps(int32): The decay step size. It determines the decay cycle.
451+
decay_steps(int): The decay step size. It determines the decay cycle.
406452
end_learning_rate(float, optional): The minimum final learning rate. The default value is 0.0001.
407453
power(float, optional): Power of polynomial. The default value is 1.0.
408454
cycle(bool, optional): If set true, decay the learning rate every decay_steps. The default value is False.
@@ -784,7 +830,7 @@ def __init__(self,
784830
raise ValueError(
785831
'new_lr = origin_lr * decay_rate and decay_rate should be < 1.0.'
786832
)
787-
self.decay_rate = decay_rate
833+
self.decay_rate = self.create_lr_var(decay_rate)
788834

789835
threshold_mode = threshold_mode.lower()
790836
if threshold_mode not in ['rel', 'abs']:
@@ -793,8 +839,10 @@ def __init__(self,
793839
self.threshold_mode = threshold_mode
794840
check_type(learning_rate, 'learning_rate', (float, int, Variable),
795841
'ReduceLROnPlateau')
796-
if isinstance(learning_rate, (float, int)):
797-
learning_rate = self.create_lr_var(learning_rate)
842+
if not isinstance(learning_rate, (float, int, Variable)):
843+
raise TypeError(
844+
"The type of 'learning_rate' in 'ReduceLROnPlateau' must be 'float, int, Variable', but received %s."
845+
% type(learning_rate))
798846

799847
self.learning_rate = learning_rate
800848
self.verbose = verbose
@@ -808,9 +856,17 @@ def __init__(self,
808856
self.cooldown_counter = 0
809857
self.best_loss = None
810858
self.num_bad_epochs = 0
811-
self.epoch = 0
859+
self.epoch_num = 0
860+
861+
def _state_keys(self):
862+
self.keys = [
863+
'cooldown_counter', 'best_loss', 'num_bad_epochs', 'epoch_num',
864+
'learning_rate'
865+
]
812866

813867
def __call__(self):
868+
if not isinstance(self.learning_rate, Variable):
869+
self.learning_rate = self.create_lr_var(self.learning_rate)
814870
return self.learning_rate
815871

816872
def step(self, loss):
@@ -836,7 +892,7 @@ def step(self, loss):
836892
"should be (1L,), but the current loss.shape is {}. Maybe that " \
837893
"you should call fluid.layers.mean to process it first.".format(loss.shape)
838894

839-
self.epoch += 1
895+
self.epoch_num += 1
840896
if self.cooldown_counter > 0:
841897
self.cooldown_counter -= 1
842898
else:
@@ -854,10 +910,11 @@ def step(self, loss):
854910
self.decay_rate, self.min_lr)
855911
if self.learning_rate - new_lr > self.eps:
856912
if self.verbose:
913+
old_lr = self.learning_rate.numpy()[0] if isinstance(
914+
self.learning_rate,
915+
Variable) else self.learning_rate
857916
print('Epoch {}: reducing learning rate from {} to {}.'.
858-
format(self.epoch,
859-
self.learning_rate.numpy()[0],
860-
new_lr.numpy()[0]))
917+
format(self.epoch_num, old_lr, new_lr.numpy()[0]))
861918
self.learning_rate = new_lr
862919

863920
def _is_better(self, current, best):
@@ -890,22 +947,28 @@ def __init__(self, learning_rate, dtype=None):
890947
raise TypeError(
891948
"The type of 'learning_rate' must be 'float, int', but received %s."
892949
% type(learning_rate))
893-
if learning_rate >= 1.0:
894-
raise ValueError("The initial learning rate")
950+
if learning_rate < 0:
951+
raise ValueError("Invalid learning rate: {}".format(learning_rate))
895952

896953
self.base_lr = float(learning_rate)
897954

898955
self.epoch_num = -1
956+
self.dtype = dtype
899957
if dtype is None:
900958
self.dtype = "float32"
901959
self.learning_rate = self.create_lr_var(self.base_lr)
902960

903961
self.epoch()
904962

963+
def _state_keys(self):
964+
self.keys = ['epoch_num', 'learning_rate']
965+
905966
def __call__(self):
906967
"""
907968
Return last computed learning rate on current epoch.
908969
"""
970+
if not isinstance(self.learning_rate, Variable):
971+
self.learning_rate = self.create_lr_var(self.learning_rate)
909972
return self.learning_rate
910973

911974
def epoch(self, epoch=None):
@@ -918,8 +981,6 @@ def epoch(self, epoch=None):
918981
self.epoch_num = epoch
919982

920983
self.learning_rate = self.get_lr()
921-
if isinstance(self.learning_rate, float):
922-
self.learning_rate = self.create_lr_var(self.learning_rate)
923984

924985
def get_lr(self):
925986
raise NotImplementedError
@@ -946,7 +1007,7 @@ class StepDecay(_LearningRateEpochDecay):
9461007
9471008
Parameters:
9481009
learning_rate (float|int): The initial learning rate. It can be set to python float or int number.
949-
step_size (int): Period of learning rate decay..
1010+
step_size (int): Period of learning rate decay.
9501011
decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` .
9511012
It should be less than 1.0. Default: 0.1.
9521013
@@ -1024,7 +1085,7 @@ class MultiStepDecay(_LearningRateEpochDecay):
10241085
learning_rate = 0.005
10251086
10261087
Parameters:
1027-
learning_rate (float|int): The initial learning rate. It can be set to python float or int number. If it
1088+
learning_rate (float|int): The initial learning rate. It can be set to python float or int number.
10281089
milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
10291090
decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` .
10301091
It should be less than 1.0. Default: 0.1.

python/paddle/fluid/optimizer.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from .regularizer import append_regularization_ops
3434
from .dygraph import base as imperative_base
3535
from .dygraph import no_grad
36-
from .dygraph.learning_rate_scheduler import LearningRateDecay
36+
from .dygraph.learning_rate_scheduler import LearningRateDecay, _LearningRateEpochDecay
3737
from paddle.fluid import core
3838
from paddle.fluid.layers import tensor
3939
from functools import reduce
@@ -148,17 +148,17 @@ def state_dict(self):
148148
state_dict[var_tmp.name] = var_tmp
149149
# global step if use lr decay
150150
if isinstance(self._learning_rate, LearningRateDecay):
151-
var_tmp = None
152-
if framework.in_dygraph_mode():
151+
state_dict["LR_Scheduler"] = self._learning_rate.state_dict()
152+
153+
if not isinstance(self._learning_rate, _LearningRateEpochDecay):
154+
var_tmp = None
153155
var_temp = framework._varbase_creator(
154156
None, name='global_step', dtype='int32')
155-
else:
156-
var_temp = Variable(None, name='global_step', dtype='int32')
157157

158-
tensor.fill_constant(
159-
[1], "int32", self._learning_rate.step_num, out=var_temp)
158+
tensor.fill_constant(
159+
[1], "int32", self._learning_rate.step_num, out=var_temp)
160160

161-
state_dict['global_step'] = var_temp
161+
state_dict['global_step'] = var_temp
162162
return state_dict
163163

164164
@framework.dygraph_only
@@ -192,30 +192,28 @@ def set_dict(self, state_dict):
192192
'''
193193

194194
if isinstance(self._learning_rate, LearningRateDecay):
195-
assert 'global_step' in state_dict, \
196-
'Global step not in state dict, Dygraph use LearningRateDecay, global_step must in state_dict'
197-
global_step = state_dict['global_step']
198-
199-
if isinstance(global_step, core.VarBase):
200-
step_np = global_step
201-
step_np = np.array(step_np.value().get_tensor())
202-
assert step_np.shape == (1,), \
203-
"global step shape is (1,), the shape is {}".format( step_np.shape )
204-
205-
self._learning_rate.step_num = int(step_np[0])
206-
elif isinstance(global_step, Variable):
207-
step_np = global_step.numpy()
208-
assert step_np.shape == (1,), \
209-
"global step shape is (1,), the shape is {}".format( step_np.shape )
210-
self._learning_rate.step_num = step_np[0]
211-
elif isinstance(global_step, np.ndarray):
212-
assert global_step.shape == (1,), \
213-
"global step shape is (1,), the shape is {}".format( global_step.shape )
214-
self._learning_rate.step_num = global_step[0]
215-
else:
216-
raise RuntimeError(
217-
"Type not supprt, value in state dict must be [VarBase, Variable, numpy], the type is ",
218-
type(global_step))
195+
self._learning_rate.set_dict(state_dict["LR_Scheduler"])
196+
197+
if not isinstance(self._learning_rate, _LearningRateEpochDecay):
198+
assert 'global_step' in state_dict, \
199+
'Global step not in state dict, Dygraph use LearningRateDecay, global_step must in state_dict'
200+
global_step = state_dict['global_step']
201+
202+
if isinstance(global_step, Variable):
203+
step_np = global_step
204+
step_np = np.array(step_np.value().get_tensor())
205+
assert step_np.shape == (1,), \
206+
"global step shape is (1,), the shape is {}".format( step_np.shape )
207+
208+
self._learning_rate.step_num = int(step_np[0])
209+
elif isinstance(global_step, np.ndarray):
210+
assert global_step.shape == (1,), \
211+
"global step shape is (1,), the shape is {}".format( global_step.shape )
212+
self._learning_rate.step_num = global_step[0]
213+
else:
214+
raise RuntimeError(
215+
"Type not supprt, value in state dict must be [VarBase, Variable, numpy], the type is ",
216+
type(global_step))
219217

220218
self._accumulators_holder = state_dict
221219
for k, v in self._accumulators.items():
@@ -346,11 +344,14 @@ def current_step_lr(self):
346344
347345
"""
348346
current_lr = self._global_learning_rate()
349-
if current_lr:
347+
if isinstance(current_lr, framework.Variable):
350348
return self._global_learning_rate().numpy()[0]
351349

352350
if isinstance(self._learning_rate, float):
353351
return self._learning_rate
352+
elif isinstance(self._learning_rate, _LearningRateEpochDecay):
353+
step_lr = self._learning_rate()
354+
return step_lr.numpy()[0]
354355
else:
355356
step_lr = self._learning_rate.step()
356357
if isinstance(step_lr, (float, int)):

0 commit comments

Comments
 (0)