Skip to content

Commit 6b157e0

Browse files
hxbaiyaox12
andauthored
[Dev] Optimizer State and Master Weight Offloading (NVIDIA#2760)
Co-authored-by: Xin Yao <xiny@nvidia.com>
1 parent b927e1f commit 6b157e0

File tree

6 files changed

+725
-1
lines changed

6 files changed

+725
-1
lines changed
Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
3+
"""Optimizer state offloading class."""
4+
5+
from typing import TYPE_CHECKING, Dict, List, Tuple
6+
7+
import torch
8+
9+
if TYPE_CHECKING:
10+
from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer
11+
12+
13+
class OptimizerStateOffloader:
14+
"""
15+
Manages offloading of optimizer states and master weights to CPU.
16+
Used with DistributedOptimizer to reduce GPU memory usage.
17+
18+
Supports overlapped D2H/H2D transfers using CUDA streams.
19+
20+
Master weights can be stored in two locations:
21+
- In adam optimizer state (when use_precision_aware_optimizer_no_fp8_or_ds_fp8 is True)
22+
- In mcore's shard_fp32_from_float16_groups
23+
"""
24+
25+
OPTIMIZER_STATE_KEYS = ('exp_avg', 'exp_avg_sq')
26+
MASTER_WEIGHT_KEY = 'master_param'
27+
28+
def __init__(self, distrib_optimizer: "DistributedOptimizer"):
29+
"""
30+
Args:
31+
distrib_optimizer: The DistributedOptimizer to offload states and master weights from.
32+
"""
33+
self.dist_optimizer = distrib_optimizer
34+
self.adam_optimizer = distrib_optimizer.optimizer
35+
36+
# Only support TE FusedAdam optimizer for now.
37+
try:
38+
from transformer_engine.pytorch.optimizers import FusedAdam
39+
40+
assert isinstance(self.adam_optimizer, FusedAdam), (
41+
f"OptimizerStateOffloader requires TE FusedAdam optimizer, "
42+
f"but got {type(self.adam_optimizer).__name__}"
43+
)
44+
except ImportError:
45+
raise ImportError(
46+
"OptimizerStateOffloader requires transformer_engine.pytorch.optimizers.FusedAdam"
47+
)
48+
49+
# Check if master weights are stored in adam optimizer state
50+
self.optimizer_contains_master_weights = self.adam_optimizer.master_weights
51+
52+
# CUDA streams for async transfers
53+
self._d2h_stream = torch.cuda.Stream()
54+
self._h2d_stream = torch.cuda.Stream()
55+
56+
# CPU buffers for optimizer states: {param: {key: cpu_tensor}}
57+
self._opt_state_cpu_buffers: Dict[torch.Tensor, Dict[str, torch.Tensor]] = {}
58+
59+
# CPU buffers for mcore master weights, matching the structure of source groups
60+
# List[List[cpu_tensor]]
61+
self._shard_fp32_from_float16_cpu_buffers: List[List[torch.Tensor]] = []
62+
63+
# State tracking
64+
self._offloaded = False
65+
self._offloaded_state_keys: Tuple[str, ...] = ()
66+
self._offloaded_mcore_master_weights = False
67+
68+
# Track whether optimizer states (exp_avg, exp_avg_sq) have been initialized.
69+
# These are lazily initialized by FusedAdam during the first optimizer.step().
70+
# Master weights (shard_fp32_from_float16_groups) are available from the start.
71+
self._optimizer_states_initialized = False
72+
73+
def mark_optimizer_states_initialized(self):
74+
"""
75+
Mark that optimizer states (exp_avg, exp_avg_sq) are now available.
76+
Should be called after the first optimizer.step() completes.
77+
"""
78+
self._optimizer_states_initialized = True
79+
80+
def _get_state_keys_to_offload(
81+
self, offload_optimizer_states: bool, offload_master_weights: bool
82+
) -> Tuple[str, ...]:
83+
"""Get the state keys in FusedAdam to offload based on configuration."""
84+
keys = []
85+
# Skip optimizer states offloading if they haven't been initialized yet.
86+
# Optimizer states are lazily initialized by FusedAdam during the first optimizer.step().
87+
if self._optimizer_states_initialized:
88+
if offload_optimizer_states:
89+
keys.extend(self.OPTIMIZER_STATE_KEYS)
90+
if offload_master_weights and self.optimizer_contains_master_weights:
91+
keys.append(self.MASTER_WEIGHT_KEY)
92+
return tuple(keys)
93+
94+
def _ensure_state_cpu_buffer(
95+
self, param: torch.Tensor, state_key: str, gpu_tensor: torch.Tensor, pin_memory: bool = True
96+
) -> torch.Tensor:
97+
"""Get or create a CPU buffer for a state tensor."""
98+
if param not in self._opt_state_cpu_buffers:
99+
self._opt_state_cpu_buffers[param] = {}
100+
101+
if state_key not in self._opt_state_cpu_buffers[param]:
102+
cpu_buffer = torch.empty(
103+
gpu_tensor.size(),
104+
dtype=gpu_tensor.dtype,
105+
layout=gpu_tensor.layout,
106+
device='cpu',
107+
pin_memory=pin_memory,
108+
)
109+
self._opt_state_cpu_buffers[param][state_key] = cpu_buffer
110+
111+
return self._opt_state_cpu_buffers[param][state_key]
112+
113+
def _offload_shard_groups(
114+
self,
115+
shard_groups: List[List[torch.Tensor]],
116+
cpu_buffers: List[List[torch.Tensor]],
117+
pin_memory: bool = True,
118+
):
119+
"""Offload a shard group to CPU buffers."""
120+
# Initialize CPU buffers on first call
121+
if len(cpu_buffers) == 0:
122+
for group in shard_groups:
123+
group_buffers = []
124+
for gpu_tensor in group:
125+
cpu_buffer = torch.empty(
126+
gpu_tensor.size(),
127+
dtype=gpu_tensor.dtype,
128+
layout=gpu_tensor.layout,
129+
device='cpu',
130+
pin_memory=pin_memory,
131+
)
132+
group_buffers.append(cpu_buffer)
133+
cpu_buffers.append(group_buffers)
134+
135+
# Copy D2H
136+
for group_idx, group in enumerate(shard_groups):
137+
for param_idx, gpu_tensor in enumerate(group):
138+
cpu_buffer = cpu_buffers[group_idx][param_idx]
139+
cpu_buffer.copy_(gpu_tensor, non_blocking=pin_memory)
140+
gpu_tensor.record_stream(self._d2h_stream)
141+
142+
def _offload_states(
143+
self,
144+
offload_optimizer_states: bool,
145+
offload_master_weights: bool,
146+
use_pin_memory: bool = True,
147+
):
148+
"""Offload optimizer states and/or master weights to CPU."""
149+
# Offload states from adam optimizer
150+
self._offloaded_state_keys = self._get_state_keys_to_offload(
151+
offload_optimizer_states, offload_master_weights
152+
)
153+
states = self.adam_optimizer.state
154+
155+
for param, param_state in states.items():
156+
for state_key in self._offloaded_state_keys:
157+
if state_key not in param_state:
158+
continue
159+
160+
gpu_tensor = param_state[state_key]
161+
if not isinstance(gpu_tensor, torch.Tensor) or not gpu_tensor.is_cuda:
162+
continue
163+
164+
cpu_buffer = self._ensure_state_cpu_buffer(
165+
param, state_key, gpu_tensor, use_pin_memory
166+
)
167+
cpu_buffer.copy_(gpu_tensor, non_blocking=use_pin_memory)
168+
gpu_tensor.record_stream(self._d2h_stream)
169+
170+
# Offload mcore master weights if not in optimizer state
171+
if offload_master_weights and not self.optimizer_contains_master_weights:
172+
self._offload_shard_groups(
173+
self.dist_optimizer.shard_fp32_from_float16_groups,
174+
self._shard_fp32_from_float16_cpu_buffers,
175+
use_pin_memory,
176+
)
177+
self._offloaded_mcore_master_weights = True
178+
179+
def _release_states(self):
180+
"""Replace optimizer state GPU tensors with CPU tensors to free GPU memory."""
181+
states = self.adam_optimizer.state
182+
183+
for param, param_state in states.items():
184+
if param not in self._opt_state_cpu_buffers:
185+
continue
186+
187+
for state_key in self._offloaded_state_keys:
188+
if state_key not in self._opt_state_cpu_buffers[param]:
189+
continue
190+
191+
param_state[state_key].untyped_storage().resize_(0)
192+
193+
if self._offloaded_mcore_master_weights:
194+
for group in self.dist_optimizer.shard_fp32_from_float16_groups:
195+
for gpu_tensor in group:
196+
gpu_tensor.untyped_storage().resize_(0)
197+
198+
def _reload_shard_groups(
199+
self,
200+
shard_groups: List[List[torch.Tensor]],
201+
cpu_buffers: List[List[torch.Tensor]],
202+
is_allocate_stage: bool,
203+
):
204+
"""Reload shard groups from CPU to GPU."""
205+
for group_idx, group in enumerate(shard_groups):
206+
for param_idx, _ in enumerate(group):
207+
cpu_buffer = cpu_buffers[group_idx][param_idx]
208+
if is_allocate_stage:
209+
shard_groups[group_idx][param_idx].untyped_storage().resize_(
210+
cpu_buffer.untyped_storage().size()
211+
)
212+
else:
213+
shard_groups[group_idx][param_idx].copy_(
214+
cpu_buffer, non_blocking=cpu_buffer.is_pinned()
215+
)
216+
217+
def _reload_states(self, is_allocate_stage: bool):
218+
"""
219+
Reload optimizer states and/or master weights from CPU to GPU.
220+
221+
If is_allocate_stage is True, only allocate GPU memory for the states and master weights,
222+
but do not copy the data from CPU to GPU. Otherwise, copy the data from CPU to GPU.
223+
The two processes are separated to make sure that the GPU memory is allocated on the
224+
default stream to avoid fragmentation.
225+
"""
226+
# Reload states to adam optimizer
227+
states = self.adam_optimizer.state
228+
229+
for param, param_state in states.items():
230+
if param not in self._opt_state_cpu_buffers:
231+
continue
232+
233+
for state_key in self._offloaded_state_keys:
234+
if state_key not in self._opt_state_cpu_buffers[param]:
235+
continue
236+
237+
cpu_buffer = self._opt_state_cpu_buffers[param][state_key]
238+
if is_allocate_stage:
239+
param_state[state_key].untyped_storage().resize_(
240+
cpu_buffer.untyped_storage().size()
241+
)
242+
else:
243+
param_state[state_key].copy_(cpu_buffer, non_blocking=cpu_buffer.is_pinned())
244+
245+
# Reload mcore master weights if not in optimizer state
246+
if self._offloaded_mcore_master_weights:
247+
self._reload_shard_groups(
248+
self.dist_optimizer.shard_fp32_from_float16_groups,
249+
self._shard_fp32_from_float16_cpu_buffers,
250+
is_allocate_stage,
251+
)
252+
253+
def offload(self, offload_optimizer_states: bool = True, offload_master_weights: bool = True):
254+
"""
255+
Offload optimizer states and/or master weights to CPU.
256+
Starts async D2H transfer that can overlap with other operations.
257+
258+
Args:
259+
offload_optimizer_states: Whether to offload exp_avg, exp_avg_sq.
260+
offload_master_weights: Whether to offload master weights.
261+
"""
262+
if not offload_optimizer_states and not offload_master_weights:
263+
return
264+
265+
# Wait for current stream finishing updating the optimizer states.
266+
self._d2h_stream.wait_stream(torch.cuda.current_stream())
267+
268+
with torch.cuda.stream(self._d2h_stream):
269+
self._offload_states(offload_optimizer_states, offload_master_weights)
270+
271+
self._offloaded = True
272+
273+
def release_gpu_memory(self):
274+
"""
275+
Release GPU memory for optimizer states and master weights after D2H copy completes.
276+
277+
This is separated from offload() to allow delayed GPU memory release,
278+
which is needed for mxfp8 + overlap_param_gather case where master weights
279+
must remain on GPU until after _copy_main_params_to_param_buffer() is called.
280+
"""
281+
if not self._offloaded:
282+
return
283+
284+
self._release_states()
285+
286+
def reload(self):
287+
"""
288+
Reload optimizer states and/or master weights from CPU to GPU.
289+
Call before optimizer.step() to ensure states are on GPU.
290+
"""
291+
if not self._offloaded:
292+
return
293+
294+
# Allocate GPU memory on the current stream to avoid fragmentation.
295+
self._reload_states(is_allocate_stage=True)
296+
297+
self._h2d_stream.wait_stream(self._d2h_stream)
298+
self._h2d_stream.wait_stream(torch.cuda.current_stream())
299+
300+
# Reload states on the h2d stream to overlap with other operations.
301+
with torch.cuda.stream(self._h2d_stream):
302+
self._reload_states(is_allocate_stage=False)
303+
304+
self._offloaded_state_keys = ()
305+
self._offloaded_mcore_master_weights = False
306+
self._offloaded = False
307+
308+
def sync_before_step(self):
309+
"""
310+
Wait for H2D reload to complete before optimizer.step().
311+
Must be called to ensure states are on GPU before optimizer uses them.
312+
313+
This is separated from reload() to make it possible to move the reload ahead of time.
314+
"""
315+
torch.cuda.current_stream().wait_stream(self._h2d_stream)

megatron/core/optimizer/distrib_optimizer.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from ..fp8_utils import dequantize_fp8_tensor, is_float8tensor, quantize_param_shard
5050
from ..transformer.fsdp_dtensor_checkpoint import handle_experts_in_state_dict
5151
from ..transformer.module import MegatronModule
52+
from .cpu_offloading.optimizer_state_offloader import OptimizerStateOffloader
5253
from .grad_scaler import MegatronGradScaler
5354
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper, param_group_identifier_keys
5455
from .optimizer_config import OptimizerConfig
@@ -604,6 +605,10 @@ def __init__(
604605
self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges]
605606
self.optimizer.load_state_dict(self.optimizer.state_dict())
606607

608+
self._state_offloader: Optional[OptimizerStateOffloader] = None
609+
if self.config.offload_optimizer_states:
610+
self._state_offloader = OptimizerStateOffloader(self)
611+
607612
def _get_model_param_range_map(self, param: torch.nn.Parameter):
608613
"""
609614
Given a model param, get the index sub-range of the param that this
@@ -2580,6 +2585,8 @@ def step_with_ready_grads(self) -> bool:
25802585
Under the hood, either launch synchronous param all-gathers or get ready to launch
25812586
asynchorous all-gathers that get overlapped with the next forward pass.
25822587
"""
2588+
if self._state_offloader is not None:
2589+
self._state_offloader.sync_before_step()
25832590
update_successful = super().step_with_ready_grads()
25842591

25852592
timers = self.config.timers
@@ -2600,4 +2607,22 @@ def step_with_ready_grads(self) -> bool:
26002607
if timers is not None:
26012608
timers('params-all-gather').stop()
26022609

2610+
if self._state_offloader is not None:
2611+
self._state_offloader.mark_optimizer_states_initialized()
2612+
26032613
return update_successful
2614+
2615+
def offload_states(self):
2616+
"""Offload states to CPU."""
2617+
if self._state_offloader is not None:
2618+
self._state_offloader.offload()
2619+
2620+
def reload_offloaded_states(self):
2621+
"""Start async reload of offloaded states."""
2622+
if self._state_offloader is not None:
2623+
self._state_offloader.reload()
2624+
2625+
def release_offloaded_gpu_states(self):
2626+
"""Release GPU memory after D2H completes. For delayed release case."""
2627+
if self._state_offloader is not None:
2628+
self._state_offloader.release_gpu_memory()

megatron/core/optimizer/optimizer_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,12 @@ class OptimizerConfig:
266266
pin_cpu_params: bool = True
267267
"""If True, pin the optimizer parameters to CPU memory."""
268268

269+
offload_optimizer_states: bool = False
270+
"""
271+
If True, offload optimizer states to CPU after each optimizer step and
272+
reload them before the next optimizer step.
273+
"""
274+
269275
################
270276
# Miscellaneous
271277
################

0 commit comments

Comments
 (0)