Skip to content

support offloading layers to CPU#3512

Merged
winglian merged 4 commits intomainfrom
layer-offloading
Mar 22, 2026
Merged

support offloading layers to CPU#3512
winglian merged 4 commits intomainfrom
layer-offloading

Conversation

@winglian
Copy link
Collaborator

@winglian winglian commented Mar 19, 2026

Description

Similar to Zero-3 offloading, but seems to work better with 4bit. With qwen3.5 35B w qlora, uses 26GB reserved vs ~29GB reserved w/o offloading

Summary by CodeRabbit

  • New Features

    • Added layer offloading capability for GPU memory optimization during model training, enabling CPU offloading of model parameters with automatic prefetching during backward passes.
    • Added support for Qwen3.5 MoE models.
    • Improved device mapping configuration for quantized MoE expert training.
  • Configuration

    • Added layer_offloading training argument to enable/disable layer offloading functionality.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 19, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 485d8a45-413b-4c06-9cbf-aa8624d9ebfd

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Introduced layer offloading functionality to enable GPU-to-CPU streaming of frozen model parameters during training. Added a new LayerOffloadingMixin trainer class, integrated it into the trainer base and builder configuration, created corresponding training argument fields, and updated MOE kernel support for Qwen models.

Changes

Cohort / File(s) Summary
Core Layer Offloading Implementation
src/axolotl/core/trainers/mixins/layer_offloading.py, src/axolotl/core/trainers/mixins/__init__.py, src/axolotl/core/trainers/base.py
Added new LayerOffloadingMixin class with hook-based parameter streaming, LayerOffloadManager for orchestrating offload/load cycles, and helper functions to identify decoder layers and frozen parameters. Integrated mixin into AxolotlTrainer base class and exported via package __init__.
Configuration and Training Arguments
src/axolotl/core/training_args_base.py, src/axolotl/utils/schemas/config.py
Added new boolean configuration field layer_offloading to both training arguments and input config schema, with descriptions indicating CPU offloading behavior during forward/backward passes.
Builder Integration
src/axolotl/core/builders/base.py
Added conditional propagation of layer_offloading flag from config to training arguments within _configure_gradient_checkpointing method.
Model Loading and MOE Support
src/axolotl/loaders/model.py, src/axolotl/integrations/kernels/constants.py
Updated device mapping logic to force single-GPU placement when MOE expert quantization is enabled. Added qwen3_5_moe_text mapping to sparse MOE block constants and fallback import logic for model types ending in _text.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

  • #2900: Modifies _configure_gradient_checkpointing in the same builder file to propagate offloading-related flags, paralleling the layer offloading propagation logic.
  • #2718: Refactors and auto-generates documentation for AxolotlInputConfig, the same configuration schema file being extended here.

Suggested reviewers

  • NanoCode012
  • djsaunde
  • SalmanMohammadi
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 44.83% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title 'support offloading layers to CPU' directly matches the primary feature being implemented across all modified files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch layer-offloading
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Contributor

github-actions bot commented Mar 19, 2026

