Skip to content

Commit 8b68f23

Browse files
authored
Merge pull request #37 from kozistr/feature/fp16
[Feature] Support FP16 for all optimizers by utilizing wrapper class
2 parents 2d4c3c6 + 333608f commit 8b68f23

File tree

5 files changed

+363
-9
lines changed

5 files changed

+363
-9
lines changed

README.rst

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,24 @@ Simple Usage
2727

2828
::
2929

30-
from pytorch_optimizer import Ranger21
30+
from pytorch_optimizer import AdamP
3131

3232
...
3333
model = YourModel()
34-
optimizer = Ranger21(model.parameters())
34+
optimizer = AdamP(model.parameters())
3535
...
3636

37-
for input, output in data:
38-
optimizer.zero_grad()
39-
loss = loss_function(output, model(input))
40-
loss.backward()
41-
optimizer.step()
37+
or you can use optimizer loader, simply passing a name of the optimizer.
38+
39+
::
40+
41+
from pytorch_optimizer import load_optimizers
42+
43+
...
44+
model = YourModel()
45+
opt = load_optimizers(optimizer='adamp', use_fp16=True)
46+
optimizer = opt(model.parameters())
47+
...
4248

4349
Supported Optimizers
4450
--------------------

pytorch_optimizer/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77
from pytorch_optimizer.chebyshev_schedule import get_chebyshev_schedule
88
from pytorch_optimizer.diffgrad import DiffGrad
99
from pytorch_optimizer.diffrgrad import DiffRGrad
10+
from pytorch_optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer
1011
from pytorch_optimizer.gc import centralize_gradient
1112
from pytorch_optimizer.lookahead import Lookahead
1213
from pytorch_optimizer.madgrad import MADGRAD
14+
from pytorch_optimizer.optimizers import load_optimizers
1315
from pytorch_optimizer.pcgrad import PCGrad
1416
from pytorch_optimizer.radam import RAdam
1517
from pytorch_optimizer.ranger import Ranger
1618
from pytorch_optimizer.ranger21 import Ranger21
1719
from pytorch_optimizer.sam import SAM
1820
from pytorch_optimizer.sgdp import SGDP
19-
from pytorch_optimizer.utils import get_optimizer_parameters, normalize_gradient, unit_norm
21+
from pytorch_optimizer.utils import clip_grad_norm, get_optimizer_parameters, normalize_gradient, unit_norm
2022

21-
__VERSION__ = '0.1.1'
23+
__VERSION__ = '0.2.0'

