Skip to content

Commit e262125

Browse files
authored
[cherry pick]split minimize and add unscale_ for GradScaler (#35927)
1、Split function GradScaler::minimize() to GradScaler::step() + GradScaler::update() 2、Add GradScaler::unscale_(optimizer)
1 parent 085eae2 commit e262125

File tree

5 files changed

+260
-45
lines changed

5 files changed

+260
-45
lines changed

python/paddle/amp/grad_scaler.py

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,28 @@
1313
# limitations under the License.
1414

1515
from paddle.fluid.dygraph.amp import AmpScaler
16+
from paddle.fluid.dygraph.amp import OptimizerState
17+
from collections import defaultdict
1618

1719
__all__ = []
1820

1921

22+
def _refresh_optimizer_state():
23+
return {"state": OptimizerState.INIT}
24+
25+
2026
class GradScaler(AmpScaler):
2127
"""
2228
GradScaler is used for Auto-Mixed-Precision training in dynamic graph mode.
2329
It controls the scaling of loss, helps avoiding numerical overflow.
24-
The object of this class has two methods `scale()`, `minimize()`.
30+
The object of this class has nineteen methods `scale()`, `unscale_()`, `minimize()`, `step()`, `update()` and `get`/`set` api of parameters.
2531
2632
`scale()` is used to multiply the loss by a scale ratio.
27-
`minimize()` is similar as `optimizer.minimize()`, performs parameters updating.
33+
`unscale_()` is used to unscale the gradients of parameters, multiplies the gradients of parameters by 1/(scale ratio)
34+
`minimize()` is similar as `optimizer.minimize()`, performs parameters updating, and it will update the loss_scaling, it equal to `step()` + `update()`.
35+
`step()` is similar as `optimizer.step()`, which performs parameters updating.
36+
`update` is used to update the loss_scaling.
37+
2838
2939
Commonly, it is used together with `paddle.amp.auto_cast` to achieve Auto-Mixed-Precision in
3040
dynamic graph mode.
@@ -115,7 +125,7 @@ def minimize(self, optimizer, *args, **kwargs):
115125
This function is similar as `optimizer.minimize()`, which performs parameters updating.
116126
117127
If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
118-
Otherwise, it first unscales the scaled gradients of parameters, then updates the parameters.
128+
Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters.
119129
120130
Finally, the loss scaling ratio is updated.
121131
@@ -151,16 +161,18 @@ def step(self, optimizer):
151161
This function is similar as `optimizer.step()`, which performs parameters updating.
152162
153163
If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
154-
Otherwise, it first unscales the scaled gradients of parameters, then updates the parameters.
164+
Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters.
155165
156166
Args:
157167
optimizer(Optimizer): The optimizer used to update parameters.
158168
159169
Examples:
170+
160171
.. code-block:: python
161172
162173
# required: gpu
163174
import paddle
175+
164176
model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
165177
optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
166178
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
@@ -170,24 +182,97 @@ def step(self, optimizer):
170182
loss = paddle.mean(conv)
171183
scaled = scaler.scale(loss) # scale the loss
172184
scaled.backward() # do backward
173-
scaler.step(optimizer)
185+
scaler.step(optimizer) # update parameters
186+
scaler.update() # update the loss scaling ratio
174187
optimizer.clear_grad()
175188
"""
176189
if not self._enable:
177190
return optimizer.step()
178191

192+
optimizer_state = self._optimizer_states[id(optimizer)]
193+
if optimizer_state["state"] is OptimizerState.STEPPED:
194+
raise RuntimeError(
195+
"step() has already been called since the last update().")
196+
179197
# unscale the grad
180-
self._unscale(optimizer)
198+
if optimizer_state["state"] is OptimizerState.INIT:
199+
self._unscale(optimizer)
181200

182201
if self._found_inf:
183202
self._cache_founf_inf = True
184203
else:
185204
optimizer.step()
186205
self._cache_founf_inf = False
187206

207+
optimizer_state["state"] = OptimizerState.STEPPED
208+
209+
if not self._use_dynamic_loss_scaling:
210+
self._optimizer_states = defaultdict(_refresh_optimizer_state)
211+
212+
def update(self):
213+
"""
214+
Updates the loss_scaling.
215+
216+
Examples:
217+
218+
.. code-block:: python
219+
220+
# required: gpu
221+
import paddle
222+
223+
model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
224+
optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
225+
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
226+
data = paddle.rand([10, 3, 32, 32])
227+
with paddle.amp.auto_cast():
228+
conv = model(data)
229+
loss = paddle.mean(conv)
230+
scaled = scaler.scale(loss) # scale the loss
231+
scaled.backward() # do backward
232+
scaler.step(optimizer) # update parameters
233+
scaler.update() # update the loss scaling ratio
234+
optimizer.clear_grad()
235+
"""
236+
if not self._enable:
237+
return
188238
if self._use_dynamic_loss_scaling:
189-
# uopdate the scale
190239
self._update()
240+
self._optimizer_states = defaultdict(_refresh_optimizer_state)
241+
return
242+
243+
def unscale_(self, optimizer):
244+
"""
245+
Unscale the gradients of parameters, multiplies the gradients of parameters by 1/(loss scaling ratio).
246+
If this instance of :class:`GradScaler` is not enabled, output are returned unmodified.
247+
248+
Args:
249+
optimizer(Optimizer): The optimizer used to update parameters.
250+
251+
Returns:
252+
The unscaled parameters or original parameters.
253+
254+
Examples:
255+
256+
.. code-block:: python
257+
258+
# required: gpu
259+
import paddle
260+
261+
model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
262+
optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
263+
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
264+
data = paddle.rand([10, 3, 32, 32])
265+
with paddle.amp.auto_cast():
266+
conv = model(data)
267+
loss = paddle.mean(conv)
268+
scaled = scaler.scale(loss) # scale the loss
269+
scaled.backward() # do backward
270+
scaler.unscale_(optimizer) # unscale the parameter
271+
scaler.step(optimizer)
272+
scaler.update()
273+
optimizer.clear_grad()
274+
"""
275+
return super(GradScaler, self)._unscale(optimizer)
191276

192277
def is_enable(self):
193278
"""

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def _broadcast_final_loss(self):
329329
def _optimizer_step(self):
330330
if self.scaler:
331331
self.scaler.step(self.optimizer)
332+
self.scaler.update()
332333
else:
333334
self.optimizer.step()
334335

python/paddle/fluid/dygraph/amp/loss_scaler.py

Lines changed: 77 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,20 @@
2121
import warnings
2222
import numpy as np
2323
from paddle import _C_ops
24+
from collections import defaultdict
25+
from enum import Enum
2426

25-
__all__ = ['AmpScaler']
27+
__all__ = ['AmpScaler', 'OptimizerState']
28+
29+
30+
class OptimizerState(Enum):
31+
INIT = 0
32+
UNSCALED = 1
33+
STEPPED = 2
34+
35+
36+
def _refresh_optimizer_state():
37+
return {"state": OptimizerState.INIT}
2638

2739

2840
class AmpScaler(object):
@@ -31,10 +43,11 @@ class AmpScaler(object):
3143
3244
AmpScaler is used for Auto-Mixed-Precision training/inferring in imperative
3345
mode. It controls the scaling of loss, helps avoiding numerical overflow.
34-
The object of this class has two methods `scale()`, `minimize()`.
46+
The object of this class has seventeen methods `scale()`, `unscale_()`, `minimize()` and `get`/`set` api of parameters.
3547
3648
`scale()` is used to multiply the loss by a scale ratio.
37-
`minimize()` is similar as `Optimizer.minimize()`, performs parameters updating.
49+
`unscale_()` is used to unscale the gradients of parameters, multiplies the gradients of parameters by 1/(scale ratio)
50+
`minimize()` is similar as `optimizer.minimize()`, performs parameters updating, and it will update the loss_scaling.
3851
3952
Commonly, it is used together with `amp_guard` to achieve Auto-Mixed-Precision in
4053
imperative mode.
@@ -117,6 +130,7 @@ def __init__(self,
117130
self._scale = to_variable(
118131
np.array([self._init_loss_scaling]).astype(np.float32))
119132
self._cache_founf_inf = None
133+
self._optimizer_states = defaultdict(_refresh_optimizer_state)
120134

121135
def scale(self, var):
122136
"""
@@ -129,24 +143,25 @@ def scale(self, var):
129143
The scaled variable or original variable.
130144
131145
Examples:
146+
132147
.. code-block:: python
133148
134-
import numpy as np
135-
import paddle.fluid as fluid
136-
137-
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
138-
with fluid.dygraph.guard():
139-
model = fluid.dygraph.Conv2D(3, 2, 3)
140-
optimizer = fluid.optimizer.SGDOptimizer(
141-
learning_rate=0.01, parameter_list=model.parameters())
142-
scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
143-
data = fluid.dygraph.to_variable(data)
144-
with fluid.dygraph.amp_guard():
145-
conv = model(data)
146-
loss = fluid.layers.reduce_mean(conv)
147-
scaled = scaler.scale(loss)
148-
scaled.backward()
149-
scaler.minimize(optimizer, scaled)
149+
import numpy as np
150+
import paddle.fluid as fluid
151+
152+
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
153+
with fluid.dygraph.guard():
154+
model = fluid.dygraph.Conv2D(3, 2, 3)
155+
optimizer = fluid.optimizer.SGDOptimizer(
156+
learning_rate=0.01, parameter_list=model.parameters())
157+
scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
158+
data = fluid.dygraph.to_variable(data)
159+
with fluid.dygraph.amp_guard():
160+
conv = model(data)
161+
loss = fluid.layers.reduce_mean(conv)
162+
scaled = scaler.scale(loss)
163+
scaled.backward()
164+
scaler.minimize(optimizer, scaled)
150165
"""
151166
check_type(var, "var", core.VarBase, 'AmpScaler.scale()')
152167

@@ -160,7 +175,7 @@ def minimize(self, optimizer, *args, **kwargs):
160175
This function is similar as `Optimizer.minimize()`, which performs parameters updating.
161176
162177
If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
163-
Otherwise, it first unscales the scaled gradients of parameters, then updates the parameters.
178+
Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters.
164179
165180
Finally, the loss scaling ratio is updated.
166181
@@ -170,30 +185,34 @@ def minimize(self, optimizer, *args, **kwargs):
170185
kwargs: Keyword arguments, which will be forward to `Optimizer.minimize()`.
171186
172187
Examples:
188+
173189
.. code-block:: python
174190
175-
import numpy as np
176-
import paddle.fluid as fluid
177-
178-
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
179-
with fluid.dygraph.guard():
180-
model = fluid.dygraph.Conv2D(3, 2, 3)
181-
optimizer = fluid.optimizer.SGDOptimizer(
182-
learning_rate=0.01, parameter_list=model.parameters())
183-
scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
184-
data = fluid.dygraph.to_variable(data)
185-
with fluid.dygraph.amp_guard():
186-
conv = model(data)
187-
loss = fluid.layers.reduce_mean(conv)
188-
scaled = scaler.scale(loss)
189-
scaled.backward()
190-
scaler.minimize(optimizer, scaled)
191+
import numpy as np
192+
import paddle.fluid as fluid
193+
194+
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
195+
with fluid.dygraph.guard():
196+
model = fluid.dygraph.Conv2D(3, 2, 3)
197+
optimizer = fluid.optimizer.SGDOptimizer(
198+
learning_rate=0.01, parameter_list=model.parameters())
199+
scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
200+
data = fluid.dygraph.to_variable(data)
201+
with fluid.dygraph.amp_guard():
202+
conv = model(data)
203+
loss = fluid.layers.reduce_mean(conv)
204+
scaled = scaler.scale(loss)
205+
scaled.backward()
206+
scaler.minimize(optimizer, scaled)
191207
"""
192208
if not self._enable:
193209
return optimizer.minimize(*args, **kwargs)
194210

211+
optimizer_state = self._optimizer_states[id(optimizer)]
212+
195213
# unscale the grad
196-
self._unscale(optimizer)
214+
if optimizer_state["state"] is OptimizerState.INIT:
215+
self._unscale(optimizer)
197216

198217
optimize_ops, params_grads = (None, None)
199218

@@ -207,12 +226,31 @@ def minimize(self, optimizer, *args, **kwargs):
207226
# uopdate the scale
208227
self._update()
209228

229+
self._optimizer_states = defaultdict(_refresh_optimizer_state)
230+
210231
return optimize_ops, params_grads
211232

212233
def _unscale(self, optimizer):
234+
"""
235+
Unscale the gradients of parameters, multiplies the gradients of parameters by 1/(loss scaling ratio).
236+
If this instance of :class:`GradScaler` is not enabled, output are returned unmodified.
237+
Args:
238+
optimizer(Optimizer): The optimizer used to update parameters.
239+
Returns:
240+
The unscaled parameters or original parameters.
241+
"""
213242
if not self._enable:
214243
return
215244

245+
optimizer_state = self._optimizer_states[id(optimizer)]
246+
247+
if optimizer_state["state"] is OptimizerState.UNSCALED:
248+
raise RuntimeError(
249+
"unscale_() has already been called on this optimizer since the last update()."
250+
)
251+
elif optimizer_state["state"] is OptimizerState.STEPPED:
252+
raise RuntimeError("unscale_() is being called after step().")
253+
216254
if getattr(optimizer, '_param_groups', None) and isinstance(
217255
optimizer._param_groups[0], dict):
218256
param_grads = []
@@ -256,6 +294,8 @@ def _unscale(self, optimizer):
256294
temp_found_inf_fp32)
257295
self._found_inf = temp_found_inf_fp16 or temp_found_inf_fp32
258296

297+
optimizer_state["state"] = OptimizerState.UNSCALED
298+
259299
def _update(self):
260300
"""
261301
Updates the loss_scaling.

python/paddle/fluid/tests/unittests/hybrid_parallel_mp_amp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def train_batch(self, batch, model, optimizer, is_mp):
4848
scaled.backward() # do backward
4949

5050
scaler.step(optimizer) # update parameters
51+
scaler.update()
5152
optimizer.clear_grad()
5253
return scaled
5354

0 commit comments

Comments
 (0)