4949from ..fp8_utils import dequantize_fp8_tensor , is_float8tensor , quantize_param_shard
5050from ..transformer .fsdp_dtensor_checkpoint import handle_experts_in_state_dict
5151from ..transformer .module import MegatronModule
52- from .cpu_offloading .optimizer_state_offloader import OptimizerStateOffloader
5352from .grad_scaler import MegatronGradScaler
5453from .optimizer import MixedPrecisionOptimizer , _zero_grad_group_helper , param_group_identifier_keys
5554from .optimizer_config import OptimizerConfig
@@ -605,10 +604,6 @@ def __init__(
605604 self .optimizer .param_groups = [g ["orig_group" ] for g in self .opt_group_ranges ]
606605 self .optimizer .load_state_dict (self .optimizer .state_dict ())
607606
608- self ._state_offloader : Optional [OptimizerStateOffloader ] = None
609- if self .config .offload_optimizer_states :
610- self ._state_offloader = OptimizerStateOffloader (self )
611-
612607 def _get_model_param_range_map (self , param : torch .nn .Parameter ):
613608 """
614609 Given a model param, get the index sub-range of the param that this
@@ -2585,8 +2580,6 @@ def step_with_ready_grads(self) -> bool:
25852580 Under the hood, either launch synchronous param all-gathers or get ready to launch
25862581 asynchorous all-gathers that get overlapped with the next forward pass.
25872582 """
2588- if self ._state_offloader is not None :
2589- self ._state_offloader .sync_before_step ()
25902583 update_successful = super ().step_with_ready_grads ()
25912584
25922585 timers = self .config .timers
@@ -2607,22 +2600,4 @@ def step_with_ready_grads(self) -> bool:
26072600 if timers is not None :
26082601 timers ('params-all-gather' ).stop ()
26092602
2610- if self ._state_offloader is not None :
2611- self ._state_offloader .mark_optimizer_states_initialized ()
2612-
26132603 return update_successful
2614-
2615- def offload_states (self ):
2616- """Offload states to CPU."""
2617- if self ._state_offloader is not None :
2618- self ._state_offloader .offload ()
2619-
2620- def reload_offloaded_states (self ):
2621- """Start async reload of offloaded states."""
2622- if self ._state_offloader is not None :
2623- self ._state_offloader .reload ()
2624-
2625- def release_offloaded_gpu_states (self ):
2626- """Release GPU memory after D2H completes. For delayed release case."""
2627- if self ._state_offloader is not None :
2628- self ._state_offloader .release_gpu_memory ()
0 commit comments