Skip to content

Commit 8e89325

Browse files
committed
support offloading layers to CPU
1 parent 163bd4d commit 8e89325

File tree

8 files changed

+333
-1
lines changed

8 files changed

+333
-1
lines changed

src/axolotl/core/builders/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,8 @@ def _configure_accelerator_config(self, training_args_kwargs: dict):
484484
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
485485

486486
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
487+
if self.cfg.layer_offloading:
488+
training_args_kwargs["layer_offloading"] = True
487489
if self.cfg.activation_offloading is True:
488490
# don't use the HF gradient checkpointing, manually wrap
489491
training_args_kwargs["gradient_checkpointing"] = False

src/axolotl/core/trainers/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ActivationOffloadingMixin,
3434
CheckpointSaveMixin,
3535
DistributedParallelMixin,
36+
LayerOffloadingMixin,
3637
OptimizerMixin,
3738
PackingMixin,
3839
RngLoaderMixin,
@@ -67,6 +68,7 @@ class AxolotlTrainer(
6768
OptimizerMixin,
6869
RngLoaderMixin,
6970
CheckpointSaveMixin,
71+
LayerOffloadingMixin,
7072
ActivationOffloadingMixin,
7173
DistributedParallelMixin,
7274
Trainer,

src/axolotl/core/trainers/mixins/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from .activation_checkpointing import ActivationOffloadingMixin
66
from .checkpoints import CheckpointSaveMixin
7+
from .layer_offloading import LayerOffloadingMixin
78
from .distributed_parallel import DistributedParallelMixin
89
from .optimizer import OptimizerMixin
910
from .packing import PackingMixin
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
"""
2+
Trainer mixin for layer-wise parameter offloading to CPU.
3+
4+
Offloads frozen (non-trainable) parameters in decoder layers to CPU, then uses
5+
forward/backward hooks to stream them on/off GPU one layer at a time with CUDA
6+
stream prefetching. Trainable parameters (e.g. LoRA weights) stay on GPU always.
7+
8+
Forward: pre-hook loads layer N's frozen params to GPU (prefetches N+1 on
9+
transfer stream), post-hook offloads layer N-1's frozen params.
10+
Backward: same in reverse order.
11+
"""
12+
13+
import contextlib
14+
15+
import torch
16+
import torch.nn as nn
17+
from transformers import Trainer
18+
19+
from axolotl.utils.logging import get_logger
20+
21+
LOG = get_logger(__name__)
22+
23+
24+
def _find_decoder_layers(model: nn.Module) -> tuple[nn.ModuleList | None, list[str]]:
25+
"""Recursively search the model for the decoder layer ModuleList.
26+
27+
Finds any ModuleList whose children have 'DecoderLayer' in their class name.
28+
Handles all common HF architectures including VLM wrappers (e.g. Qwen3.5-MoE
29+
where layers are at model.language_model.layers).
30+
"""
31+
# BFS to find the first ModuleList containing decoder layers
32+
queue = [model]
33+
while queue:
34+
m = queue.pop(0)
35+
for name, child in m.named_children():
36+
if isinstance(child, nn.ModuleList) and len(child) > 0:
37+
first_type = type(child[0]).__name__
38+
if "DecoderLayer" in first_type or "TransformerBlock" in first_type:
39+
layer_types = list({type(layer).__name__ for layer in child})
40+
return child, layer_types
41+
else:
42+
queue.append(child)
43+
44+
return None, []
45+
46+
47+
def _get_frozen_params(layer: nn.Module) -> list[tuple[str, nn.Parameter]]:
48+
"""Get all non-trainable parameters in a layer."""
49+
return [(n, p) for n, p in layer.named_parameters() if not p.requires_grad]
50+
51+
52+
class LayerOffloadManager:
53+
"""Manages offloading frozen decoder layer params to CPU and streaming
54+
them back during forward/backward with CUDA stream overlap.
55+
56+
Only frozen (requires_grad=False) parameters are offloaded.
57+
Trainable parameters (LoRA weights, etc.) remain on GPU at all times.
58+
"""
59+
60+
def __init__(
61+
self,
62+
model: nn.Module,
63+
num_prefetch: int = 1,
64+
):
65+
self.model = model
66+
self.num_prefetch = num_prefetch
67+
self._hooks: list = []
68+
self._device = None
69+
70+
# Find decoder layers
71+
self.layers, layer_types = _find_decoder_layers(model)
72+
if self.layers is None:
73+
LOG.warning("LayerOffloadManager: no decoder layers found, offloading disabled")
74+
self.enabled = False
75+
return
76+
77+
self.enabled = True
78+
self.n_layers = len(self.layers)
79+
LOG.info(
80+
f"Layer offloading: found {self.n_layers} layers ({', '.join(layer_types)})"
81+
)
82+
83+
# Determine GPU device
84+
for p in model.parameters():
85+
if p.device.type == "cuda":
86+
self._device = p.device
87+
break
88+
if self._device is None:
89+
LOG.warning("LayerOffloadManager: no CUDA parameters found")
90+
self.enabled = False
91+
return
92+
93+
# Transfer stream for async prefetch
94+
self._transfer_stream = torch.cuda.Stream(device=self._device)
95+
96+
# Track which layers have their frozen params on GPU
97+
self._on_gpu: set[int] = set(range(self.n_layers))
98+
99+
# Cache: frozen param references per layer (list of (name, param) tuples)
100+
self._frozen_params: list[list[tuple[str, nn.Parameter]]] = [
101+
_get_frozen_params(self.layers[i]) for i in range(self.n_layers)
102+
]
103+
104+
# CPU storage: pinned tensors for each layer's frozen params
105+
# Populated on first offload
106+
self._cpu_data: list[dict[str, torch.Tensor]] = [{} for _ in range(self.n_layers)]
107+
108+
# Offload all layers upfront
109+
self._offload_all()
110+
111+
# Release cached memory blocks back to the driver
112+
torch.cuda.empty_cache()
113+
114+
def _offload_all(self):
115+
"""Move all frozen params in all decoder layers to CPU."""
116+
mem_before = torch.cuda.memory_allocated(self._device)
117+
for i in range(self.n_layers):
118+
self._offload_layer(i)
119+
mem_after = torch.cuda.memory_allocated(self._device)
120+
freed = (mem_before - mem_after) / 1e6
121+
LOG.info(
122+
f"Layer offloading: offloaded frozen params from {self.n_layers} layers, "
123+
f"freed {freed:.0f} MB GPU memory"
124+
)
125+
126+
def _offload_layer(self, idx: int):
127+
"""Move frozen params of layer idx to CPU pinned memory."""
128+
if idx not in self._on_gpu:
129+
return
130+
for name, param in self._frozen_params[idx]:
131+
if param.device.type != "cuda":
132+
continue
133+
# Allocate pinned CPU tensor on first offload
134+
if name not in self._cpu_data[idx]:
135+
self._cpu_data[idx][name] = torch.empty_like(
136+
param.data, device="cpu", pin_memory=True
137+
)
138+
cpu_buf = self._cpu_data[idx][name]
139+
# Async copy GPU -> CPU (on transfer stream for overlap)
140+
cpu_buf.copy_(param.data, non_blocking=True)
141+
# Point parameter at a dummy CPU tensor to free GPU memory
142+
param.data = cpu_buf
143+
self._on_gpu.discard(idx)
144+
145+
def _load_layer(self, idx: int, stream=None):
146+
"""Move frozen params of layer idx back to GPU."""
147+
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
148+
return
149+
ctx = torch.cuda.stream(stream) if stream is not None else contextlib.nullcontext()
150+
with ctx:
151+
for name, param in self._frozen_params[idx]:
152+
if param.device.type == "cuda":
153+
continue
154+
gpu_data = param.data.to(self._device, non_blocking=True)
155+
param.data = gpu_data
156+
self._on_gpu.add(idx)
157+
158+
def _prefetch_layer(self, idx: int):
159+
"""Async prefetch layer idx on the transfer stream."""
160+
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
161+
return
162+
self._transfer_stream.wait_stream(torch.cuda.default_stream(self._device))
163+
self._load_layer(idx, stream=self._transfer_stream)
164+
165+
def _wait_transfer(self):
166+
"""Make default stream wait for any in-flight transfers."""
167+
torch.cuda.default_stream(self._device).wait_stream(self._transfer_stream)
168+
169+
def setup_hooks(self):
170+
"""Register forward and backward hooks on each decoder layer."""
171+
if not self.enabled:
172+
return
173+
174+
for idx in range(self.n_layers):
175+
layer = self.layers[idx]
176+
177+
def make_pre_fwd(i):
178+
def hook(module, args):
179+
# Ensure this layer is on GPU
180+
if i not in self._on_gpu:
181+
self._load_layer(i)
182+
self._wait_transfer()
183+
# Prefetch next layer(s)
184+
for offset in range(1, self.num_prefetch + 1):
185+
self._prefetch_layer(i + offset)
186+
return hook
187+
188+
def make_post_fwd(i):
189+
def hook(module, args, output):
190+
# Offload previous layer (no longer needed in forward)
191+
if i > 0:
192+
self._offload_layer(i - 1)
193+
# Offload last layer after forward
194+
if i == self.n_layers - 1:
195+
self._offload_layer(i)
196+
return hook
197+
198+
def make_pre_bwd(i):
199+
def hook(module, grad_output):
200+
# Load this layer for backward
201+
if i not in self._on_gpu:
202+
self._load_layer(i)
203+
self._wait_transfer()
204+
# Prefetch previous layer(s)
205+
for offset in range(1, self.num_prefetch + 1):
206+
self._prefetch_layer(i - offset)
207+
return hook
208+
209+
def make_post_bwd(i):
210+
def hook(module, grad_input, grad_output):
211+
# Offload the layer above
212+
if i < self.n_layers - 1:
213+
self._offload_layer(i + 1)
214+
# Offload first layer after backward
215+
if i == 0:
216+
self._offload_layer(i)
217+
return hook
218+
219+
h1 = layer.register_forward_pre_hook(make_pre_fwd(idx))
220+
h2 = layer.register_forward_hook(make_post_fwd(idx))
221+
h3 = layer.register_full_backward_pre_hook(make_pre_bwd(idx))
222+
h4 = layer.register_full_backward_hook(make_post_bwd(idx))
223+
self._hooks.extend([h1, h2, h3, h4])
224+
225+
def remove_hooks(self):
226+
"""Remove all hooks and restore layers to GPU."""
227+
for h in self._hooks:
228+
h.remove()
229+
self._hooks.clear()
230+
if self.enabled:
231+
for i in range(self.n_layers):
232+
if i not in self._on_gpu:
233+
self._load_layer(i)
234+
235+
def pre_step(self):
236+
"""Called before each training step — ensure layers start offloaded."""
237+
if not self.enabled:
238+
return
239+
for i in list(self._on_gpu):
240+
self._offload_layer(i)
241+
# Prefetch layer 0 for forward
242+
self._prefetch_layer(0)
243+
244+
def post_step(self):
245+
"""Called after each training step — ensure layers are offloaded."""
246+
if not self.enabled:
247+
return
248+
for i in list(self._on_gpu):
249+
self._offload_layer(i)
250+
# Prefetch layer 0 for next step
251+
self._prefetch_layer(0)
252+
253+
254+
class _LayerOffloadContext:
255+
"""Context manager wrapping pre_step / post_step around a training step."""
256+
257+
def __init__(self, manager: LayerOffloadManager):
258+
self.manager = manager
259+
260+
def __enter__(self):
261+
self.manager.pre_step()
262+
return self
263+
264+
def __exit__(self, *args):
265+
self.manager.post_step()
266+
267+
268+
class LayerOffloadingMixin(Trainer):
269+
"""
270+
Trainer mixin class for layer-wise parameter offloading to CPU.
271+
272+
Offloads frozen decoder layer params to CPU at init, then streams them
273+
on/off GPU one layer at a time during each training step.
274+
"""
275+
276+
def __init__(self, *args, **kwargs):
277+
super().__init__(*args, **kwargs)
278+
if getattr(self.args, "layer_offloading", False):
279+
LOG.info("Layer parameter offloading enabled")
280+
self._layer_offload_manager = LayerOffloadManager(
281+
model=self.model,
282+
num_prefetch=1,
283+
)
284+
self._layer_offload_manager.setup_hooks()
285+
self._layer_offload_ctx = _LayerOffloadContext(self._layer_offload_manager)
286+
else:
287+
self._layer_offload_manager = None
288+
self._layer_offload_ctx = contextlib.nullcontext()
289+
290+
def training_step(self, *args, **kwargs):
291+
with self._layer_offload_ctx:
292+
return super().training_step(*args, **kwargs)

src/axolotl/core/training_args_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,11 @@ class AxolotlTrainingMixins:
235235
metadata={"help": "Use activation offloading with CUDA streams for training."},
236236
)
237237

