|
| 1 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 2 | + |
| 3 | +"""Optimizer state offloading class.""" |
| 4 | + |
| 5 | +from typing import TYPE_CHECKING, Dict, List, Tuple |
| 6 | + |
| 7 | +import torch |
| 8 | + |
| 9 | +if TYPE_CHECKING: |
| 10 | + from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer |
| 11 | + |
| 12 | + |
| 13 | +class OptimizerStateOffloader: |
| 14 | + """ |
| 15 | + Manages offloading of optimizer states and master weights to CPU. |
| 16 | + Used with DistributedOptimizer to reduce GPU memory usage. |
| 17 | +
|
| 18 | + Supports overlapped D2H/H2D transfers using CUDA streams. |
| 19 | +
|
| 20 | + Master weights can be stored in two locations: |
| 21 | + - In adam optimizer state (when use_precision_aware_optimizer_no_fp8_or_ds_fp8 is True) |
| 22 | + - In mcore's shard_fp32_from_float16_groups |
| 23 | + """ |
| 24 | + |
| 25 | + OPTIMIZER_STATE_KEYS = ('exp_avg', 'exp_avg_sq') |
| 26 | + MASTER_WEIGHT_KEY = 'master_param' |
| 27 | + |
| 28 | + def __init__(self, distrib_optimizer: "DistributedOptimizer"): |
| 29 | + """ |
| 30 | + Args: |
| 31 | + distrib_optimizer: The DistributedOptimizer to offload states and master weights from. |
| 32 | + """ |
| 33 | + self.dist_optimizer = distrib_optimizer |
| 34 | + self.adam_optimizer = distrib_optimizer.optimizer |
| 35 | + |
| 36 | + # Only support TE FusedAdam optimizer for now. |
| 37 | + try: |
| 38 | + from transformer_engine.pytorch.optimizers import FusedAdam |
| 39 | + |
| 40 | + assert isinstance(self.adam_optimizer, FusedAdam), ( |
| 41 | + f"OptimizerStateOffloader requires TE FusedAdam optimizer, " |
| 42 | + f"but got {type(self.adam_optimizer).__name__}" |
| 43 | + ) |
| 44 | + except ImportError: |
| 45 | + raise ImportError( |
| 46 | + "OptimizerStateOffloader requires transformer_engine.pytorch.optimizers.FusedAdam" |
| 47 | + ) |
| 48 | + |
| 49 | + # Check if master weights are stored in adam optimizer state |
| 50 | + self.optimizer_contains_master_weights = self.adam_optimizer.master_weights |
| 51 | + |
| 52 | + # CUDA streams for async transfers |
| 53 | + self._d2h_stream = torch.cuda.Stream() |
| 54 | + self._h2d_stream = torch.cuda.Stream() |
| 55 | + |
| 56 | + # CPU buffers for optimizer states: {param: {key: cpu_tensor}} |
| 57 | + self._opt_state_cpu_buffers: Dict[torch.Tensor, Dict[str, torch.Tensor]] = {} |
| 58 | + |
| 59 | + # CPU buffers for mcore master weights, matching the structure of source groups |
| 60 | + # List[List[cpu_tensor]] |
| 61 | + self._shard_fp32_from_float16_cpu_buffers: List[List[torch.Tensor]] = [] |
| 62 | + |
| 63 | + # State tracking |
| 64 | + self._offloaded = False |
| 65 | + self._offloaded_state_keys: Tuple[str, ...] = () |
| 66 | + self._offloaded_mcore_master_weights = False |
| 67 | + |
| 68 | + # Track whether optimizer states (exp_avg, exp_avg_sq) have been initialized. |
| 69 | + # These are lazily initialized by FusedAdam during the first optimizer.step(). |
| 70 | + # Master weights (shard_fp32_from_float16_groups) are available from the start. |
| 71 | + self._optimizer_states_initialized = False |
| 72 | + |
| 73 | + def mark_optimizer_states_initialized(self): |
| 74 | + """ |
| 75 | + Mark that optimizer states (exp_avg, exp_avg_sq) are now available. |
| 76 | + Should be called after the first optimizer.step() completes. |
| 77 | + """ |
| 78 | + self._optimizer_states_initialized = True |
| 79 | + |
| 80 | + def _get_state_keys_to_offload( |
| 81 | + self, offload_optimizer_states: bool, offload_master_weights: bool |
| 82 | + ) -> Tuple[str, ...]: |
| 83 | + """Get the state keys in FusedAdam to offload based on configuration.""" |
| 84 | + keys = [] |
| 85 | + # Skip optimizer states offloading if they haven't been initialized yet. |
| 86 | + # Optimizer states are lazily initialized by FusedAdam during the first optimizer.step(). |
| 87 | + if self._optimizer_states_initialized: |
| 88 | + if offload_optimizer_states: |
| 89 | + keys.extend(self.OPTIMIZER_STATE_KEYS) |
| 90 | + if offload_master_weights and self.optimizer_contains_master_weights: |
| 91 | + keys.append(self.MASTER_WEIGHT_KEY) |
| 92 | + return tuple(keys) |
| 93 | + |
| 94 | + def _ensure_state_cpu_buffer( |
| 95 | + self, param: torch.Tensor, state_key: str, gpu_tensor: torch.Tensor, pin_memory: bool = True |
| 96 | + ) -> torch.Tensor: |
| 97 | + """Get or create a CPU buffer for a state tensor.""" |
| 98 | + if param not in self._opt_state_cpu_buffers: |
| 99 | + self._opt_state_cpu_buffers[param] = {} |
| 100 | + |
| 101 | + if state_key not in self._opt_state_cpu_buffers[param]: |
| 102 | + cpu_buffer = torch.empty( |
| 103 | + gpu_tensor.size(), |
| 104 | + dtype=gpu_tensor.dtype, |
| 105 | + layout=gpu_tensor.layout, |
| 106 | + device='cpu', |
| 107 | + pin_memory=pin_memory, |
| 108 | + ) |
| 109 | + self._opt_state_cpu_buffers[param][state_key] = cpu_buffer |
| 110 | + |
| 111 | + return self._opt_state_cpu_buffers[param][state_key] |
| 112 | + |
| 113 | + def _offload_shard_groups( |
| 114 | + self, |
| 115 | + shard_groups: List[List[torch.Tensor]], |
| 116 | + cpu_buffers: List[List[torch.Tensor]], |
| 117 | + pin_memory: bool = True, |
| 118 | + ): |
| 119 | + """Offload a shard group to CPU buffers.""" |
| 120 | + # Initialize CPU buffers on first call |
| 121 | + if len(cpu_buffers) == 0: |
| 122 | + for group in shard_groups: |
| 123 | + group_buffers = [] |
| 124 | + for gpu_tensor in group: |
| 125 | + cpu_buffer = torch.empty( |
| 126 | + gpu_tensor.size(), |
| 127 | + dtype=gpu_tensor.dtype, |
| 128 | + layout=gpu_tensor.layout, |
| 129 | + device='cpu', |
| 130 | + pin_memory=pin_memory, |
| 131 | + ) |
| 132 | + group_buffers.append(cpu_buffer) |
| 133 | + cpu_buffers.append(group_buffers) |
| 134 | + |
| 135 | + # Copy D2H |
| 136 | + for group_idx, group in enumerate(shard_groups): |
| 137 | + for param_idx, gpu_tensor in enumerate(group): |
| 138 | + cpu_buffer = cpu_buffers[group_idx][param_idx] |
| 139 | + cpu_buffer.copy_(gpu_tensor, non_blocking=pin_memory) |
| 140 | + gpu_tensor.record_stream(self._d2h_stream) |
| 141 | + |
| 142 | + def _offload_states( |
| 143 | + self, |
| 144 | + offload_optimizer_states: bool, |
| 145 | + offload_master_weights: bool, |
| 146 | + use_pin_memory: bool = True, |
| 147 | + ): |
| 148 | + """Offload optimizer states and/or master weights to CPU.""" |
| 149 | + # Offload states from adam optimizer |
| 150 | + self._offloaded_state_keys = self._get_state_keys_to_offload( |
| 151 | + offload_optimizer_states, offload_master_weights |
| 152 | + ) |
| 153 | + states = self.adam_optimizer.state |
| 154 | + |
| 155 | + for param, param_state in states.items(): |
| 156 | + for state_key in self._offloaded_state_keys: |
| 157 | + if state_key not in param_state: |
| 158 | + continue |
| 159 | + |
| 160 | + gpu_tensor = param_state[state_key] |
| 161 | + if not isinstance(gpu_tensor, torch.Tensor) or not gpu_tensor.is_cuda: |
| 162 | + continue |
| 163 | + |
| 164 | + cpu_buffer = self._ensure_state_cpu_buffer( |
| 165 | + param, state_key, gpu_tensor, use_pin_memory |
| 166 | + ) |
| 167 | + cpu_buffer.copy_(gpu_tensor, non_blocking=use_pin_memory) |
| 168 | + gpu_tensor.record_stream(self._d2h_stream) |
| 169 | + |
| 170 | + # Offload mcore master weights if not in optimizer state |
| 171 | + if offload_master_weights and not self.optimizer_contains_master_weights: |
| 172 | + self._offload_shard_groups( |
| 173 | + self.dist_optimizer.shard_fp32_from_float16_groups, |
| 174 | + self._shard_fp32_from_float16_cpu_buffers, |
| 175 | + use_pin_memory, |
| 176 | + ) |
| 177 | + self._offloaded_mcore_master_weights = True |
| 178 | + |
| 179 | + def _release_states(self): |
| 180 | + """Replace optimizer state GPU tensors with CPU tensors to free GPU memory.""" |
| 181 | + states = self.adam_optimizer.state |
| 182 | + |
| 183 | + for param, param_state in states.items(): |
| 184 | + if param not in self._opt_state_cpu_buffers: |
| 185 | + continue |
| 186 | + |
| 187 | + for state_key in self._offloaded_state_keys: |
| 188 | + if state_key not in self._opt_state_cpu_buffers[param]: |
| 189 | + continue |
| 190 | + |
| 191 | + param_state[state_key].untyped_storage().resize_(0) |
| 192 | + |
| 193 | + if self._offloaded_mcore_master_weights: |
| 194 | + for group in self.dist_optimizer.shard_fp32_from_float16_groups: |
| 195 | + for gpu_tensor in group: |
| 196 | + gpu_tensor.untyped_storage().resize_(0) |
| 197 | + |
| 198 | + def _reload_shard_groups( |
| 199 | + self, |
| 200 | + shard_groups: List[List[torch.Tensor]], |
| 201 | + cpu_buffers: List[List[torch.Tensor]], |
| 202 | + is_allocate_stage: bool, |
| 203 | + ): |
| 204 | + """Reload shard groups from CPU to GPU.""" |
| 205 | + for group_idx, group in enumerate(shard_groups): |
| 206 | + for param_idx, _ in enumerate(group): |
| 207 | + cpu_buffer = cpu_buffers[group_idx][param_idx] |
| 208 | + if is_allocate_stage: |
| 209 | + shard_groups[group_idx][param_idx].untyped_storage().resize_( |
| 210 | + cpu_buffer.untyped_storage().size() |
| 211 | + ) |
| 212 | + else: |
| 213 | + shard_groups[group_idx][param_idx].copy_( |
| 214 | + cpu_buffer, non_blocking=cpu_buffer.is_pinned() |
| 215 | + ) |
| 216 | + |
| 217 | + def _reload_states(self, is_allocate_stage: bool): |
| 218 | + """ |
| 219 | + Reload optimizer states and/or master weights from CPU to GPU. |
| 220 | +
|
| 221 | + If is_allocate_stage is True, only allocate GPU memory for the states and master weights, |
| 222 | + but do not copy the data from CPU to GPU. Otherwise, copy the data from CPU to GPU. |
| 223 | + The two processes are separated to make sure that the GPU memory is allocated on the |
| 224 | + default stream to avoid fragmentation. |
| 225 | + """ |
| 226 | + # Reload states to adam optimizer |
| 227 | + states = self.adam_optimizer.state |
| 228 | + |
| 229 | + for param, param_state in states.items(): |
| 230 | + if param not in self._opt_state_cpu_buffers: |
| 231 | + continue |
| 232 | + |
| 233 | + for state_key in self._offloaded_state_keys: |
| 234 | + if state_key not in self._opt_state_cpu_buffers[param]: |
| 235 | + continue |
| 236 | + |
| 237 | + cpu_buffer = self._opt_state_cpu_buffers[param][state_key] |
| 238 | + if is_allocate_stage: |
| 239 | + param_state[state_key].untyped_storage().resize_( |
| 240 | + cpu_buffer.untyped_storage().size() |
| 241 | + ) |
| 242 | + else: |
| 243 | + param_state[state_key].copy_(cpu_buffer, non_blocking=cpu_buffer.is_pinned()) |
| 244 | + |
| 245 | + # Reload mcore master weights if not in optimizer state |
| 246 | + if self._offloaded_mcore_master_weights: |
| 247 | + self._reload_shard_groups( |
| 248 | + self.dist_optimizer.shard_fp32_from_float16_groups, |
| 249 | + self._shard_fp32_from_float16_cpu_buffers, |
| 250 | + is_allocate_stage, |
| 251 | + ) |
| 252 | + |
| 253 | + def offload(self, offload_optimizer_states: bool = True, offload_master_weights: bool = True): |
| 254 | + """ |
| 255 | + Offload optimizer states and/or master weights to CPU. |
| 256 | + Starts async D2H transfer that can overlap with other operations. |
| 257 | +
|
| 258 | + Args: |
| 259 | + offload_optimizer_states: Whether to offload exp_avg, exp_avg_sq. |
| 260 | + offload_master_weights: Whether to offload master weights. |
| 261 | + """ |
| 262 | + if not offload_optimizer_states and not offload_master_weights: |
| 263 | + return |
| 264 | + |
| 265 | + # Wait for current stream finishing updating the optimizer states. |
| 266 | + self._d2h_stream.wait_stream(torch.cuda.current_stream()) |
| 267 | + |
| 268 | + with torch.cuda.stream(self._d2h_stream): |
| 269 | + self._offload_states(offload_optimizer_states, offload_master_weights) |
| 270 | + |
| 271 | + self._offloaded = True |
| 272 | + |
| 273 | + def release_gpu_memory(self): |
| 274 | + """ |
| 275 | + Release GPU memory for optimizer states and master weights after D2H copy completes. |
| 276 | +
|
| 277 | + This is separated from offload() to allow delayed GPU memory release, |
| 278 | + which is needed for mxfp8 + overlap_param_gather case where master weights |
| 279 | + must remain on GPU until after _copy_main_params_to_param_buffer() is called. |
| 280 | + """ |
| 281 | + if not self._offloaded: |
| 282 | + return |
| 283 | + |
| 284 | + self._release_states() |
| 285 | + |
| 286 | + def reload(self): |
| 287 | + """ |
| 288 | + Reload optimizer states and/or master weights from CPU to GPU. |
| 289 | + Call before optimizer.step() to ensure states are on GPU. |
| 290 | + """ |
| 291 | + if not self._offloaded: |
| 292 | + return |
| 293 | + |
| 294 | + # Allocate GPU memory on the current stream to avoid fragmentation. |
| 295 | + self._reload_states(is_allocate_stage=True) |
| 296 | + |
| 297 | + self._h2d_stream.wait_stream(self._d2h_stream) |
| 298 | + self._h2d_stream.wait_stream(torch.cuda.current_stream()) |
| 299 | + |
| 300 | + # Reload states on the h2d stream to overlap with other operations. |
| 301 | + with torch.cuda.stream(self._h2d_stream): |
| 302 | + self._reload_states(is_allocate_stage=False) |
| 303 | + |
| 304 | + self._offloaded_state_keys = () |
| 305 | + self._offloaded_mcore_master_weights = False |
| 306 | + self._offloaded = False |
| 307 | + |
| 308 | + def sync_before_step(self): |
| 309 | + """ |
| 310 | + Wait for H2D reload to complete before optimizer.step(). |
| 311 | + Must be called to ensure states are on GPU before optimizer uses them. |
| 312 | +
|
| 313 | + This is separated from reload() to make it possible to move the reload ahead of time. |
| 314 | + """ |
| 315 | + torch.cuda.current_stream().wait_stream(self._h2d_stream) |
0 commit comments