Skip to content

Commit c9df6ef

Browse files
authored
support offloading layers to CPU (#3512) [skip ci]
* support offloading layers to CPU * chore: lint * revert change * update docs
1 parent 0ee98a0 commit c9df6ef

File tree

8 files changed

+360
-1
lines changed

8 files changed

+360
-1
lines changed

docs/gradient_checkpointing.qmd

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
---
2-
title: Gradient Checkpointing and Activation Offloading
2+
title: Gradient Checkpointing, Activation Offloading, and Layer Offloading
33
---
44

55
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
@@ -27,3 +27,33 @@ The `activation_offloading: legacy` naively offloads activations to CPU and with
2727

2828
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
2929
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.
30+
31+
### Enabling Layer Offloading
32+
33+
```yaml
34+
layer_offloading: true
35+
```
36+
37+
Layer offloading reduces GPU memory usage by moving frozen (non-trainable) decoder layer parameters to CPU
38+
and streaming them back to GPU one layer at a time during the forward and backward passes. This is
39+
particularly useful for LoRA/QLoRA training where most of the model's parameters are frozen — only the
40+
trainable adapter weights stay on GPU permanently.
41+
42+
During training, forward and backward hooks on each decoder layer handle the transfer automatically:
43+
44+
- **Forward pass:** Before a layer executes, its frozen params are loaded to GPU. The next layer is
45+
prefetched asynchronously on a separate CUDA stream for overlap.
46+
- **Backward pass:** Same pattern in reverse — the current layer's frozen params are loaded and the
47+
previous layer is prefetched.
48+
49+
After each layer finishes, its frozen params are offloaded back to CPU pinned memory.
50+
51+
This approach trades some CPU-GPU transfer overhead for significant GPU memory savings — the freed memory
52+
is roughly equal to the size of all frozen parameters across all decoder layers, minus one layer's worth
53+
that is kept on GPU at any given time.
54+
55+
**Requirements:**
56+
57+
- CUDA GPU (CPU-only training is not supported for this feature)
58+
- Works with any HuggingFace model architecture that uses decoder layers (Llama, Mistral, Qwen, etc.)
59+
- Best combined with LoRA/QLoRA where most parameters are frozen

docs/optimizations.qmd

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ These techniques save VRAM by changing how activations are handled.
5454
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
5555
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
5656

57+
### Layer Offloading
58+
59+
Offloads frozen (non-trainable) decoder layer parameters to CPU and streams them back to GPU one layer at a time during forward/backward passes using CUDA stream prefetching. Especially effective for LoRA/QLoRA where most parameters are frozen.
60+
61+
- **Config:** `layer_offloading: true`
62+
- **Learn more:** [Layer Offloading Docs](gradient_checkpointing.qmd#enabling-layer-offloading)
63+
5764
### Cut Cross Entropy (CCE)
5865

5966
Reduces VRAM usage by using an optimized cross-entropy loss calculation.

src/axolotl/core/builders/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,8 @@ def _configure_accelerator_config(self, training_args_kwargs: dict):
508508
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
509509

510510
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
511+
if self.cfg.layer_offloading:
512+
training_args_kwargs["layer_offloading"] = True
511513
if self.cfg.activation_offloading is True:
512514
# don't use the HF gradient checkpointing, manually wrap
513515
training_args_kwargs["gradient_checkpointing"] = False

src/axolotl/core/trainers/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
ActivationOffloadingMixin,
3535
CheckpointSaveMixin,
3636
DistributedParallelMixin,
37+
LayerOffloadingMixin,
3738
OptimizerMixin,
3839
PackingMixin,
3940
RngLoaderMixin,
@@ -66,6 +67,7 @@ class AxolotlTrainer(
6667
OptimizerMixin,
6768
RngLoaderMixin,
6869
CheckpointSaveMixin,
70+
LayerOffloadingMixin,
6971
ActivationOffloadingMixin,
7072
DistributedParallelMixin,
7173
Trainer,

src/axolotl/core/trainers/mixins/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from .activation_checkpointing import ActivationOffloadingMixin
66
from .checkpoints import CheckpointSaveMixin
7+
from .layer_offloading import LayerOffloadingMixin
78
from .distributed_parallel import DistributedParallelMixin
89
from .optimizer import OptimizerMixin
910
from .packing import PackingMixin
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
"""
2+
Trainer mixin for layer-wise parameter offloading to CPU.
3+
4+
Offloads frozen (non-trainable) parameters in decoder layers to CPU, then uses
5+
forward/backward hooks to stream them on/off GPU one layer at a time with CUDA
6+
stream prefetching. Trainable parameters (e.g. LoRA weights) stay on GPU always.
7+
8+
Forward: pre-hook loads layer N's frozen params to GPU (prefetches N+1 on
9+
transfer stream), post-hook offloads layer N-1's frozen params.
10+
Backward: same in reverse order.
11+
"""
12+
13+
import contextlib
14+
15+
import torch
16+
import torch.nn as nn
17+
from transformers import Trainer
18+
19+
from axolotl.utils.logging import get_logger
20+
21+
LOG = get_logger(__name__)
22+
23+
24+
def _find_decoder_layers(model: nn.Module) -> tuple[nn.ModuleList | None, list[str]]:
25+
"""Recursively search the model for the decoder layer ModuleList.
26+
27+
Finds any ModuleList whose children have 'DecoderLayer' in their class name.
28+
Handles all common HF architectures including VLM wrappers (e.g. Qwen3.5-MoE
29+
where layers are at model.language_model.layers).
30+
"""
31+
# BFS to find the first ModuleList containing decoder layers
32+
queue = [model]
33+
while queue:
34+
m = queue.pop(0)
35+
for _name, child in m.named_children():
36+
if isinstance(child, nn.ModuleList) and len(child) > 0:
37+
first_type = type(child[0]).__name__
38+
if "DecoderLayer" in first_type or "TransformerBlock" in first_type:
39+
layer_types = list({type(layer).__name__ for layer in child})
40+
return child, layer_types
41+
else:
42+
queue.append(child)
43+
44+
return None, []
45+
46+
47+
def _get_frozen_params(layer: nn.Module) -> list[tuple[str, nn.Parameter]]:
48+
"""Get all non-trainable parameters in a layer."""
49+
return [(n, p) for n, p in layer.named_parameters() if not p.requires_grad]
50+
51+
52+
class LayerOffloadManager:
53+
"""Manages offloading frozen decoder layer params to CPU and streaming
54+
them back during forward/backward with CUDA stream overlap.
55+
56+
Only frozen (requires_grad=False) parameters are offloaded.
57+
Trainable parameters (LoRA weights, etc.) remain on GPU at all times.
58+
"""
59+
60+
def __init__(
61+
self,
62+
model: nn.Module,
63+
num_prefetch: int = 1,
64+
):
65+
self.model = model
66+
self.num_prefetch = num_prefetch
67+
self._hooks: list = []
68+
self._device = None
69+
70+
# Find decoder layers
71+
self.layers, layer_types = _find_decoder_layers(model)
72+
if self.layers is None:
73+
LOG.warning(
74+
"LayerOffloadManager: no decoder layers found, offloading disabled"
75+
)
76+
self.enabled = False
77+
return
78+
79+
self.enabled = True
80+
self.n_layers = len(self.layers)
81+
LOG.info(
82+
f"Layer offloading: found {self.n_layers} layers ({', '.join(layer_types)})"
83+
)
84+
85+
# Determine GPU device
86+
for p in model.parameters():
87+
if p.device.type == "cuda":
88+
self._device = p.device
89+
break
90+
if self._device is None:
91+
LOG.warning("LayerOffloadManager: no CUDA parameters found")
92+
self.enabled = False
93+
return
94+
95+
# Transfer stream for async prefetch
96+
self._transfer_stream = torch.cuda.Stream(device=self._device)
97+
98+
# Track which layers have their frozen params on GPU
99+
self._on_gpu: set[int] = set(range(self.n_layers))
100+
101+
# Cache: frozen param references per layer (list of (name, param) tuples)
102+
self._frozen_params: list[list[tuple[str, nn.Parameter]]] = [
103+
_get_frozen_params(self.layers[i]) for i in range(self.n_layers)
104+
]
105+
106+
# CPU storage: pinned tensors for each layer's frozen params
107+
# Populated on first offload
108+
self._cpu_data: list[dict[str, torch.Tensor]] = [
109+
{} for _ in range(self.n_layers)
110+
]
111+
112+
# Offload all layers upfront
113+
self._offload_all()
114+
115+
# Release cached memory blocks back to the driver
116+
torch.cuda.empty_cache()
117+
118+
def _offload_all(self):
119+
"""Move all frozen params in all decoder layers to CPU."""
120+
mem_before = torch.cuda.memory_allocated(self._device)
121+
for i in range(self.n_layers):
122+
self._offload_layer(i)
123+
mem_after = torch.cuda.memory_allocated(self._device)
124+
freed = (mem_before - mem_after) / 1e6
125+
LOG.info(
126+
f"Layer offloading: offloaded frozen params from {self.n_layers} layers, "
127+
f"freed {freed:.0f} MB GPU memory"
128+
)
129+
130+
def _offload_layer(self, idx: int):
131+
"""Move frozen params of layer idx to CPU pinned memory."""
132+
if idx not in self._on_gpu:
133+
return
134+
for name, param in self._frozen_params[idx]:
135+
if param.device.type != "cuda":
136+
continue
137+
# Allocate pinned CPU tensor on first offload
138+
if name not in self._cpu_data[idx]:
139+
self._cpu_data[idx][name] = torch.empty_like(
140+
param.data, device="cpu", pin_memory=True
141+
)
142+
cpu_buf = self._cpu_data[idx][name]
143+
# Async copy GPU -> CPU (on transfer stream for overlap)
144+
cpu_buf.copy_(param.data, non_blocking=True)
145+
# Point parameter at a dummy CPU tensor to free GPU memory
146+
param.data = cpu_buf
147+
self._on_gpu.discard(idx)
148+
149+
def _load_layer(self, idx: int, stream=None):
150+
"""Move frozen params of layer idx back to GPU."""
151+
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
152+
return
153+
ctx = (
154+
torch.cuda.stream(stream)
155+
if stream is not None
156+
else contextlib.nullcontext()
157+
)
158+
with ctx:
159+
for _name, param in self._frozen_params[idx]:
160+
if param.device.type == "cuda":
161+
continue
162+
gpu_data = param.data.to(self._device, non_blocking=True)
163+
param.data = gpu_data
164+
self._on_gpu.add(idx)
165+
166+
def _prefetch_layer(self, idx: int):
167+
"""Async prefetch layer idx on the transfer stream."""
168+
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
169+
return
170+
self._transfer_stream.wait_stream(torch.cuda.default_stream(self._device))
171+
self._load_layer(idx, stream=self._transfer_stream)
172+
173+
def _wait_transfer(self):
174+
"""Make default stream wait for any in-flight transfers."""
175+
torch.cuda.default_stream(self._device).wait_stream(self._transfer_stream)
176+
177+
def setup_hooks(self):
178+
"""Register forward and backward hooks on each decoder layer."""
179+
if not self.enabled:
180+
return
181+
182+
for idx in range(self.n_layers):
183+
layer = self.layers[idx]
184+
185+
def make_pre_fwd(i):
186+
def hook(module, args):
187+
# Ensure this layer is on GPU
188+
if i not in self._on_gpu:
189+
self._load_layer(i)
190+
self._wait_transfer()
191+
# Prefetch next layer(s)
192+
for offset in range(1, self.num_prefetch + 1):
193+
self._prefetch_layer(i + offset)
194+
195+
return hook
196+
197+
def make_post_fwd(i):
198+
def hook(module, args, output):
199+
# Offload previous layer (no longer needed in forward)
200+
if i > 0:
201+
self._offload_layer(i - 1)
202+
# Offload last layer after forward
203+
if i == self.n_layers - 1:
204+
self._offload_layer(i)
205+
206+
return hook
207+
208+
def make_pre_bwd(i):
209+
def hook(module, grad_output):
210+
# Load this layer for backward
211+
if i not in self._on_gpu:
212+
self._load_layer(i)
213+
self._wait_transfer()
214+
# Prefetch previous layer(s)
215+
for offset in range(1, self.num_prefetch + 1):
216+
self._prefetch_layer(i - offset)
217+
218+
return hook
219+
220+
def make_post_bwd(i):
221+
def hook(module, grad_input, grad_output):
222+
# Offload the layer above
223+
if i < self.n_layers - 1:
224+
self._offload_layer(i + 1)
225+
# Offload first layer after backward
226+
if i == 0:
227+
self._offload_layer(i)
228+
229+
return hook
230+
231+
h1 = layer.register_forward_pre_hook(make_pre_fwd(idx))
232+
h2 = layer.register_forward_hook(make_post_fwd(idx))
233+
h3 = layer.register_full_backward_pre_hook(make_pre_bwd(idx))
234+
h4 = layer.register_full_backward_hook(make_post_bwd(idx))
235+
self._hooks.extend([h1, h2, h3, h4])
236+
237+
def remove_hooks(self):
238+
"""Remove all hooks and restore layers to GPU."""
239+
for h in self._hooks:
240+
h.remove()
241+
self._hooks.clear()
242+
if self.enabled:
243+
for i in range(self.n_layers):
244+
if i not in self._on_gpu:
245+
self._load_layer(i)
246+
247+
def pre_step(self):
248+
"""Called before each training step — ensure layers start offloaded."""
249+
if not self.enabled:
250+
return
251+
for i in list(self._on_gpu):
252+
self._offload_layer(i)
253+
# Prefetch layer 0 for forward
254+
self._prefetch_layer(0)
255+
256+
def post_step(self):
257+
"""Called after each training step — ensure layers are offloaded."""
258+
if not self.enabled:
259+
return
260+
for i in list(self._on_gpu):
261+
self._offload_layer(i)
262+
# Prefetch layer 0 for next step
263+
self._prefetch_layer(0)
264+
265+
266+
class _LayerOffloadContext:
267+
"""Context manager wrapping pre_step / post_step around a training step."""
268+
269+
def __init__(self, manager: LayerOffloadManager):
270+
self.manager = manager
271+
272+
def __enter__(self):
273+
self.manager.pre_step()
274+
return self
275+
276+
def __exit__(self, *args):
277+
self.manager.post_step()
278+
279+
280+
class LayerOffloadingMixin(Trainer):
281+
"""
282+
Trainer mixin class for layer-wise parameter offloading to CPU.
283+
284+
Offloads frozen decoder layer params to CPU at init, then streams them
285+
on/off GPU one layer at a time during each training step.
286+
"""
287+
288+
def __init__(self, *args, **kwargs):
289+
super().__init__(*args, **kwargs)
290+
if getattr(self.args, "layer_offloading", False):
291+
LOG.info("Layer parameter offloading enabled")
292+
self._layer_offload_manager = LayerOffloadManager(
293+
model=self.model,
294+
num_prefetch=1,
295+
)
296+
self._layer_offload_manager.setup_hooks()
297+
self._layer_offload_ctx = _LayerOffloadContext(self._layer_offload_manager)
298+
else:
299+
self._layer_offload_manager = None
300+
self._layer_offload_ctx = contextlib.nullcontext()
301+
302+
def training_step(self, *args, **kwargs):
303+
with self._layer_offload_ctx:
304+
return super().training_step(*args, **kwargs)

0 commit comments

Comments
 (0)