1+ import functools
12import math
3+ import operator
4+ import re
25import warnings
36from importlib .util import find_spec
4- from typing import Callable , Dict , List , Optional , Set , Tuple , Union
7+ from typing import Callable , Dict , List , Optional , Set , Tuple , Type , Union
58
69import numpy as np
710import torch
1114from torch .nn .modules .batchnorm import _BatchNorm
1215from torch .nn .utils import clip_grad_norm_
1316
14- from pytorch_optimizer .base .types import PARAMETERS
17+ from pytorch_optimizer .base .types import CLOSURE , LOSS , PARAMETERS
1518
1619HAS_TRANSFORMERS : bool = find_spec ('transformers' ) is not None
1720
@@ -36,6 +39,127 @@ def is_deepspeed_zero3_enabled() -> bool:
3639 return False
3740
3841
42+ def parse_pytorch_version (version_string : str ) -> List [int ]:
43+ r"""Parse Pytorch version."""
44+ match = re .match (r'(\d+\.\d+\.\d+)' , version_string )
45+ if not match :
46+ raise ValueError (f'invalid version string format: { version_string } ' )
47+
48+ return [int (x ) for x in match .group (1 ).split ('.' )]
49+
50+
51+ def compare_versions (v1 : str , v2 : str ) -> bool :
52+ r"""Compare two Pytorch versions."""
53+ v1_parts : List [int ] = parse_pytorch_version (v1 )
54+ v2_parts : List [int ] = parse_pytorch_version (v2 )
55+ return (v1_parts > v2_parts ) - (v1_parts < v2_parts )
56+
57+
58+ TORCH_VERSION_AT_LEAST_2_4 : bool = compare_versions (torch .__version__ , '2.4.0' )
59+
60+
61+ class CPUOffloadOptimizer : # pragma: no cover
62+ """Offload optimizer to CPU for single-GPU training. This will reduce GPU memory by the size of optimizer state.
63+
64+ Reference: https://github.com/pytorch/ao/blob/main/torchao/prototype/low_bit_optim/cpu_offload.py
65+
66+ :param params: PARAMETERS. a list of parameters or parameter groups.
67+ :param optimizer_class: Type[torch.optim.Optimizer]. constructor of the base optimizer. Defaults to
68+ :class:`torch.optim.AdamW`.
69+ :param offload_gradients: bool. free GPU gradients once they are moved to CPU. Not compatible with gradient
70+ accumulation.
71+ :param kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`.
72+ """
73+
74+ def __init__ (
75+ self ,
76+ params : PARAMETERS ,
77+ optimizer_class : Type [torch .optim .Optimizer ] = torch .optim .AdamW ,
78+ * ,
79+ offload_gradients : bool = False ,
80+ ** kwargs ,
81+ ) -> None :
82+ if optimizer_class is torch .optim .AdamW and TORCH_VERSION_AT_LEAST_2_4 and 'fused' not in kwargs :
83+ kwargs .update (fused = True )
84+
85+ param_groups = list (params )
86+ if len (param_groups ) == 0 :
87+ raise ValueError ('optimizer got an empty parameter list' )
88+ if not isinstance (param_groups [0 ], dict ):
89+ param_groups = [{'params' : param_groups }]
90+
91+ self .param_cuda2cpu_map = {}
92+ self .optim_dict = {}
93+ self .stream = torch .cuda .Stream ()
94+
95+ self .queue = {}
96+
97+ def backward_hook (p_cuda : torch .Tensor ) -> None :
98+ if p_cuda .grad is None :
99+ return
100+
101+ p_cpu = self .param_cuda2cpu_map [p_cuda ]
102+
103+ self .stream .wait_stream (torch .cuda .current_stream ())
104+ with torch .cuda .stream (self .stream ):
105+ p_cpu .grad .copy_ (p_cuda .grad , non_blocking = True )
106+
107+ if p_cuda in self .queue :
108+ del self .queue [p_cuda ]
109+
110+ self .queue [p_cuda ] = self .stream .record_event ()
111+
112+ if offload_gradients :
113+ p_cuda .grad .record_stream (self .stream )
114+ p_cuda .grad = None
115+
116+ for param_group in param_groups :
117+ params = param_group .pop ('params' )
118+
119+ for p_cuda in params :
120+ p_cpu = torch .empty_like (p_cuda , device = 'cpu' , pin_memory = True )
121+ p_cpu .grad = torch .empty_like (p_cpu , pin_memory = True )
122+
123+ p_cpu .copy_ (p_cuda .detach (), non_blocking = True )
124+ self .param_cuda2cpu_map [p_cuda ] = p_cpu
125+
126+ p_cuda .register_post_accumulate_grad_hook (backward_hook )
127+ self .optim_dict [p_cuda ] = optimizer_class ([{'params' : p_cpu , ** param_group }], ** kwargs )
128+
129+ @torch .no_grad ()
130+ def step (self , closure : CLOSURE = None ) -> LOSS :
131+ loss = None
132+ if closure is not None :
133+ loss = closure ()
134+
135+ for p_cuda , grad_d2h_event in self .queue .items ():
136+ grad_d2h_event .synchronize ()
137+ self .optim_dict [p_cuda ].step ()
138+
139+ p_cpu = self .param_cuda2cpu_map [p_cuda ]
140+ with torch .cuda .stream (self .stream ):
141+ p_cuda .copy_ (p_cpu , non_blocking = True )
142+
143+ self .queue .clear ()
144+
145+ return loss
146+
147+ def zero_grad (self , _ : bool = True ) -> None :
148+ for p_cuda in self .param_cuda2cpu_map :
149+ p_cuda .grad = None
150+
151+ @property
152+ def param_groups (self ):
153+ return functools .reduce (operator .add , (optim .param_groups for optim in self .optim_dict .values ()), [])
154+
155+ def state_dict (self ):
156+ return [optim .state_dict () for optim in self .optim_dict .values ()]
157+
158+ def load_state_dict (self , state_dict ):
159+ for optim , optim_state_dict in zip (self .optim_dict .values (), state_dict ):
160+ optim .load_state_dict (optim_state_dict )
161+
162+
39163def is_valid_parameters (parameters : PARAMETERS ) -> bool :
40164 r"""Check where the parameters are valid."""
41165 return isinstance (parameters , (list , tuple )) and len (parameters ) > 0 and isinstance (parameters [0 ], dict )
0 commit comments