📖 Documentation Preview: https://69bd4734298f2d3840e25b45--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit c0f30cd

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/core/trainers/mixins/layer_offloading.py`:
- Around line 83-95: The current LayerOffloadManager assumes a single CUDA
device by picking the first p.device and setting
self._device/self._transfer_stream, which breaks models sharded across GPUs;
update LayerOffloadManager to detect multi-device models by scanning
model.parameters() for all unique cuda devices and either (a) fail fast with a
clear error when more than one CUDA device is found, or (b) record the original
device per layer/tensor (e.g., map parameter -> device or module name -> device)
and use that mapping when offloading/rehydrating so each tensor is copied back
to its original device; if you choose per-device handling also create/lookup a
torch.cuda.Stream per device instead of a single self._transfer_stream so async
transfers use the correct device stream.
- Around line 244-251: The race occurs because post_step() launches
_prefetch_layer(0) asynchronously but _load_layer() marks the layer as resident
before the async copy finishes, allowing pre_step() to offload that same layer;
fix by changing the offload/prefetch lifecycle so a layer is only marked
resident after the async CPU→GPU transfer completes (or conversely, prefetch
must set a separate "prefetching" state and _load_layer()/pre_step() must wait
for that transfer). Concretely, update _prefetch_layer(), _load_layer(), and the
residency bookkeeping (e.g., _on_gpu, _resident flags) to: 1) create/record a
cuda Event on the transfer stream when launching the copy, 2) store that event
in a per-layer structure (e.g., layer_transfer_event), 3) only set the layer as
resident when the event has been recorded/completed (or have
_offload_layer/_load_layer check event.synchronize() or query event.completed()
before proceeding), and 4) ensure pre_step() checks/waits on that event (or the
"prefetching" flag) before offloading the layer.

In `@src/axolotl/loaders/model.py`:
- Around line 508-520: The current guard for quantize_moe_experts only handles
device_map when it's exactly "auto" or None; extend it to also detect when
device_map is a dict produced by infer_auto_device_map that contains non-GPU
placements (e.g., "cpu" or "disk") and treat those cases the same by forcing
single-GPU placement. Update the condition around quantize_moe_experts to check
if device_map is a dict and any(value in ("cpu", "disk") for value in
device_map.values()) (or equivalent string checks), and if so set
self.model_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", 0))}
just like the existing branch for "auto"/None; keep references to
quantize_moe_experts, device_map, infer_auto_device_map, and
self.model_kwargs["device_map"] so the change is easy to locate.

In `@src/axolotl/utils/schemas/config.py`:
- Around line 436-441: The new layer_offloading Field can be enabled
concurrently with activation_offloading causing both LayerOffloadingMixin and
ActivationOffloadingMixin to wrap training_step; add a validation that rejects
configs where both layer_offloading and activation_offloading are truthy.
Implement this in the config model (the class containing the layer_offloading
and activation_offloading Fields) using a pydantic validator or root_validator
that checks both flags and raises a ValueError with a clear message if both are
enabled so the builder/trainer cannot compose both mixins simultaneously.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: c54f67e3-e04e-44e6-a3db-b16dad7d0beb

📥 Commits

Reviewing files that changed from the base of the PR and between 5ef3f28 and 67e0cef.

📒 Files selected for processing (8)
  • src/axolotl/core/builders/base.py
  • src/axolotl/core/trainers/base.py
  • src/axolotl/core/trainers/mixins/__init__.py
  • src/axolotl/core/trainers/mixins/layer_offloading.py
  • src/axolotl/core/training_args_base.py
  • src/axolotl/integrations/kernels/constants.py
  • src/axolotl/loaders/model.py
  • src/axolotl/utils/schemas/config.py

Comment on lines +83 to +95
# Determine GPU device
for p in model.parameters():
if p.device.type == "cuda":
self._device = p.device
break
if self._device is None:
LOG.warning("LayerOffloadManager: no CUDA parameters found")
self.enabled = False
return

# Transfer stream for async prefetch
self._transfer_stream = torch.cuda.Stream(device=self._device)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast on multi-device models, or track the original device per layer.

LayerOffloadManager picks the first CUDA device it sees and later reloads every offloaded tensor onto that single device. Any model that was intentionally sharded across GPUs will bring some layers back on the wrong GPU.

💡 Minimal safe guard
-        # Determine GPU device
-        for p in model.parameters():
-            if p.device.type == "cuda":
-                self._device = p.device
-                break
+        # Determine GPU device
+        devices = {p.device for p in model.parameters() if p.device.type == "cuda"}
+        if len(devices) > 1:
+            raise ValueError(
+                "layer_offloading currently supports a single CUDA device per process"
+            )
+        if devices:
+            self._device = next(iter(devices))
         if self._device is None:
             LOG.warning("LayerOffloadManager: no CUDA parameters found")
             self.enabled = False
             return
🧰 Tools
🪛 GitHub Actions: lint

[error] pre-commit hook failed: ruff-format modified files (2 files reformatted).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/core/trainers/mixins/layer_offloading.py` around lines 83 - 95,
The current LayerOffloadManager assumes a single CUDA device by picking the
first p.device and setting self._device/self._transfer_stream, which breaks
models sharded across GPUs; update LayerOffloadManager to detect multi-device
models by scanning model.parameters() for all unique cuda devices and either (a)
fail fast with a clear error when more than one CUDA device is found, or (b)
record the original device per layer/tensor (e.g., map parameter -> device or
module name -> device) and use that mapping when offloading/rehydrating so each
tensor is copied back to its original device; if you choose per-device handling
also create/lookup a torch.cuda.Stream per device instead of a single
self._transfer_stream so async transfers use the correct device stream.

