Skip to content

Commit a27aa49

Browse files
committed
fix master weight offloading
1 parent b3f0ab3 commit a27aa49

File tree

4 files changed

+111
-43
lines changed

4 files changed

+111
-43
lines changed

megatron/core/optimizer/cpu_offloading/optimizer_state_offloader.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@ class OptimizerStateOffloader:
2626
MASTER_WEIGHT_KEY = 'master_param'
2727

2828

29-
def __init__(
30-
self,
31-
distrib_optimizer: "DistributedOptimizer",
32-
):
29+
def __init__(self, distrib_optimizer: "DistributedOptimizer"):
3330
"""
3431
Args:
3532
distrib_optimizer: The DistributedOptimizer to offload states and master weights from.
@@ -69,23 +66,34 @@ def __init__(
6966
self._offloaded_state_keys: Tuple[str, ...] = ()
7067
self._offloaded_mcore_master_weights = False
7168

69+
# Track whether optimizer states (exp_avg, exp_avg_sq) have been initialized.
70+
# These are lazily initialized by FusedAdam during the first optimizer.step().
71+
# Master weights (shard_fp32_from_float16_groups) are available from the start.
72+
self._optimizer_states_initialized = False
73+
74+
def mark_optimizer_states_initialized(self):
75+
"""
76+
Mark that optimizer states (exp_avg, exp_avg_sq) are now available.
77+
Should be called after the first optimizer.step() completes.
78+
"""
79+
self._optimizer_states_initialized = True
80+
7281
def _get_state_keys_to_offload(
7382
self, offload_optimizer_states: bool, offload_master_weights: bool
7483
) -> Tuple[str, ...]:
7584
"""Get the state keys in FusedAdam to offload based on configuration."""
7685
keys = []
77-
if offload_optimizer_states:
78-
keys.extend(self.OPTIMIZER_STATE_KEYS)
79-
if offload_master_weights and self.optimizer_contains_master_weights:
80-
keys.append(self.MASTER_WEIGHT_KEY)
86+
# Skip optimizer states offloading if they haven't been initialized yet.
87+
# Optimizer states are lazily initialized by FusedAdam during the first optimizer.step().
88+
if self._optimizer_states_initialized:
89+
if offload_optimizer_states:
90+
keys.extend(self.OPTIMIZER_STATE_KEYS)
91+
if offload_master_weights and self.optimizer_contains_master_weights:
92+
keys.append(self.MASTER_WEIGHT_KEY)
8193
return tuple(keys)
8294

8395
def _ensure_state_cpu_buffer(
84-
self,
85-
param: torch.Tensor,
86-
state_key: str,
87-
gpu_tensor: torch.Tensor,
88-
pin_memory: bool = True,
96+
self, param: torch.Tensor, state_key: str, gpu_tensor: torch.Tensor, pin_memory: bool = True
8997
) -> torch.Tensor:
9098
"""Get or create a CPU buffer for a state tensor."""
9199
if param not in self._opt_state_cpu_buffers:
@@ -155,10 +163,7 @@ def _offload_states(
155163
continue
156164

157165
cpu_buffer = self._ensure_state_cpu_buffer(
158-
param,
159-
state_key,
160-
gpu_tensor,
161-
use_pin_memory,
166+
param, state_key, gpu_tensor, use_pin_memory
162167
)
163168
cpu_buffer.copy_(gpu_tensor, non_blocking=use_pin_memory)
164169
gpu_tensor.record_stream(self._d2h_stream)
@@ -246,9 +251,7 @@ def _reload_states(self, is_allocate_stage: bool):
246251
is_allocate_stage,
247252
)
248253

249-
def offload(
250-
self, offload_optimizer_states: bool = True, offload_master_weights: bool = True
251-
):
254+
def offload(self, offload_optimizer_states: bool = True, offload_master_weights: bool = True):
252255
"""
253256
Offload optimizer states and/or master weights to CPU.
254257
Starts async D2H transfer that can overlap with other operations.

