Skip to content

Commit 23ceb55

Browse files
authored
Merge pull request #284 from kozistr/feature/cpu-offloading
[Feature] Implement CPUOffloadOptimizer
2 parents ed1d3e1 + df3c4ea commit 23ceb55

File tree

6 files changed

+163
-6
lines changed

6 files changed

+163
-6
lines changed

README.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@ please check [the bnb requirements](https://github.com/TimDettmers/bitsandbytes?
3333

3434
From `v3.0.0`, drop `Python 3.7` support. However, you can still use this package with `Python 3.7` by installing with `--ignore-requires-python` option.
3535

36-
```bash
37-
$ pip install "pytorch-optimizer[bitsandbytes]"
38-
```
39-
4036
### Simple Usage
4137

4238
```python

docs/changelogs/v3.2.0.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
* Support 8/4bit, fp8 optimizers. (#208, #281)
1010
* `torchao_adamw8bit`, `torchao_adamw4bit`, `torchao_adamwfp8`.
1111
* Support a module-name-level (e.g. `LayerNorm`) weight decay exclusion for `get_optimizer_parameters`. (#282, #283)
12+
* Implement `CPUOffloadOptimizer`, which offloads optimizer to CPU for single-GPU training. (#284)
1213

1314
### Bug
1415

docs/util.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Utilization
22

3+
::: pytorch_optimizer.optimizer.utils.CPUOffloadOptimizer
4+
:docstring:
5+
:members:
6+
37
::: pytorch_optimizer.optimizer.utils.get_optimizer_parameters
48
:docstring:
59
:members:

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@
121121
from pytorch_optimizer.optimizer.tiger import Tiger
122122
from pytorch_optimizer.optimizer.trac import TRAC
123123
from pytorch_optimizer.optimizer.utils import (
124+
CPUOffloadOptimizer,
124125
clip_grad_norm,
125126
disable_running_stats,
126127
enable_running_stats,

pytorch_optimizer/optimizer/utils.py

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import functools
12
import math
3+
import operator
4+
import re
25
import warnings
36
from importlib.util import find_spec
4-
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
7+
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
58

69
import numpy as np
710
import torch
@@ -11,7 +14,7 @@
1114
from torch.nn.modules.batchnorm import _BatchNorm
1215
from torch.nn.utils import clip_grad_norm_
1316

14-
from pytorch_optimizer.base.types import PARAMETERS
17+
from pytorch_optimizer.base.types import CLOSURE, LOSS, PARAMETERS
1518

1619
HAS_TRANSFORMERS: bool = find_spec('transformers') is not None
1720

@@ -36,6 +39,127 @@ def is_deepspeed_zero3_enabled() -> bool:
3639
return False
3740

3841

42+
def parse_pytorch_version(version_string: str) -> List[int]:
43+
r"""Parse Pytorch version."""
44+
match = re.match(r'(\d+\.\d+\.\d+)', version_string)
45+
if not match:
46+
raise ValueError(f'invalid version string format: {version_string}')
47+
48+
return [int(x) for x in match.group(1).split('.')]
49+
50+
51+
def compare_versions(v1: str, v2: str) -> bool:
52+
r"""Compare two Pytorch versions."""
53+
v1_parts: List[int] = parse_pytorch_version(v1)
54+
v2_parts: List[int] = parse_pytorch_version(v2)
55+
return (v1_parts > v2_parts) - (v1_parts < v2_parts)
56+
57+
58+
TORCH_VERSION_AT_LEAST_2_4: bool = compare_versions(torch.__version__, '2.4.0')
59+
60+
61+
class CPUOffloadOptimizer: # pragma: no cover
62+
"""Offload optimizer to CPU for single-GPU training. This will reduce GPU memory by the size of optimizer state.
63+
64+
Reference: https://github.com/pytorch/ao/blob/main/torchao/prototype/low_bit_optim/cpu_offload.py
65+
66+
:param params: PARAMETERS. a list of parameters or parameter groups.
67+
:param optimizer_class: Type[torch.optim.Optimizer]. constructor of the base optimizer. Defaults to
68+
:class:`torch.optim.AdamW`.
69+
:param offload_gradients: bool. free GPU gradients once they are moved to CPU. Not compatible with gradient
70+
accumulation.
71+
:param kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`.
72+
"""
73+
74+
def __init__(
75+
self,
76+
params: PARAMETERS,
77+
optimizer_class: Type[torch.optim.Optimizer] = torch.optim.AdamW,
78+
*,
79+
offload_gradients: bool = False,
80+
**kwargs,
81+
) -> None:
82+
if optimizer_class is torch.optim.AdamW and TORCH_VERSION_AT_LEAST_2_4 and 'fused' not in kwargs:
83+
kwargs.update(fused=True)
84+
85+
param_groups = list(params)
86+
if len(param_groups) == 0:
87+
raise ValueError('optimizer got an empty parameter list')
88+
if not isinstance(param_groups[0], dict):
89+
param_groups = [{'params': param_groups}]
90+
91+
self.param_cuda2cpu_map = {}
92+
self.optim_dict = {}
93+
self.stream = torch.cuda.Stream()
94+
95+
self.queue = {}
96+
97+
def backward_hook(p_cuda: torch.Tensor) -> None:
98+
if p_cuda.grad is None:
99+
return
100+
101+
p_cpu = self.param_cuda2cpu_map[p_cuda]
102+
103+
self.stream.wait_stream(torch.cuda.current_stream())
104+
with torch.cuda.stream(self.stream):
105+
p_cpu.grad.copy_(p_cuda.grad, non_blocking=True)
106+
107+
if p_cuda in self.queue:
108+
del self.queue[p_cuda]
109+
110+
self.queue[p_cuda] = self.stream.record_event()
111+
112+
if offload_gradients:
113+
p_cuda.grad.record_stream(self.stream)
114+
p_cuda.grad = None
115+
116+
for param_group in param_groups:
117+
params = param_group.pop('params')
118+
119+
for p_cuda in params:
120+
p_cpu = torch.empty_like(p_cuda, device='cpu', pin_memory=True)
121+
p_cpu.grad = torch.empty_like(p_cpu, pin_memory=True)
122+
123+
p_cpu.copy_(p_cuda.detach(), non_blocking=True)
124+
self.param_cuda2cpu_map[p_cuda] = p_cpu
125+
126+
p_cuda.register_post_accumulate_grad_hook(backward_hook)
127+
self.optim_dict[p_cuda] = optimizer_class([{'params': p_cpu, **param_group}], **kwargs)
128+
129+
@torch.no_grad()
130+
def step(self, closure: CLOSURE = None) -> LOSS:
131+
loss = None
132+
if closure is not None:
133+
loss = closure()
134+
135+
for p_cuda, grad_d2h_event in self.queue.items():
136+
grad_d2h_event.synchronize()
137+
self.optim_dict[p_cuda].step()
138+
139+
p_cpu = self.param_cuda2cpu_map[p_cuda]
140+
with torch.cuda.stream(self.stream):
141+
p_cuda.copy_(p_cpu, non_blocking=True)
142+
143+
self.queue.clear()
144+
145+
return loss
146+
147+
def zero_grad(self, _: bool = True) -> None:
148+
for p_cuda in self.param_cuda2cpu_map:
149+
p_cuda.grad = None
150+
151+
@property
152+
def param_groups(self):
153+
return functools.reduce(operator.add, (optim.param_groups for optim in self.optim_dict.values()), [])
154+
155+
def state_dict(self):
156+
return [optim.state_dict() for optim in self.optim_dict.values()]
157+
158+
def load_state_dict(self, state_dict):
159+
for optim, optim_state_dict in zip(self.optim_dict.values(), state_dict):
160+
optim.load_state_dict(optim_state_dict)
161+
162+
39163
def is_valid_parameters(parameters: PARAMETERS) -> bool:
40164
r"""Check where the parameters are valid."""
41165
return isinstance(parameters, (list, tuple)) and len(parameters) > 0 and isinstance(parameters[0], dict)

tests/test_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
merge_small_dims,
1313
)
1414
from pytorch_optimizer.optimizer.utils import (
15+
CPUOffloadOptimizer,
1516
clip_grad_norm,
17+
compare_versions,
1618
disable_running_stats,
1719
enable_running_stats,
1820
get_optimizer_parameters,
@@ -21,6 +23,7 @@
2123
neuron_mean,
2224
neuron_norm,
2325
normalize_gradient,
26+
parse_pytorch_version,
2427
reduce_max_except_dim,
2528
reg_noise,
2629
to_real,
@@ -231,3 +234,31 @@ def test_emcmc():
231234

232235
loss = reg_noise(network1, network2, int(5e4), 1e-1).detach().numpy()
233236
np.testing.assert_almost_equal(loss, 0.0011383)
237+
238+
239+
def test_version_utils():
240+
with pytest.raises(ValueError):
241+
parse_pytorch_version('a.s.d.f')
242+
243+
assert parse_pytorch_version(torch.__version__) == [2, 5, 0]
244+
245+
assert compare_versions('2.5.0', '2.4.0') >= 0
246+
247+
248+
def test_cpu_offload_optimizer():
249+
if not torch.cuda.is_available():
250+
pytest.skip('need GPU to run a test')
251+
252+
params = Example().parameters()
253+
254+
opt = CPUOffloadOptimizer(params, torch.optim.AdamW, fused=False, offload_gradients=True)
255+
256+
with pytest.raises(ValueError):
257+
CPUOffloadOptimizer([], torch.optim.AdamW)
258+
259+
opt.zero_grad()
260+
261+
_ = opt.param_groups
262+
263+
state_dict = opt.state_dict()
264+
opt.load_state_dict(state_dict)

0 commit comments

Comments
 (0)