Comment on lines +244 to +251
def post_step(self):
"""Called after each training step — ensure layers are offloaded."""
if not self.enabled:
return
for i in list(self._on_gpu):
self._offload_layer(i)
# Prefetch layer 0 for next step
self._prefetch_layer(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

The cross-step prefetch can race with the next pre_step().

post_step() queues _prefetch_layer(0) on the transfer stream, but _load_layer() marks that layer as resident before the async copy completes. On the very next batch, pre_step() iterates _on_gpu and can immediately _offload_layer(0) without waiting for that transfer, so CPU→GPU and GPU→CPU copies can overlap on the same parameter storage.

💡 One safe fix
     def pre_step(self):
         """Called before each training step — ensure layers start offloaded."""
         if not self.enabled:
             return
+        self._wait_transfer()
         for i in list(self._on_gpu):
             self._offload_layer(i)
         # Prefetch layer 0 for forward
         self._prefetch_layer(0)

     def post_step(self):
         """Called after each training step — ensure layers are offloaded."""
         if not self.enabled:
             return
         for i in list(self._on_gpu):
             self._offload_layer(i)
-        # Prefetch layer 0 for next step
-        self._prefetch_layer(0)
🧰 Tools
🪛 GitHub Actions: lint

[error] pre-commit hook failed: ruff-format modified files (2 files reformatted).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/core/trainers/mixins/layer_offloading.py` around lines 244 - 251,
The race occurs because post_step() launches _prefetch_layer(0) asynchronously
but _load_layer() marks the layer as resident before the async copy finishes,
allowing pre_step() to offload that same layer; fix by changing the
offload/prefetch lifecycle so a layer is only marked resident after the async
CPU→GPU transfer completes (or conversely, prefetch must set a separate
"prefetching" state and _load_layer()/pre_step() must wait for that transfer).
Concretely, update _prefetch_layer(), _load_layer(), and the residency
bookkeeping (e.g., _on_gpu, _resident flags) to: 1) create/record a cuda Event
on the transfer stream when launching the copy, 2) store that event in a
per-layer structure (e.g., layer_transfer_event), 3) only set the layer as
resident when the event has been recorded/completed (or have
_offload_layer/_load_layer check event.synchronize() or query event.completed()
before proceeding), and 4) ensure pre_step() checks/waits on that event (or the
"prefetching" flag) before offloading the layer.

Comment on lines +508 to +520
# quantize_moe_experts quantizes expert weights on-the-fly during loading,
# so the actual VRAM usage is much less than bf16 estimates.
# When device_map is "auto", accelerate's infer_auto_device_map computes
# the device map at bf16 size (before quantization), causing it to offload
# layers to CPU, which BnB then rejects. Force single-GPU placement to
# prevent this. Only applies to the non-FSDP, non-ZeRO3 path (DDP/single).
if getattr(self.cfg, "quantize_moe_experts", False) and device_map in (
"auto",
None,
):
self.model_kwargs["device_map"] = {
"": int(os.environ.get("LOCAL_RANK", 0))
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Handle inferred dict device maps here as well.

When max_memory/gpu_memory_limit is set, the code above already turns device_map into a dict via infer_auto_device_map(). That dict can still contain "cpu"/"disk" placements, but this guard only catches the literal "auto"/None cases, so quantize_moe_experts still falls into the same BnB CPU-offload failure path for memory-limited configs.

💡 Suggested fix
-            if getattr(self.cfg, "quantize_moe_experts", False) and device_map in (
-                "auto",
-                None,
-            ):
+            if getattr(self.cfg, "quantize_moe_experts", False) and (
+                device_map in ("auto", None)
+                or (
+                    isinstance(device_map, dict)
+                    and any(
+                        str(dst).startswith(("cpu", "disk"))
+                        for dst in device_map.values()
+                    )
+                )
+            ):
                 self.model_kwargs["device_map"] = {
                     "": int(os.environ.get("LOCAL_RANK", 0))
                 }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/loaders/model.py` around lines 508 - 520, The current guard for
quantize_moe_experts only handles device_map when it's exactly "auto" or None;
extend it to also detect when device_map is a dict produced by
infer_auto_device_map that contains non-GPU placements (e.g., "cpu" or "disk")
and treat those cases the same by forcing single-GPU placement. Update the
condition around quantize_moe_experts to check if device_map is a dict and
any(value in ("cpu", "disk") for value in device_map.values()) (or equivalent
string checks), and if so set self.model_kwargs["device_map"] = {"":
int(os.environ.get("LOCAL_RANK", 0))} just like the existing branch for
"auto"/None; keep references to quantize_moe_experts, device_map,
infer_auto_device_map, and self.model_kwargs["device_map"] so the change is easy
to locate.

@codecov
Copy link

codecov bot commented Mar 19, 2026

Codecov Report

❌ Patch coverage is 19.88304% with 137 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...c/axolotl/core/trainers/mixins/layer_offloading.py 18.07% 136 Missing ⚠️
src/axolotl/core/builders/base.py 50.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@NanoCode012
Copy link
Collaborator

Could we have more docs on this? I'm not sure how a user could use it and whether it's limited to single gpu or also fsdp/deepspeed/ddp.

I also saw some possibly unrelated changes for quantize moe experts and device_map.

@winglian winglian added the scheduled_release This PR is slated for the upcoming release label Mar 21, 2026
@winglian winglian merged commit c9df6ef into main Mar 22, 2026
14 of 16 checks passed
@winglian winglian deleted the layer-offloading branch March 22, 2026 02:47
@winglian winglian removed the scheduled_release This PR is slated for the upcoming release label Mar 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants