You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: python/paddle/amp/grad_scaler.py
+92-7Lines changed: 92 additions & 7 deletions
Original file line number
Diff line number
Diff line change
@@ -13,18 +13,28 @@
13
13
# limitations under the License.
14
14
15
15
frompaddle.fluid.dygraph.ampimportAmpScaler
16
+
frompaddle.fluid.dygraph.ampimportOptimizerState
17
+
fromcollectionsimportdefaultdict
16
18
17
19
__all__= []
18
20
19
21
22
+
def_refresh_optimizer_state():
23
+
return {"state": OptimizerState.INIT}
24
+
25
+
20
26
classGradScaler(AmpScaler):
21
27
"""
22
28
GradScaler is used for Auto-Mixed-Precision training in dynamic graph mode.
23
29
It controls the scaling of loss, helps avoiding numerical overflow.
24
-
The object of this class has two methods `scale()`, `minimize()`.
30
+
The object of this class has nineteen methods `scale()`, `unscale_()`, `minimize()`, `step()`, `update()` and `get`/`set` api of parameters.
25
31
26
32
`scale()` is used to multiply the loss by a scale ratio.
27
-
`minimize()` is similar as `optimizer.minimize()`, performs parameters updating.
33
+
`unscale_()` is used to unscale the gradients of parameters, multiplies the gradients of parameters by 1/(scale ratio)
34
+
`minimize()` is similar as `optimizer.minimize()`, performs parameters updating, and it will update the loss_scaling, it equal to `step()` + `update()`.
35
+
`step()` is similar as `optimizer.step()`, which performs parameters updating.
36
+
`update` is used to update the loss_scaling.
37
+
28
38
29
39
Commonly, it is used together with `paddle.amp.auto_cast` to achieve Auto-Mixed-Precision in
0 commit comments