|
| 1 | +""" |
| 2 | +Trainer mixin for layer-wise parameter offloading to CPU. |
| 3 | +
|
| 4 | +Offloads frozen (non-trainable) parameters in decoder layers to CPU, then uses |
| 5 | +forward/backward hooks to stream them on/off GPU one layer at a time with CUDA |
| 6 | +stream prefetching. Trainable parameters (e.g. LoRA weights) stay on GPU always. |
| 7 | +
|
| 8 | +Forward: pre-hook loads layer N's frozen params to GPU (prefetches N+1 on |
| 9 | + transfer stream), post-hook offloads layer N-1's frozen params. |
| 10 | +Backward: same in reverse order. |
| 11 | +""" |
| 12 | + |
| 13 | +import contextlib |
| 14 | + |
| 15 | +import torch |
| 16 | +import torch.nn as nn |
| 17 | +from transformers import Trainer |
| 18 | + |
| 19 | +from axolotl.utils.logging import get_logger |
| 20 | + |
| 21 | +LOG = get_logger(__name__) |
| 22 | + |
| 23 | + |
| 24 | +def _find_decoder_layers(model: nn.Module) -> tuple[nn.ModuleList | None, list[str]]: |
| 25 | + """Recursively search the model for the decoder layer ModuleList. |
| 26 | +
|
| 27 | + Finds any ModuleList whose children have 'DecoderLayer' in their class name. |
| 28 | + Handles all common HF architectures including VLM wrappers (e.g. Qwen3.5-MoE |
| 29 | + where layers are at model.language_model.layers). |
| 30 | + """ |
| 31 | + # BFS to find the first ModuleList containing decoder layers |
| 32 | + queue = [model] |
| 33 | + while queue: |
| 34 | + m = queue.pop(0) |
| 35 | + for _name, child in m.named_children(): |
| 36 | + if isinstance(child, nn.ModuleList) and len(child) > 0: |
| 37 | + first_type = type(child[0]).__name__ |
| 38 | + if "DecoderLayer" in first_type or "TransformerBlock" in first_type: |
| 39 | + layer_types = list({type(layer).__name__ for layer in child}) |
| 40 | + return child, layer_types |
| 41 | + else: |
| 42 | + queue.append(child) |
| 43 | + |
| 44 | + return None, [] |
| 45 | + |
| 46 | + |
| 47 | +def _get_frozen_params(layer: nn.Module) -> list[tuple[str, nn.Parameter]]: |
| 48 | + """Get all non-trainable parameters in a layer.""" |
| 49 | + return [(n, p) for n, p in layer.named_parameters() if not p.requires_grad] |
| 50 | + |
| 51 | + |
| 52 | +class LayerOffloadManager: |
| 53 | + """Manages offloading frozen decoder layer params to CPU and streaming |
| 54 | + them back during forward/backward with CUDA stream overlap. |
| 55 | +
|
| 56 | + Only frozen (requires_grad=False) parameters are offloaded. |
| 57 | + Trainable parameters (LoRA weights, etc.) remain on GPU at all times. |
| 58 | + """ |
| 59 | + |
| 60 | + def __init__( |
| 61 | + self, |
| 62 | + model: nn.Module, |
| 63 | + num_prefetch: int = 1, |
| 64 | + ): |
| 65 | + self.model = model |
| 66 | + self.num_prefetch = num_prefetch |
| 67 | + self._hooks: list = [] |
| 68 | + self._device = None |
| 69 | + |
| 70 | + # Find decoder layers |
| 71 | + self.layers, layer_types = _find_decoder_layers(model) |
| 72 | + if self.layers is None: |
| 73 | + LOG.warning( |
| 74 | + "LayerOffloadManager: no decoder layers found, offloading disabled" |
| 75 | + ) |
| 76 | + self.enabled = False |
| 77 | + return |
| 78 | + |
| 79 | + self.enabled = True |
| 80 | + self.n_layers = len(self.layers) |
| 81 | + LOG.info( |
| 82 | + f"Layer offloading: found {self.n_layers} layers ({', '.join(layer_types)})" |
| 83 | + ) |
| 84 | + |
| 85 | + # Determine GPU device |
| 86 | + for p in model.parameters(): |
| 87 | + if p.device.type == "cuda": |
| 88 | + self._device = p.device |
| 89 | + break |
| 90 | + if self._device is None: |
| 91 | + LOG.warning("LayerOffloadManager: no CUDA parameters found") |
| 92 | + self.enabled = False |
| 93 | + return |
| 94 | + |
| 95 | + # Transfer stream for async prefetch |
| 96 | + self._transfer_stream = torch.cuda.Stream(device=self._device) |
| 97 | + |
| 98 | + # Track which layers have their frozen params on GPU |
| 99 | + self._on_gpu: set[int] = set(range(self.n_layers)) |
| 100 | + |
| 101 | + # Cache: frozen param references per layer (list of (name, param) tuples) |
| 102 | + self._frozen_params: list[list[tuple[str, nn.Parameter]]] = [ |
| 103 | + _get_frozen_params(self.layers[i]) for i in range(self.n_layers) |
| 104 | + ] |
| 105 | + |
| 106 | + # CPU storage: pinned tensors for each layer's frozen params |
| 107 | + # Populated on first offload |
| 108 | + self._cpu_data: list[dict[str, torch.Tensor]] = [ |
| 109 | + {} for _ in range(self.n_layers) |
| 110 | + ] |
| 111 | + |
| 112 | + # Offload all layers upfront |
| 113 | + self._offload_all() |
| 114 | + |
| 115 | + # Release cached memory blocks back to the driver |
| 116 | + torch.cuda.empty_cache() |
| 117 | + |
| 118 | + def _offload_all(self): |
| 119 | + """Move all frozen params in all decoder layers to CPU.""" |
| 120 | + mem_before = torch.cuda.memory_allocated(self._device) |
| 121 | + for i in range(self.n_layers): |
| 122 | + self._offload_layer(i) |
| 123 | + mem_after = torch.cuda.memory_allocated(self._device) |
| 124 | + freed = (mem_before - mem_after) / 1e6 |
| 125 | + LOG.info( |
| 126 | + f"Layer offloading: offloaded frozen params from {self.n_layers} layers, " |
| 127 | + f"freed {freed:.0f} MB GPU memory" |
| 128 | + ) |
| 129 | + |
| 130 | + def _offload_layer(self, idx: int): |
| 131 | + """Move frozen params of layer idx to CPU pinned memory.""" |
| 132 | + if idx not in self._on_gpu: |
| 133 | + return |
| 134 | + for name, param in self._frozen_params[idx]: |
| 135 | + if param.device.type != "cuda": |
| 136 | + continue |
| 137 | + # Allocate pinned CPU tensor on first offload |
| 138 | + if name not in self._cpu_data[idx]: |
| 139 | + self._cpu_data[idx][name] = torch.empty_like( |
| 140 | + param.data, device="cpu", pin_memory=True |
| 141 | + ) |
| 142 | + cpu_buf = self._cpu_data[idx][name] |
| 143 | + # Async copy GPU -> CPU (on transfer stream for overlap) |
| 144 | + cpu_buf.copy_(param.data, non_blocking=True) |
| 145 | + # Point parameter at a dummy CPU tensor to free GPU memory |
| 146 | + param.data = cpu_buf |
| 147 | + self._on_gpu.discard(idx) |
| 148 | + |
| 149 | + def _load_layer(self, idx: int, stream=None): |
| 150 | + """Move frozen params of layer idx back to GPU.""" |
| 151 | + if idx in self._on_gpu or idx < 0 or idx >= self.n_layers: |
| 152 | + return |
| 153 | + ctx = ( |
| 154 | + torch.cuda.stream(stream) |
| 155 | + if stream is not None |
| 156 | + else contextlib.nullcontext() |
| 157 | + ) |
| 158 | + with ctx: |
| 159 | + for _name, param in self._frozen_params[idx]: |
| 160 | + if param.device.type == "cuda": |
| 161 | + continue |
| 162 | + gpu_data = param.data.to(self._device, non_blocking=True) |
| 163 | + param.data = gpu_data |
| 164 | + self._on_gpu.add(idx) |
| 165 | + |
| 166 | + def _prefetch_layer(self, idx: int): |
| 167 | + """Async prefetch layer idx on the transfer stream.""" |
| 168 | + if idx in self._on_gpu or idx < 0 or idx >= self.n_layers: |
| 169 | + return |
| 170 | + self._transfer_stream.wait_stream(torch.cuda.default_stream(self._device)) |
| 171 | + self._load_layer(idx, stream=self._transfer_stream) |
| 172 | + |
| 173 | + def _wait_transfer(self): |
| 174 | + """Make default stream wait for any in-flight transfers.""" |
| 175 | + torch.cuda.default_stream(self._device).wait_stream(self._transfer_stream) |
| 176 | + |
| 177 | + def setup_hooks(self): |
| 178 | + """Register forward and backward hooks on each decoder layer.""" |
| 179 | + if not self.enabled: |
| 180 | + return |
| 181 | + |
| 182 | + for idx in range(self.n_layers): |
| 183 | + layer = self.layers[idx] |
| 184 | + |
| 185 | + def make_pre_fwd(i): |
| 186 | + def hook(module, args): |
| 187 | + # Ensure this layer is on GPU |
| 188 | + if i not in self._on_gpu: |
| 189 | + self._load_layer(i) |
| 190 | + self._wait_transfer() |
| 191 | + # Prefetch next layer(s) |
| 192 | + for offset in range(1, self.num_prefetch + 1): |
| 193 | + self._prefetch_layer(i + offset) |
| 194 | + |
| 195 | + return hook |
| 196 | + |
| 197 | + def make_post_fwd(i): |
| 198 | + def hook(module, args, output): |
| 199 | + # Offload previous layer (no longer needed in forward) |
| 200 | + if i > 0: |
| 201 | + self._offload_layer(i - 1) |
| 202 | + # Offload last layer after forward |
| 203 | + if i == self.n_layers - 1: |
| 204 | + self._offload_layer(i) |
| 205 | + |
| 206 | + return hook |
| 207 | + |
| 208 | + def make_pre_bwd(i): |
| 209 | + def hook(module, grad_output): |
| 210 | + # Load this layer for backward |
| 211 | + if i not in self._on_gpu: |
| 212 | + self._load_layer(i) |
| 213 | + self._wait_transfer() |
| 214 | + # Prefetch previous layer(s) |
| 215 | + for offset in range(1, self.num_prefetch + 1): |
| 216 | + self._prefetch_layer(i - offset) |
| 217 | + |
| 218 | + return hook |
| 219 | + |
| 220 | + def make_post_bwd(i): |
| 221 | + def hook(module, grad_input, grad_output): |
| 222 | + # Offload the layer above |
| 223 | + if i < self.n_layers - 1: |
| 224 | + self._offload_layer(i + 1) |
| 225 | + # Offload first layer after backward |
| 226 | + if i == 0: |
| 227 | + self._offload_layer(i) |
| 228 | + |
| 229 | + return hook |
| 230 | + |
| 231 | + h1 = layer.register_forward_pre_hook(make_pre_fwd(idx)) |
| 232 | + h2 = layer.register_forward_hook(make_post_fwd(idx)) |
| 233 | + h3 = layer.register_full_backward_pre_hook(make_pre_bwd(idx)) |
| 234 | + h4 = layer.register_full_backward_hook(make_post_bwd(idx)) |
| 235 | + self._hooks.extend([h1, h2, h3, h4]) |
| 236 | + |
| 237 | + def remove_hooks(self): |
| 238 | + """Remove all hooks and restore layers to GPU.""" |
| 239 | + for h in self._hooks: |
| 240 | + h.remove() |
| 241 | + self._hooks.clear() |
| 242 | + if self.enabled: |
| 243 | + for i in range(self.n_layers): |
| 244 | + if i not in self._on_gpu: |
| 245 | + self._load_layer(i) |
| 246 | + |
| 247 | + def pre_step(self): |
| 248 | + """Called before each training step — ensure layers start offloaded.""" |
| 249 | + if not self.enabled: |
| 250 | + return |
| 251 | + for i in list(self._on_gpu): |
| 252 | + self._offload_layer(i) |
| 253 | + # Prefetch layer 0 for forward |
| 254 | + self._prefetch_layer(0) |
| 255 | + |
| 256 | + def post_step(self): |
| 257 | + """Called after each training step — ensure layers are offloaded.""" |
| 258 | + if not self.enabled: |
| 259 | + return |
| 260 | + for i in list(self._on_gpu): |
| 261 | + self._offload_layer(i) |
| 262 | + # Prefetch layer 0 for next step |
| 263 | + self._prefetch_layer(0) |
| 264 | + |
| 265 | + |
| 266 | +class _LayerOffloadContext: |
| 267 | + """Context manager wrapping pre_step / post_step around a training step.""" |
| 268 | + |
| 269 | + def __init__(self, manager: LayerOffloadManager): |
| 270 | + self.manager = manager |
| 271 | + |
| 272 | + def __enter__(self): |
| 273 | + self.manager.pre_step() |
| 274 | + return self |
| 275 | + |
| 276 | + def __exit__(self, *args): |
| 277 | + self.manager.post_step() |
| 278 | + |
| 279 | + |
| 280 | +class LayerOffloadingMixin(Trainer): |
| 281 | + """ |
| 282 | + Trainer mixin class for layer-wise parameter offloading to CPU. |
| 283 | +
|
| 284 | + Offloads frozen decoder layer params to CPU at init, then streams them |
| 285 | + on/off GPU one layer at a time during each training step. |
| 286 | + """ |
| 287 | + |
| 288 | + def __init__(self, *args, **kwargs): |
| 289 | + super().__init__(*args, **kwargs) |
| 290 | + if getattr(self.args, "layer_offloading", False): |
| 291 | + LOG.info("Layer parameter offloading enabled") |
| 292 | + self._layer_offload_manager = LayerOffloadManager( |
| 293 | + model=self.model, |
| 294 | + num_prefetch=1, |
| 295 | + ) |
| 296 | + self._layer_offload_manager.setup_hooks() |
| 297 | + self._layer_offload_ctx = _LayerOffloadContext(self._layer_offload_manager) |
| 298 | + else: |
| 299 | + self._layer_offload_manager = None |
| 300 | + self._layer_offload_ctx = contextlib.nullcontext() |
| 301 | + |
| 302 | + def training_step(self, *args, **kwargs): |
| 303 | + with self._layer_offload_ctx: |
| 304 | + return super().training_step(*args, **kwargs) |
0 commit comments