pytorch_optimizer/fp16.py

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
from typing import Dict, Optional
2+
3+
import torch
4+
from torch.optim import Optimizer
5+
6+
from pytorch_optimizer.types import CLOSURE
7+
from pytorch_optimizer.utils import clip_grad_norm, has_overflow
8+
9+
__AUTHOR__ = 'Facebook'
10+
__REFERENCE__ = 'https://github.com/facebookresearch/ParlAI/blob/main/parlai/utils/fp16.py'
11+
12+
13+
class DynamicLossScaler:
14+
"""Dynamically adjusts the loss scaling factor.
15+
Dynamic loss scalers are important in mixed-precision training.
16+
They help us avoid underflows and overflows in low-precision gradients.
17+
18+
See here for information:
19+
<https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#lossscaling>
20+
21+
Shamelessly stolen and adapted from FairSeq.
22+
<https://github.com/pytorch/fairseq/blob/main/fairseq/optim/fp16_optimizer.py>
23+
"""
24+
25+
def __init__(
26+
self,
27+
init_scale: float = 2.0 ** 15,
28+
scale_factor: float = 2.0,
29+
scale_window: int = 2000,
30+
tolerance: float = 0.00,
31+
threshold: Optional[float] = None,
32+
):
33+
"""
34+
:param init_scale: Initial loss scale.
35+
:param scale_factor: Factor by which to increase or decrease loss scale.
36+
:param scale_window: If we do not experience overflow in scale_window iterations,
37+
loss scale will increase by scale_factor.
38+
:param tolerance: Pct of iterations that have overflowed after which we must decrease the loss scale
39+
:param threshold: If not None, loss scale will decrease below this threshold
40+
"""
41+
self.loss_scale = init_scale
42+
self.scale_factor = scale_factor
43+
self.scale_window = scale_window
44+
self.tolerance = tolerance
45+
self.threshold = threshold
46+
47+
self.iter: int = 0
48+
self.last_overflow_iter: int = -1
49+
self.last_rescale_iter: int = -1
50+
self.overflows_since_rescale: int = 0
51+
52+
def update_scale(self, overflow: bool):
53+
"""Update the loss scale.
54+
If overflow exceeds our tolerance, we decrease the loss scale. If the number of
55+
iterations since the last overflow exceeds the scale window, we increase the loss scale.
56+
"""
57+
iter_since_rescale: int = self.iter - self.last_rescale_iter
58+
59+
if overflow:
60+
# calculate how often we overflowed already
61+
self.last_overflow_iter = self.iter
62+
self.overflows_since_rescale += 1
63+
64+
pct_overflow: float = self.overflows_since_rescale / float(iter_since_rescale)
65+
if pct_overflow >= self.tolerance:
66+
# decrease loss scale by the scale factor
67+
self.decrease_loss_scale()
68+
69+
# reset iterations
70+
self.last_rescale_iter = self.iter
71+
self.overflows_since_rescale = 0
72+
elif (self.iter - self.last_overflow_iter) % self.scale_window == 0:
73+
# increase the loss scale by scale factor
74+
self.loss_scale *= self.scale_factor
75+
self.last_rescale_iter = self.iter
76+
77+
self.iter += 1
78+
79+
def decrease_loss_scale(self):
80+
"""Decrease the loss scale by self.scale_factor.
81+
NOTE: the loss_scale will not go below self.threshold.
82+
"""
83+
self.loss_scale /= self.scale_factor
84+
if self.threshold is not None:
85+
self.loss_scale = max(self.loss_scale, self.threshold)
86+
87+
88+
class SafeFP16Optimizer(Optimizer):
89+
def __init__(self, optimizer, aggregate_gnorms: bool = False):
90+
self.optimizer = optimizer
91+
self.aggregate_gnorms = aggregate_gnorms
92+
93+
self.fp16_params = self.get_parameters(optimizer)
94+
self.fp32_params = self.build_fp32_params(self.fp16_params, flatten=False)
95+
96+
# we want the optimizer to be tracking the fp32 parameters
97+
if len(optimizer.param_groups) != 1:
98+
# future implementers: this should hopefully be a matter of just
99+
# iterating through the param groups and keeping track of the pointer
100+
# through the fp32_params
101+
raise NotImplementedError('[-] Need to implement the parameter group transfer.')
102+
103+
optimizer.param_groups[0]['params'] = self.fp32_params
104+
105+
self.scaler: DynamicLossScaler = DynamicLossScaler(2.0 ** 15)
106+
self.min_loss_scale: float = 2 ** -5
107+
self.needs_sync: bool = True
108+
109+
@classmethod
110+
def get_parameters(cls, optimizer: Optimizer):
111+
params = []
112+
for pg in optimizer.param_groups:
113+
params += list(pg['params'])
114+
return params
115+
116+
@classmethod
117+
def build_fp32_params(cls, parameters, flatten: bool = True):
118+
# create FP32 copy of parameters and grads
119+
if flatten:
120+
total_param_size = sum(p.data.numel() for p in parameters)
121+
fp32_params = torch.zeros(total_param_size, dtype=torch.float, device=parameters[0].device)
122+
123+
offset: int = 0
124+
for p in parameters:
125+
numel = p.data.numel()
126+
fp32_params[offset : offset + numel].copy_(p.data.view(-1))
127+
offset += numel
128+
129+
fp32_params = torch.nn.Parameter(fp32_params)
130+
fp32_params.grad = fp32_params.data.new(total_param_size)
131+
return fp32_params
132+
133+
fp32_params = []
134+
for p in parameters:
135+
p32 = torch.nn.Parameter(p.data.float())
136+
p32.grad = torch.zeros_like(p32.data)
137+
fp32_params.append(p32)
138+
139+
return fp32_params
140+
141+
def state_dict(self) -> Dict:
142+
"""Return the optimizer's state dict."""
143+
state_dict = self.optimizer.state_dict()
144+
if self.scaler is not None:
145+
state_dict['loss_scaler'] = self.scaler.loss_scale
146+
return state_dict
147+
148+
def load_state_dict(self, state_dict: Dict):
149+
"""Load an optimizer state dict.
150+
In general we should prefer the configuration of the existing optimizer instance
151+
(e.g., learning rate) over that found in the state_dict. This allows us to
152+
resume training from a checkpoint using a new set of optimizer args.
153+
"""
154+
if 'loss_scaler' in state_dict and self.scaler is not None and isinstance(state_dict['loss_scaler'], float):
155+
self.scaler.loss_scale = state_dict['loss_scaler']
156+
self.optimizer.load_state_dict(state_dict)
157+
158+
def backward(self, loss, update_main_grads: bool = False):
159+
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
160+
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this function
161+
additionally dynamically scales the loss to avoid gradient underflow.
162+
"""
163+
if self.scaler is not None:
164+
loss = loss * self.scaler.loss_scale
165+
166+
loss.backward()
167+
168+
self.needs_sync = True
169+
if update_main_grads:
170+
self.update_main_grads()
171+
172+
def sync_fp16_grads_to_fp32(self, multiply_grads: float = 1.0):
173+
if self.needs_sync:
174+
if self.scaler is not None:
175+
# correct for dynamic loss scaler
176+
multiply_grads /= self.scaler.loss_scale
177+
178+
# copy FP16 grads to FP32
179+
for p, p32 in zip(self.fp16_params, self.fp32_params):
180+
if not p.requires_grad:
181+
continue
182+
183+
if p.grad is not None:
184+
p32.grad.data.copy_(p.grad.data)
185+
p32.grad.data.mul_(multiply_grads)
186+
else:
187+
p32.grad = torch.zeros_like(p.data, dtype=torch.float)
188+
189+
self.needs_sync = False
190+
191+
def multiply_grads(self, c):
192+
"""Multiplies grads by a constant c."""
193+
if self.needs_sync:
194+
self.sync_fp16_grads_to_fp32(c)
195+
else:
196+
for p32 in self.fp32_params:
197+
p32.grad.data.mul_(c)
198+
199+
def update_main_grads(self):
200+
self.sync_fp16_grads_to_fp32()
201+
202+
def clip_main_grads(self, max_norm):
203+
"""Clips gradient norm and updates dynamic loss scaler."""
204+
self.sync_fp16_grads_to_fp32()
205+
206+
grad_norm = clip_grad_norm(self.fp32_params, max_norm, sync=self.aggregate_gnorms)
207+
208+
# detect overflow and adjust loss scale
209+
if self.scaler is not None:
210+
overflow: bool = has_overflow(grad_norm)
211+
prev_scale = self.scaler.loss_scale
212+
self.scaler.update_scale(overflow)
213+
if overflow:
214+
self.zero_grad()
215+
if self.scaler.loss_scale <= self.min_loss_scale:
216+
# Use FloatingPointError as an uncommon error that parent
217+
# functions can safely catch to stop training.
218+
self.scaler.loss_scale = prev_scale
219+
220+
raise FloatingPointError(
221+
f'Minimum loss scale reached ({self.min_loss_scale}). Your loss is probably exploding. '
222+
'Try lowering the learning rate, using gradient clipping or '
223+
'increasing the batch size.\n'
224+
f'Overflow: setting loss scale to {self.scaler.loss_scale}'
225+
)
226+
227+
return grad_norm
228+
229+
def step(self, closure: CLOSURE = None):
230+
"""Performs a single optimization step."""
231+
self.sync_fp16_grads_to_fp32()
232+
self.optimizer.step(closure)
233+
234+
# copy FP32 params back into FP16 model
235+
for p, p32 in zip(self.fp16_params, self.fp32_params):
236+
if not p.requires_grad:
237+
continue
238+
p.data.copy_(p32.data)
239+
240+
def zero_grad(self):
241+
"""Clears the gradients of all optimized parameters."""
242+
for p in self.fp16_params:
243+
p.grad = None
244+
for p32 in self.fp32_params:
245+
p32.grad.zero_()
246+
self.needs_sync = False
247+
248+
def get_lr(self) -> float:
249+
return self.optimizer.get_lr()
250+
251+
def set_lr(self, lr: float):
252+
self.optimizer.set_lr(lr)
253+
254+
@property
255+
def loss_scale(self) -> float:
256+
"""Convenience function which TorchAgent calls to get current scale value."""
257+
return self.scaler.loss_scale

pytorch_optimizer/optimizers.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from pytorch_optimizer.adabelief import AdaBelief
2+
from pytorch_optimizer.adabound import AdaBound
3+
from pytorch_optimizer.adahessian import AdaHessian
4+
from pytorch_optimizer.adamp import AdamP
5+
from pytorch_optimizer.diffgrad import DiffGrad
6+
from pytorch_optimizer.diffrgrad import DiffRGrad
7+
from pytorch_optimizer.fp16 import SafeFP16Optimizer
8+
from pytorch_optimizer.madgrad import MADGRAD
9+
from pytorch_optimizer.radam import RAdam
10+
from pytorch_optimizer.ranger import Ranger
11+
from pytorch_optimizer.ranger21 import Ranger21
12+
from pytorch_optimizer.sgdp import SGDP
13+
14+
15+
def load_optimizers(optimizer: str, use_fp16: bool = False):
16+
optimizer: str = optimizer.lower()
17+
18+
if optimizer == 'adamp':
19+
opt = AdamP
20+
elif optimizer == 'ranger':
21+
opt = Ranger
22+
elif optimizer == 'ranger21':
23+
opt = Ranger21
24+
elif optimizer == 'sgdp':
25+
opt = SGDP
26+
elif optimizer == 'radam':
27+
opt = RAdam
28+
elif optimizer == 'adabelief':
29+
opt = AdaBelief
30+
elif optimizer == 'adabound':
31+
opt = AdaBound
32+
elif optimizer == 'madgrad':
33+
opt = MADGRAD
34+
elif optimizer == 'diffrgrad':
35+
opt = DiffRGrad
36+
elif optimizer == 'diffgrad':
37+
opt = DiffGrad
38+
elif optimizer == 'diffgrad':
39+
opt = DiffGrad
40+
elif optimizer == 'adahessian':
41+
opt = AdaHessian
42+
else:
43+
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')
44+
45+
if use_fp16:
46+
opt = SafeFP16Optimizer(opt)
47+
48+
return opt

0 commit comments

Comments
 (0)