@@ -71,13 +71,7 @@ def test_offloader_initialization():
7171 model , optim = create_model_and_optimizer ()
7272 dist_optim = optim .chained_optimizers [0 ]
7373
74- # Before first step, offloader should be None
75- assert dist_optim ._state_offloader is None
76-
77- # Run one step to initialize optimizer states
78- run_forward_backward_step (model , optim )
79-
80- # After first step, offloader should be initialized
74+ # Offloader is created in __init__ when offload_optimizer_states=True
8175 assert dist_optim ._state_offloader is not None
8276 offloader = dist_optim ._state_offloader
8377
@@ -86,11 +80,74 @@ def test_offloader_initialization():
8680 assert offloader ._d2h_stream is not None
8781 assert offloader ._h2d_stream is not None
8882 assert offloader ._offloaded is False
83+
84+ # Before first step, optimizer states are not initialized yet
85+ assert offloader ._optimizer_states_initialized is False
86+
87+ # Run one step to initialize optimizer states
88+ run_forward_backward_step (model , optim )
89+
90+ # After first step, optimizer states should be marked as initialized
91+ assert offloader ._optimizer_states_initialized is True
92+ Utils .destroy_model_parallel ()
93+
94+
95+ # =============================================================================
96+ # Test 2: Early Master Weight Offloading Before First Step
97+ # =============================================================================
98+ @pytest .mark .skipif (not TE_FUSED_ADAM_AVAILABLE , reason = "Requires TE FusedAdam" )
99+ def test_early_master_weight_offloading ():
100+ """Test that master weights can be offloaded before the first optimizer step."""
101+ Utils .initialize_model_parallel ()
102+ model , optim = create_model_and_optimizer ()
103+ dist_optim = optim .chained_optimizers [0 ]
104+
105+ # Offloader is created in __init__
106+ assert dist_optim ._state_offloader is not None
107+ offloader = dist_optim ._state_offloader
108+
109+ # Before first step, optimizer states are not initialized
110+ assert offloader ._optimizer_states_initialized is False
111+
112+ # Capture original master weights before offload
113+ original_master_weights = []
114+ for group in dist_optim .shard_fp32_from_float16_groups :
115+ group_weights = [tensor .clone () for tensor in group ]
116+ original_master_weights .append (group_weights )
117+
118+ # Offload before first step - should only offload master weights
119+ offloader .offload ()
120+ offloader .release_gpu_memory ()
121+ torch .cuda .synchronize ()
122+
123+ # Verify master weights were offloaded (storage resized to 0)
124+ for group in dist_optim .shard_fp32_from_float16_groups :
125+ for tensor in group :
126+ assert tensor .untyped_storage ().size () == 0 , "Master weight should be offloaded"
127+
128+ # Reload master weights
129+ offloader .reload ()
130+ offloader .sync_before_step ()
131+
132+ # Verify master weights match after reload
133+ for group_idx , group in enumerate (dist_optim .shard_fp32_from_float16_groups ):
134+ for param_idx , tensor in enumerate (group ):
135+ original = original_master_weights [group_idx ][param_idx ]
136+ torch .testing .assert_close (
137+ tensor ,
138+ original ,
139+ msg = f"Master weight [{ group_idx } ][{ param_idx } ] mismatch after offload/reload" ,
140+ )
141+
142+ # Now run a step and verify optimizer states can be offloaded after
143+ run_forward_backward_step (model , optim )
144+ assert offloader ._optimizer_states_initialized is True
145+
89146 Utils .destroy_model_parallel ()
90147
91148
92149# =============================================================================
93- # Test 2 : Offload and Reload Correctness
150+ # Test 3 : Offload and Reload Correctness
94151# =============================================================================
95152@pytest .mark .skipif (not TE_FUSED_ADAM_AVAILABLE , reason = "Requires TE FusedAdam" )
96153@pytest .mark .parametrize ("offload_optimizer_states" , [True , False ])
@@ -139,13 +196,15 @@ def test_offload_reload_correctness(offload_optimizer_states, offload_master_wei
139196 reloaded_tensor = state [key ]
140197 assert reloaded_tensor .device .type == 'cuda' , f"State { key } should be on GPU"
141198 torch .testing .assert_close (
142- reloaded_tensor , original_tensor , msg = f"State { key } mismatch after offload/reload"
199+ reloaded_tensor ,
200+ original_tensor ,
201+ msg = f"State { key } mismatch after offload/reload" ,
143202 )
144203 Utils .destroy_model_parallel ()
145204
146205
147206# =============================================================================
148- # Test 3 : GPU Memory Release Verification
207+ # Test 4 : GPU Memory Release Verification
149208# =============================================================================
150209@pytest .mark .skipif (not TE_FUSED_ADAM_AVAILABLE , reason = "Requires TE FusedAdam" )
151210def test_gpu_memory_release ():
@@ -181,7 +240,7 @@ def test_gpu_memory_release():
181240
182241
183242# =============================================================================
184- # Test 4 : Multiple Offload/Reload Cycles
243+ # Test 5 : Multiple Offload/Reload Cycles
185244# =============================================================================
186245@pytest .mark .skipif (not TE_FUSED_ADAM_AVAILABLE , reason = "Requires TE FusedAdam" )
187246def test_multiple_offload_reload_cycles ():
@@ -216,7 +275,7 @@ def test_multiple_offload_reload_cycles():
216275
217276
218277# =============================================================================
219- # Test 5 : Training Correctness with Offloading
278+ # Test 6 : Training Correctness with Offloading
220279# =============================================================================
221280@pytest .mark .skipif (not TE_FUSED_ADAM_AVAILABLE , reason = "Requires TE FusedAdam" )
222281def test_training_correctness_with_offloading ():
@@ -234,22 +293,27 @@ def test_training_correctness_with_offloading():
234293 # Train both models
235294 n_steps = 10
236295 torch .manual_seed (123 )
296+ dist_optim1 = optim1 .chained_optimizers [0 ]
297+
298+ # Offloader is created in __init__ when offload_optimizer_states=True
299+ assert dist_optim1 ._state_offloader is not None
300+ offloader = dist_optim1 ._state_offloader
301+
237302 for step in range (n_steps ):
238303 input_tensor = torch .randn (8 , 256 , dtype = torch .bfloat16 , device = 'cuda' )
239304
240305 # Model 1 with offloading
241- dist_optim1 = optim1 . chained_optimizers [ 0 ]
242- if dist_optim1 . _state_offloader is not None :
243- dist_optim1 . _state_offloader .offload ()
244- dist_optim1 . _state_offloader .release_gpu_memory ()
306+ # Offload states (master weights can be offloaded from the start,
307+ # optimizer states will be skipped until after first step)
308+ offloader .offload ()
309+ offloader .release_gpu_memory ()
245310
246311 output1 = model1 (input_tensor )
247312 loss1 = output1 .sum ()
248313 loss1 .backward ()
249314
250- if dist_optim1 ._state_offloader is not None :
251- dist_optim1 ._state_offloader .reload ()
252- dist_optim1 ._state_offloader .sync_before_step ()
315+ offloader .reload ()
316+ offloader .sync_before_step ()
253317 optim1 .step ()
254318 optim1 .zero_grad ()
255319
0 commit comments