megatron/core/optimizer/distrib_optimizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,8 @@ def __init__(
606606
self.optimizer.load_state_dict(self.optimizer.state_dict())
607607

608608
self._state_offloader: Optional[OptimizerStateOffloader] = None
609+
if self.config.offload_optimizer_states or self.config.offload_master_weights:
610+
self._state_offloader = OptimizerStateOffloader(self)
609611

610612
def _get_model_param_range_map(self, param: torch.nn.Parameter):
611613
"""
@@ -2605,10 +2607,8 @@ def step_with_ready_grads(self) -> bool:
26052607
if timers is not None:
26062608
timers('params-all-gather').stop()
26072609

2608-
# The states are initialized after the first optimizer.step() call.
2609-
# So we initialize the offloader here to make sure the states are initialized.
2610-
if self.config.offload_optimizer_states and self._state_offloader is None:
2611-
self._state_offloader = OptimizerStateOffloader(self)
2610+
if self._state_offloader is not None:
2611+
self._state_offloader.mark_optimizer_states_initialized()
26122612

26132613
return update_successful
26142614

megatron/training/arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2333,6 +2333,7 @@ def _add_training_args(parser):
23332333
dest='offload_optimizer_states',
23342334
help='Offload optimizer states to CPU after each optimizer step and '
23352335
'reload them before the next optimizer step. '
2336+
'Only support TE FusedAdam optimizer.'
23362337
'Note that this still uses pure GPU optimizer instead of '
23372338
'HybridDeviceOptimizer for --optimizer-cpu-offload.')
23382339
group.add_argument('--dataloader-type', type=str, default=None,

tests/unit_tests/test_optimizer_state_offloading.py

Lines changed: 83 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,7 @@ def test_offloader_initialization():
7171
model, optim = create_model_and_optimizer()
7272
dist_optim = optim.chained_optimizers[0]
7373

74-
# Before first step, offloader should be None
75-
assert dist_optim._state_offloader is None
76-
77-
# Run one step to initialize optimizer states
78-
run_forward_backward_step(model, optim)
79-
80-
# After first step, offloader should be initialized
74+
# Offloader is created in __init__ when offload_optimizer_states=True
8175
assert dist_optim._state_offloader is not None
8276
offloader = dist_optim._state_offloader
8377

@@ -86,11 +80,74 @@ def test_offloader_initialization():
8680
assert offloader._d2h_stream is not None
8781
assert offloader._h2d_stream is not None
8882
assert offloader._offloaded is False
83+
84+
# Before first step, optimizer states are not initialized yet
85+
assert offloader._optimizer_states_initialized is False
86+
87+
# Run one step to initialize optimizer states
88+
run_forward_backward_step(model, optim)
89+
90+
# After first step, optimizer states should be marked as initialized
91+
assert offloader._optimizer_states_initialized is True
92+
Utils.destroy_model_parallel()
93+
94+
95+
# =============================================================================
96+
# Test 2: Early Master Weight Offloading Before First Step
97+
# =============================================================================
98+
@pytest.mark.skipif(not TE_FUSED_ADAM_AVAILABLE, reason="Requires TE FusedAdam")
99+
def test_early_master_weight_offloading():
100+
"""Test that master weights can be offloaded before the first optimizer step."""
101+
Utils.initialize_model_parallel()
102+
model, optim = create_model_and_optimizer()
103+
dist_optim = optim.chained_optimizers[0]
104+
105+
# Offloader is created in __init__
106+
assert dist_optim._state_offloader is not None
107+
offloader = dist_optim._state_offloader
108+
109+
# Before first step, optimizer states are not initialized
110+
assert offloader._optimizer_states_initialized is False
111+
112+
# Capture original master weights before offload
113+
original_master_weights = []
114+
for group in dist_optim.shard_fp32_from_float16_groups:
115+
group_weights = [tensor.clone() for tensor in group]
116+
original_master_weights.append(group_weights)
117+
118+
# Offload before first step - should only offload master weights
119+
offloader.offload()
120+
offloader.release_gpu_memory()
121+
torch.cuda.synchronize()
122+
123+
# Verify master weights were offloaded (storage resized to 0)
124+
for group in dist_optim.shard_fp32_from_float16_groups:
125+
for tensor in group:
126+
assert tensor.untyped_storage().size() == 0, "Master weight should be offloaded"
127+
128+
# Reload master weights
129+
offloader.reload()
130+
offloader.sync_before_step()
131+
132+
# Verify master weights match after reload
133+
for group_idx, group in enumerate(dist_optim.shard_fp32_from_float16_groups):
134+
for param_idx, tensor in enumerate(group):
135+
original = original_master_weights[group_idx][param_idx]
136+
torch.testing.assert_close(
137+
tensor,
138+
original,
139+
msg=f"Master weight [{group_idx}][{param_idx}] mismatch after offload/reload",
140+
)
141+
142+
# Now run a step and verify optimizer states can be offloaded after
143+
run_forward_backward_step(model, optim)
144+
assert offloader._optimizer_states_initialized is True
145+
89146
Utils.destroy_model_parallel()
90147

91148

92149
# =============================================================================
93-
# Test 2: Offload and Reload Correctness
150+
# Test 3: Offload and Reload Correctness
94151
# =============================================================================
95152
@pytest.mark.skipif(not TE_FUSED_ADAM_AVAILABLE, reason="Requires TE FusedAdam")
96153
@pytest.mark.parametrize("offload_optimizer_states", [True, False])
@@ -139,13 +196,15 @@ def test_offload_reload_correctness(offload_optimizer_states, offload_master_wei
139196
reloaded_tensor = state[key]
140197
assert reloaded_tensor.device.type == 'cuda', f"State {key} should be on GPU"
141198
torch.testing.assert_close(
142-
reloaded_tensor, original_tensor, msg=f"State {key} mismatch after offload/reload"
199+
reloaded_tensor,
200+
original_tensor,
201+
msg=f"State {key} mismatch after offload/reload",
143202
)
144203
Utils.destroy_model_parallel()
145204

146205

147206
# =============================================================================
148-
# Test 3: GPU Memory Release Verification
207+
# Test 4: GPU Memory Release Verification
149208
# =============================================================================
150209
@pytest.mark.skipif(not TE_FUSED_ADAM_AVAILABLE, reason="Requires TE FusedAdam")
151210
def test_gpu_memory_release():
@@ -181,7 +240,7 @@ def test_gpu_memory_release():
181240

182241

183242
# =============================================================================
184-
# Test 4: Multiple Offload/Reload Cycles
243+
# Test 5: Multiple Offload/Reload Cycles
185244
# =============================================================================
186245
@pytest.mark.skipif(not TE_FUSED_ADAM_AVAILABLE, reason="Requires TE FusedAdam")
187246
def test_multiple_offload_reload_cycles():
@@ -216,7 +275,7 @@ def test_multiple_offload_reload_cycles():
216275

217276

218277
# =============================================================================
219-
# Test 5: Training Correctness with Offloading
278+
# Test 6: Training Correctness with Offloading
220279
# =============================================================================
221280
@pytest.mark.skipif(not TE_FUSED_ADAM_AVAILABLE, reason="Requires TE FusedAdam")
222281
def test_training_correctness_with_offloading():
@@ -234,22 +293,27 @@ def test_training_correctness_with_offloading():
234293
# Train both models
235294
n_steps = 10
236295
torch.manual_seed(123)
296+
dist_optim1 = optim1.chained_optimizers[0]
297+
298+
# Offloader is created in __init__ when offload_optimizer_states=True
299+
assert dist_optim1._state_offloader is not None
300+
offloader = dist_optim1._state_offloader
301+
237302
for step in range(n_steps):
238303
input_tensor = torch.randn(8, 256, dtype=torch.bfloat16, device='cuda')
239304

240305
# Model 1 with offloading
241-
dist_optim1 = optim1.chained_optimizers[0]
242-
if dist_optim1._state_offloader is not None:
243-
dist_optim1._state_offloader.offload()
244-
dist_optim1._state_offloader.release_gpu_memory()
306+
# Offload states (master weights can be offloaded from the start,
307+
# optimizer states will be skipped until after first step)
308+
offloader.offload()
309+
offloader.release_gpu_memory()
245310

246311
output1 = model1(input_tensor)
247312
loss1 = output1.sum()
248313
loss1.backward()
249314

250-
if dist_optim1._state_offloader is not None:
251-
dist_optim1._state_offloader.reload()
252-
dist_optim1._state_offloader.sync_before_step()
315+
offloader.reload()
316+
offloader.sync_before_step()
253317
optim1.step()
254318
optim1.zero_grad()
255319

0 commit comments

Comments
 (0)