238+
layer_offloading: bool | None = field(
239+
default=None,
240+
metadata={"help": "Offload model layer parameters to CPU during forward, prefetch back during backward."},
241+
)
242+
238243
# multi-modal section
239244

240245
image_size: int | tuple[int, int] | None = field(

src/axolotl/integrations/kernels/constants.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
1616
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
1717
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
18+
"qwen3_5_moe_text": "Qwen3_5MoeSparseMoeBlock",
1819
"qwen3_next": "Qwen3NextSparseMoeBlock",
1920
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
2021
# qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate)
@@ -58,7 +59,16 @@ def resolve_moe_block_classes(model_type: str):
5859

5960
cls_names = entry if isinstance(entry, list) else [entry]
6061
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
61-
module = importlib.import_module(module_path)
62+
try:
63+
module = importlib.import_module(module_path)
64+
except ModuleNotFoundError:
65+
# Text sub-model types (e.g. qwen3_5_moe_text) share the parent module
66+
if model_type.endswith("_text"):
67+
parent_type = model_type.removesuffix("_text")
68+
module_path = f"transformers.models.{parent_type}.modeling_{parent_type}"
69+
module = importlib.import_module(module_path)
70+
else:
71+
raise
6272

6373
classes = []
6474
for cls_name in cls_names:

src/axolotl/loaders/model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,20 @@ def _set_device_map_config(self):
505505
elif not is_ds_zero3:
506506
self.model_kwargs["device_map"] = device_map
507507

508+
# quantize_moe_experts quantizes expert weights on-the-fly during loading,
509+
# so the actual VRAM usage is much less than bf16 estimates.
510+
# When device_map is "auto", accelerate's infer_auto_device_map computes
511+
# the device map at bf16 size (before quantization), causing it to offload
512+
# layers to CPU, which BnB then rejects. Force single-GPU placement to
513+
# prevent this. Only applies to the non-FSDP, non-ZeRO3 path (DDP/single).
514+
if getattr(self.cfg, "quantize_moe_experts", False) and device_map in (
515+
"auto",
516+
None,
517+
):
518+
self.model_kwargs["device_map"] = {
519+
"": int(os.environ.get("LOCAL_RANK", 0))
520+
}
521+
508522
cur_device = get_device_type()
509523
if "mps" in str(cur_device):
510524
self.model_kwargs["device_map"] = "mps:0"

src/axolotl/utils/schemas/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,12 @@ class AxolotlInputConfig(
433433
"description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'."
434434
},
435435
)
436+
layer_offloading: bool | None = Field(
437+
default=False,
438+
json_schema_extra={
439+
"description": "Offload model layer parameters to CPU during forward, prefetch back during backward."
440+
},
441+
)
436442

437443
unfrozen_parameters: list[str] | None = Field(
438444
default=None,

0 commit comments

Comments
 (0)