From 1439513e8d59a2aa98ec3ef98111b3903696e255 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Fri, 19 Dec 2025 14:45:37 +0800 Subject: [PATCH 1/3] [feat] support activation cpu offload in fsdp and fsdp2 [feat] support activation cpu offload in fsdp and fsdp2 lint fix feat(examples): update FSDP2 config and training script for activation CPU offload - Refactor FSDP2 JSON config to include detailed documentation and recommended settings - Add activation_cpu_offload parameter to FSDP config for memory optimization - Update training script to use new Swift CLI syntax and adjusted hyperparameters feat(plugin): update FSDP version key and add gradient requirement for checkpointing - Change key from 'fsdp_version' to 'version' in fsdp_config for consistency - Add call to model.enable_input_require_grads() when activation checkpointing is enabled to ensure proper gradient computation during CPU offloading feat(plugin): fix whitespace in activation CPU offload callback docs: remove activation_cpu_offload parameter documentation --- .../train/activation_cpu_offload/fsdp2.json | 26 + .../train/activation_cpu_offload/train.sh | 27 + swift/llm/train/sft.py | 5 + swift/plugin/__init__.py | 2 + swift/plugin/activation_cpu_offload.py | 612 ++++++++++++++++++ 5 files changed, 672 insertions(+) create mode 100644 examples/train/activation_cpu_offload/fsdp2.json create mode 100644 examples/train/activation_cpu_offload/train.sh create mode 100644 swift/plugin/activation_cpu_offload.py diff --git a/examples/train/activation_cpu_offload/fsdp2.json b/examples/train/activation_cpu_offload/fsdp2.json new file mode 100644 index 0000000000..73d856389a --- /dev/null +++ b/examples/train/activation_cpu_offload/fsdp2.json @@ -0,0 +1,26 @@ +{ + "_description": "FSDP2 configuration for distributed training (PyTorch native FSDP v2)", + "_requires": "torch>=2.4.0", + "_note": "This is the recommended configuration for multi-GPU training without CPU offloading. NOTE: When using FSDP2, do NOT use --gradient_checkpointing, use activation_checkpointing in fsdp_config instead.", + + "_param_docs": { + "fsdp": "FSDP strategy string. Options: 'full_shard' (ZeRO-3 style, shards params+grads+optimizer), 'shard_grad_op' (ZeRO-2 style, shards grads+optimizer only). Add 'auto_wrap' to enable automatic layer wrapping. Add 'offload' to enable CPU offloading.", + "fsdp_version": "FSDP version. Use 2 for PyTorch native FSDP2 (recommended). FSDP2 uses DTensor for per-parameter sharding, supports LoRA/QLoRA natively.", + "auto_wrap_policy": "How to wrap model layers. 'TRANSFORMER_BASED_WRAP' wraps transformer decoder layers (from model._no_split_modules). 'SIZE_BASED_WRAP' wraps modules exceeding min_num_params.", + "cpu_ram_efficient_loading": "If true, only rank 0 loads full model weights, then broadcasts to other ranks. Reduces CPU RAM usage during initialization.", + "state_dict_type": "'SHARDED_STATE_DICT' (recommended): each rank saves its own shard without extra communication. 'FULL_STATE_DICT': gathers full model on rank 0 (higher memory, slower).", + "reshard_after_forward": "true = FULL_SHARD (ZeRO-3), reshards params after forward pass. false = SHARD_GRAD_OP (ZeRO-2), keeps params gathered during forward/backward.", + "activation_checkpointing": "Use FSDP's native activation checkpointing instead of gradient_checkpointing. This is the correct way to save memory with FSDP.", + "activation_cpu_offload": "true = offload activations to CPU. false = keep activations on GPU,can enable when using activation_checkpointing." + }, + "fsdp": "full_shard auto_wrap", + "fsdp_config": { + "fsdp_version": 2, + "reshard_after_forward": true, + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "cpu_ram_efficient_loading": true, + "state_dict_type": "SHARDED_STATE_DICT", + "activation_checkpointing": false, + "activation_cpu_offload": true + } +} diff --git a/examples/train/activation_cpu_offload/train.sh b/examples/train/activation_cpu_offload/train.sh new file mode 100644 index 0000000000..e5fee8e54c --- /dev/null +++ b/examples/train/activation_cpu_offload/train.sh @@ -0,0 +1,27 @@ +#!/bin/bash +CUDA_VISIBLE_DEVICES=0,1 \ +swift sft \ + --model 'Qwen/Qwen3-0.6B' \ + --dataset 'swift/self-cognition#1000' \ \ + --load_from_cache_file true \ + --split_dataset_ratio 0.01 \ + --train_type lora \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-4 \ + --lora_rank 8 \ + --lora_alpha 32 \ + --target_modules all-linear \ + --freeze_vit true \ + --gradient_accumulation_steps 16 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --fsdp './examples/train/activation_cpu_offload/fsdp2.json' diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 841bdb9ffa..9f6a2990a1 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -265,6 +265,7 @@ def train(self, trainer): @RayHelper.function(group='default') def _prepare_callbacks(self): from .callback import DynamicLayerActivationCallback, TrainerAdapterCallback + from swift.plugin import ActivationCpuOffloadCallBack args = self.args callbacks = [] if args.lisa_activated_layers > 0: @@ -275,6 +276,10 @@ def _prepare_callbacks(self): model=self.model) lisa_callback.switch_active_layers() # Make trainable parameters printing a correct value callbacks.append(lisa_callback) + # Check activation_cpu_offload from fsdp_config + fsdp_config = getattr(self.args, 'fsdp_config', {}) + if isinstance(fsdp_config, dict) and fsdp_config.get('activation_cpu_offload', False): + callbacks.append(ActivationCpuOffloadCallBack()) if args.is_adapter and args.train_type == 'adalora': callbacks.append(TrainerAdapterCallback(args)) diff --git a/swift/plugin/__init__.py b/swift/plugin/__init__.py index 870ece61cd..63cb2fc0e7 100644 --- a/swift/plugin/__init__.py +++ b/swift/plugin/__init__.py @@ -17,6 +17,7 @@ from .rm_plugin import rm_plugins from .env import envs, Env from .context_manager import context_managers, ContextManager + from swift.plugin.activation_cpu_offload import ActivationCpuOffloadCallBack else: _import_structure = { @@ -34,6 +35,7 @@ 'rm_plugin': ['rm_plugins'], 'env': ['envs', 'Env'], 'context_manager': ['context_managers', 'ContextManager'], + 'activation_cpu_offload': ['ActivationCpuOffloadCallBack'], } import sys diff --git a/swift/plugin/activation_cpu_offload.py b/swift/plugin/activation_cpu_offload.py new file mode 100644 index 0000000000..dc348805dc --- /dev/null +++ b/swift/plugin/activation_cpu_offload.py @@ -0,0 +1,612 @@ +"""Functionality for CPU offloading of tensors saved for backward pass.""" +from __future__ import annotations +import functools +import logging +import os +from contextlib import nullcontext +from typing import Any, Dict, Optional + +import torch +from torch.distributed.fsdp import FSDPModule as FSDP2 +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from transformers import TrainerCallback +from transformers.trainer_callback import TrainerControl, TrainerState +from transformers.training_args import TrainingArguments + +from swift.utils import get_logger + +logger = get_logger() +logger.setLevel(logging.WARNING) + + +def is_torch_npu_available() -> bool: + """Check the availability of NPU""" + try: + if hasattr(torch, 'npu') and callable(getattr(torch.npu, 'is_available', None)): + return torch.npu.is_available() + return False + except ImportError: + return False + + +is_cuda_available = torch.cuda.is_available() +is_npu_available = is_torch_npu_available() + + +def _get_unique_tensor_key(tensor): + key = (tensor.untyped_storage().data_ptr() + tensor.storage_offset(), tensor.dtype) + return key + + +def get_device_name() -> str: + """Function that gets the torch.device based on the current machine. + This currently only supports CPU, CUDA, NPU. + Returns: + device + """ + if is_cuda_available: + device = 'cuda' + elif is_npu_available: + device = 'npu' + else: + device = 'cpu' + return device + + +class FSDPParameterFilter: + + def __init__(self): + self.model_parameters_storage = set() + + def __call__(self, tensor): + return tensor.untyped_storage().data_ptr() not in self.model_parameters_storage + + def update_model_parameters(self, model): + new_storage = set() + for p in model.parameters(): + new_storage.add(p.data.untyped_storage().data_ptr()) + self.model_parameters_storage = new_storage + + +def get_torch_device() -> any: + """Return the corresponding torch attribute based on the device type string. + Returns: + module: The corresponding torch device namespace, or torch.cuda if not found. + """ + device_name = get_device_name() + try: + return getattr(torch, device_name) + except AttributeError: + logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.") + return torch.cuda + + +class CpuOffloadHookWithOffloadHandler: + """Context-manager that offloads/recovers tensors through an offload hander. + + The hook just offloads/recovers the tensor object to the handler through `tensor_push` + and `tensor_pop` interface. How the offload-handler manages the offloading, recovering + or prefetching timing is transparent to this hook. + """ + + def __init__( + self, + offload_handler: OffloadHandler, + handler_extra_kwargs: Optional[dict[str, Any]] = None, + ) -> None: + if handler_extra_kwargs is None: + handler_extra_kwargs = {} + self.offload_handler: OffloadHandler = offload_handler + self.handler_extra_kwargs: dict[str, Any] = handler_extra_kwargs + self.inside_context = False + + def __enter__(self): + self.inside_context = True + torch._C._autograd._push_saved_tensors_default_hooks(self.on_save_for_backward, self.on_get_saved_tensor) + + def __exit__(self, *args: Any): + self.inside_context = False + torch._C._autograd._pop_saved_tensors_default_hooks() + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) + return retrieve_identifier + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs) + return tensor + + +class OffloadHandler: + """A base class for CPU offload-handler.""" + + def __init__(self) -> None: + pass + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + """Tensor push.""" + raise NotImplementedError( + '`tensor_push is not implented in OffloadHandler class. Inherit this class and implement your ' + 'custom tensor_push.') + + def tensor_pop(self, tensor_tag: Any, **kwargs): + """Tensor pop.""" + raise NotImplementedError( + '`tensor_pop is not implented in OffloadHandler class. Inherit this class and implement your ' + 'custom tensor_pop.') + + +class GroupCommitFunction(torch.autograd.Function): + """this is a dummy op with output identical to input. + However, it is necessary for marking a timepoint for offload handler to + accomplish all synchronizations. Implementing it as a function is necessary + because we need to actions in both forward and backward. + """ + + @staticmethod + def forward(ctx, tensor, cpu_offload_handler): + # pylint: disable=missing-function-docstring + cpu_offload_handler.on_group_commit_forward() + ctx.cpu_offload_handler = cpu_offload_handler + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, grad_output): + # pylint: disable=missing-function-docstring + cpu_offload_handler = ctx.cpu_offload_handler + cpu_offload_handler.on_group_commit_backward() + return grad_output, None + + +group_prefetch_offload_commit = GroupCommitFunction.apply + + +class SynchronizedGroupOffloadHandler(OffloadHandler): + """Offload Handler that offloads/reloads in a synchronized way. + The device-to-host and host-to-device copying happen in the same stream + as the computation kernels, thus the copying will block computation. + """ + + def __init__(self, num_offload_group, tensor_need_offloading_checker=(lambda _: True)) -> None: + super().__init__() + + self.num_offload_group = num_offload_group + self.tensor_need_offloading_checker = tensor_need_offloading_checker + + self.groupid_reset() + + def groupid_reset(self): + """Groupid reset.""" + # Data structures to label saved tensors and book-keep their cpu copies. + # Currently, on push, create a new cpu tensor and copies; on pop, copies + # the tensor back to gpu and deletes the cpu tensor. + # These will increment whenever `group_commit()` is invoked + self.current_group, self.tensor_count_current_group = (0, 0) + self.torch_tensor_count = 0 + self.tensor_tag_to_state = {} + + def on_group_commit_forward(self): + """On group commit forward.""" + # finishing up with updating current group and tensor count + self.current_group += 1 # increment + self.tensor_count_current_group = 0 # reset + + def on_group_commit_backward(self): + """On group commit backward.""" + self.current_group -= 1 + assert self.current_group >= 0 + + @staticmethod + def offload(src_tensor, pin_memory=True): + """Offload.""" + + cpu_backup = torch.empty( + src_tensor.size(), + dtype=src_tensor.dtype, + layout=src_tensor.layout, + device='cpu', + pin_memory=pin_memory, + ) + cpu_backup.copy_(src_tensor, non_blocking=True) + state = (src_tensor.device, cpu_backup) + return state + + @staticmethod + def reload(state, non_blocking=None): + """Reload.""" + dev, cpu_backup = state + if non_blocking is None: + non_blocking = cpu_backup.is_pinned() + return cpu_backup.to(dev, non_blocking=non_blocking) + + def tensor_push(self, tensor: torch.Tensor, **kwargs): + """Tensor push.""" + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + assert tensor_tag not in self.tensor_tag_to_state + if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(tensor): + state = SynchronizedGroupOffloadHandler.offload(tensor) + self.tensor_tag_to_state[tensor_tag] = state + else: + # will be offloaded together after group commit + self.tensor_tag_to_state[tensor_tag] = tensor + + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + assert tensor_tag in self.tensor_tag_to_state + state = self.tensor_tag_to_state.pop(tensor_tag) + if isinstance(state, tuple): + tensor = SynchronizedGroupOffloadHandler.reload(state) + else: + tensor = state + return tensor + + +class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): + """Compared to synchronize, this uses more memory because of the buffer but + achieves better performance due to the overlapping. D2h and h2d copying are + completely hidden behind computation if computation time of a layer is longer + than host-device communication time. Bulk offloading with delay and bulk reloading + with prefetch are implemented.""" + + def __init__( + self, + num_offload_group, # must be <= actual number of groups (number of commits) + num_model_group, + tensor_need_offloading_checker=(lambda t: True), + ) -> None: + super().__init__( + num_offload_group=num_offload_group, + tensor_need_offloading_checker=tensor_need_offloading_checker, + ) + # Number of layers in the model + self.num_layers = num_model_group + # Data Structure to maintain reference to activation tensors + self.tensor_tag_to_buf = {} + # Tracking the number of layers offloaded + self.offloaded_group_count = 0 + # Core data structure that decides the window for offloading + self.layer_window_map = {} + self.group_offload_mapping = {} + + # Logic to make offloading load balance across computation + # for optimal CPU/GPU interconnect usage + constant = 0 + for i in range(self.num_offload_group): + self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1 + if i < (self.num_layers % self.num_offload_group): + self.layer_window_map[i] += i + 1 + constant = i + 1 + else: + self.layer_window_map[i] += constant + + # allocate streams and events for synchronization + self.d2h_stream = get_torch_device().Stream() + self.h2d_stream = get_torch_device().Stream() + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + torch_stray_tensor = isinstance( + tensor, + torch._subclasses.fake_tensor.FakeTensor | torch._subclasses.functional_tensor.FunctionalTensor, + ) + need_offload = not torch_stray_tensor + need_offload = need_offload and self.tensor_need_offloading_checker(tensor) + + if need_offload: + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + + assert tensor_tag not in self.tensor_tag_to_state + self.tensor_tag_to_state[tensor_tag] = tensor + + if self.current_group < self.num_offload_group: + self.tensor_tag_to_buf[tensor_tag] = tensor + else: + tensor_tag = tensor + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + if isinstance(tensor_tag, torch.Tensor): + return tensor_tag + assert tensor_tag in self.tensor_tag_to_state + tensor = self.tensor_tag_to_state.pop(tensor_tag) + self.tensor_tag_to_buf.pop(tensor_tag, None) + + # the tensor should have been copied back in on_group_commit_backward() + # which invokes bulk_reload_group. + assert not isinstance(tensor, tuple) + return tensor + + def bulk_offload_group(self, group_to_offload): + """Bulk offload group.""" + offload_mapping = {} + offload_size = 0 + with get_torch_device().stream(self.d2h_stream): + for tensor_tag, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_tag + if group_id == group_to_offload: + assert not isinstance(state, tuple) + key = _get_unique_tensor_key(state) + if key not in offload_mapping: + offload_mapping[key] = state + # if offload, return the reference to cpu copy + self.tensor_tag_to_state[tensor_tag] = (key, state.shape) + for key, tensor in offload_mapping.items(): + state = SynchronizedGroupOffloadHandler.offload(tensor) + offload_size += tensor.numel() * tensor.element_size() + offload_mapping[key] = state + + self.group_offload_mapping[group_to_offload] = offload_mapping + + def synchronize_on_group_commit_forward(self, current_group): + """Synchronize on group commit forward.""" + + # For the first group, kickstart the offload after we have + # the first compute completion + if current_group == 0: + self.d2h_stream.wait_stream(get_torch_device().current_stream()) + self.bulk_offload_group(current_group) + + # Window map data structure helps us synchronize based on number + # of layers offloaded + if self.layer_window_map[self.offloaded_group_count] == current_group: + # Stream synchronization both ways + self.d2h_stream.wait_stream(get_torch_device().current_stream()) + get_torch_device().current_stream().wait_stream(self.d2h_stream) + + # Time to free the activation memory after usage + for tensor_tag, _ in self.tensor_tag_to_buf.items(): + if tensor_tag[0] == self.offloaded_group_count: + self.tensor_tag_to_buf[tensor_tag] = None + + # Time to offload the next group + if self.offloaded_group_count < (self.num_offload_group - 1): + self.bulk_offload_group(self.offloaded_group_count + 1) + + # Increment the offload group count to keep track + self.offloaded_group_count += 1 + + def on_group_commit_forward(self): + """This function will cause host device synchronization""" + # handle synchronization events + self.synchronize_on_group_commit_forward(self.current_group) + + super().on_group_commit_forward() + + @torch.no_grad + def bulk_reload_group(self, group_to_reload): + """Bulk reload group.""" + assert group_to_reload < self.num_offload_group + + with get_torch_device().stream(self.h2d_stream): + # move back tensors + offload_mapping = self.group_offload_mapping.pop(group_to_reload) + assert offload_mapping is not None + for key, state in offload_mapping.items(): + offload_mapping[key] = SynchronizedGroupOffloadHandler.reload(state) + for tensor_label, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_label + if group_id == group_to_reload and not isinstance(state, torch.Tensor): + assert isinstance(state, tuple), f'{group_id} {state}' + key, shape = state + recovered_tensor = offload_mapping[key].view(shape) + self.tensor_tag_to_state[tensor_label] = recovered_tensor + + def on_group_commit_backward(self): + # first decrement the current group. + # after last commit in forward, the group will +1; in backward it -1. + # Finally it should be decremented to 0. + self.current_group -= 1 + assert self.current_group >= 0 + + # Layer window data structure helps us to reload at right times + if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group: + # Stream synchronization both ways + self.h2d_stream.wait_stream(get_torch_device().current_stream()) + get_torch_device().current_stream().wait_stream(self.h2d_stream) + + # Time to reload the next group + self.bulk_reload_group(self.offloaded_group_count - 1) + + # Decrease the offloading group counter + self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0 + + # Last group computation needs to wait till all the reloads complete + if self.current_group == 0: + get_torch_device().current_stream().wait_stream(self.h2d_stream) + self.offloaded_group_count = 0 + + +def get_activation_offload_context(num_layers: int = 1, + model_layers: int = 1, + tensor_need_offloading_checker=(lambda t: True)): + cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( + num_offload_group=num_layers, + num_model_group=model_layers, + tensor_need_offloading_checker=tensor_need_offloading_checker, + ) + + def group_prefetch_offload_commit_async(tensor): + return group_prefetch_offload_commit(tensor, cpu_offload_handler) + + return ( + CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler), + group_prefetch_offload_commit_async, + ) + + +class ActivationHandler: + + def __init__(self, offload_ctx, sync_func, tensor_filter, enable_ckpt): + self._offload_ctx = offload_ctx + self._sync_func = sync_func + self._enable_ckpt = enable_ckpt + self._tensor_filter = tensor_filter + if enable_ckpt: + self.checkpoint_fn = functools.partial( + torch.utils.checkpoint.checkpoint, + use_reentrant=True, + ) + + def pre_forward(self, module): + if module.training: + self._offload_ctx.__enter__() + self._tensor_filter.update_model_parameters(module) + + def post_forward(self, module): + if module.training: + self._offload_ctx.__exit__(None, None, None) + + def _pack_kwargs(self, *args, **kwargs): + kwarg_keys = [] + flat_args = list(args) + for k, v in kwargs.items(): + kwarg_keys.append(k) + flat_args.append(v) + + return tuple(flat_args), tuple(kwarg_keys) + + def _unpack_kwargs(self, flat_args, kwarg_keys): + assert len(kwarg_keys) <= len(flat_args), f'too many keys {len(kwarg_keys)} vs. {len(flat_args)}' + if len(kwarg_keys) == 0: + return flat_args, {} + args = flat_args[:-len(kwarg_keys)] + kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys):], strict=True)) + return args, kwargs + + def _ckpt_forward(self, forward_method, *args, **kwargs): + flat_args, kwarg_keys = self._pack_kwargs(*args, **kwargs) + + def my_function(*inputs): + # unpack back into args and kwargs + nonlocal forward_method, kwarg_keys + unpacked_args, unpacked_kwargs = self._unpack_kwargs(inputs, kwarg_keys) + # run original module + return forward_method(*unpacked_args, **unpacked_kwargs) + + return self.checkpoint_fn( + my_function, + *flat_args, + ) + + def forward(self, module, forward_method, *args, **kwargs): + if not module.training: + return forward_method(*args, **kwargs) + if not self._enable_ckpt: + ret = forward_method(*args, **kwargs) + else: + ret = self._ckpt_forward(forward_method, *args, **kwargs) + binded_tensor = ret + if isinstance(ret, tuple): + binded_tensor = ret[0] + binded_tensor = self._sync_func(binded_tensor) + final_ret = binded_tensor + if isinstance(ret, tuple): + final_ret = (final_ret, ) + ret[1:] + return final_ret + + def wrap_module_forward_method(self, module): + orig_method = module.forward + handler = self + + @functools.wraps(orig_method) + def wrapped_method(model_self, *args, **kwargs): + nonlocal handler + handler.pre_forward(model_self) + out = handler.forward(model_self, orig_method, *args, **kwargs) + handler.post_forward(model_self) + return out + + module.forward = wrapped_method.__get__(module, type(module)) + + +def enable_activation_offloading(model, strategy, enable_ckpt=False): + """ + Enable activation offloading for the model. It groups activations by TransformerLayer and offloads activation + groups asynchronously. This means that the offloading of the i-th activation group and the computation of the i+1-th + activation group happen at the same time, and there are at most two activation groups in GPU memory. + + Args: + model: the model to enable activation offloading + strategy: the training strategy of the model, such as "fsdp" + enable_ckpt: whether activation checkpointing(also called gradient checkpointing) has been enabled for the model + + Note: + For best efficiency, activation offloading is usually combined with activation checkpointing. However, this + implementation of activation offloading is conflicted with the implementation of activation checkpointing in + some training strategies. This function resolves this conflict, and therefore requires the "strategy" and + "enable_ckpt" arguments. + + Returns: + + """ + + assert strategy == 'fsdp' or strategy == 'fsdp2', 'activation offloading only supports fsdp strategy' + layers = [] + + def get_layers(module): + for name, child in module.named_children(): + if not isinstance(child, FSDP | FSDP2): + get_layers(child) + else: + wrapped_module = child + if isinstance(child, FSDP): + wrapped_module = child._fsdp_wrapped_module + # In some cases, torch.nn.Embedding is wrapped with FSDP alone. However, the activation + # size of torch.nn.Embedding is small, so it's not necessary to offload it. + if not isinstance(wrapped_module, torch.nn.Embedding): + layers.append(child) + + get_layers(model) + if len(layers) < 3: + logger.warning(f'Find only {len(layers)} fsdp layers, not neccessary to enable async activation offloading') + return + + tensor_filter = FSDPParameterFilter() + context, sync_func = get_activation_offload_context(len(layers) - 1, len(layers), tensor_filter) + if enable_ckpt: + # The implementation of activation checkpointing in transformers library is incompatible with + # activation offloading, + # so it will be disabled, but this implementation supports another version of activation checkpointing, so that + # these two features can be enabled at the same time. + for module in model.modules(): + if hasattr(module, 'gradient_checkpointing_disable'): + module.gradient_checkpointing_disable() + + handler = ActivationHandler(context, sync_func, tensor_filter, enable_ckpt) + for layer in layers: + module = layer + if isinstance(layer, FSDP): + module = module._fsdp_wrapped_module + handler.wrap_module_forward_method(module) + + +class ActivationCpuOffloadCallBack(TrainerCallback): + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the beginning of training. + """ + model = kwargs['model'] + + # Check if model is wrapped with FSDP + if isinstance(model, FSDP) or isinstance(model, FSDP2): + if args is not None and hasattr(args, 'fsdp_config'): + fsdp_config = args.fsdp_config + # Check if fsdp_config is a dictionary and has activation_cpu_offload enabled + if isinstance(fsdp_config, dict) and fsdp_config.get('activation_cpu_offload', False): + # Get FSDP version from fsdp_config + strategy = fsdp_config.get('version', None) + if strategy is not None: + fsdp_version = 'fsdp' if strategy == 1 else 'fsdp2' + # Get activation checkpointing setting from fsdp_config + enable_ckpt = fsdp_config.get('activation_checkpointing', False) + if enable_ckpt and hasattr(model, 'enable_input_require_grads'): + model.enable_input_require_grads() + enable_activation_offloading(model, strategy=fsdp_version, enable_ckpt=enable_ckpt) From e2481c0d7e958e90965b168237a7bdb2c3c36a8a Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Tue, 30 Dec 2025 17:25:36 +0800 Subject: [PATCH 2/3] feat: update activation CPU offload example and logging - Change model from Qwen3-0.6B to Qwen3-8B in training script - Remove logger level setting to use default logging configuration - Add training logs demonstrating memory savings with activation offload - Show OOM error when activation offload is disabled for comparison The update demonstrates the effectiveness of activation CPU offload for larger models, showing successful training with Qwen3-8B where it previously would have OOM'd without offloading. --- examples/train/activation_cpu_offload/train.sh | 14 +++++++++++++- swift/plugin/activation_cpu_offload.py | 1 - 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/examples/train/activation_cpu_offload/train.sh b/examples/train/activation_cpu_offload/train.sh index e5fee8e54c..9306c74032 100644 --- a/examples/train/activation_cpu_offload/train.sh +++ b/examples/train/activation_cpu_offload/train.sh @@ -1,7 +1,7 @@ #!/bin/bash CUDA_VISIBLE_DEVICES=0,1 \ swift sft \ - --model 'Qwen/Qwen3-0.6B' \ + --model 'Qwen/Qwen3-8B' \ --dataset 'swift/self-cognition#1000' \ \ --load_from_cache_file true \ --split_dataset_ratio 0.01 \ @@ -25,3 +25,15 @@ swift sft \ --warmup_ratio 0.05 \ --dataloader_num_workers 4 \ --fsdp './examples/train/activation_cpu_offload/fsdp2.json' + +# activation_cpu_offload=true +# {'loss': 1.13790035, 'grad_norm': 1.41501045, 'learning_rate': 5e-05, 'token_acc': 0.83174487, 'epoch': 0.04, 'global_step/max_steps': '1/27', 'percentage': '3.70%', 'elapsed_time': '3m 36s', 'remaining_time': '1h 33m 43s', 'memory(GiB)': 32.54, 'train_speed(iter/s)': 0.004623} +# {'loss': 0.94536996, 'grad_norm': 0.85681218, 'learning_rate': 9.649e-05, 'token_acc': 0.84959215, 'epoch': 0.19, 'global_step/max_steps': '5/27', 'percentage': '18.52%', 'elapsed_time': '17m 16s', 'remaining_time': '1h 16m 1s', 'memory(GiB)': 39.92, 'train_speed(iter/s)': 0.004823} +# {'loss': 0.68646059, 'grad_norm': 0.25970718, 'learning_rate': 7.679e-05, 'token_acc': 0.85168261, 'epoch': 0.37, 'global_step/max_steps': '10/27', 'percentage': '37.04%', 'elapsed_time': '34m 34s', 'remaining_time': '58m 46s', 'memory(GiB)': 39.92, 'train_speed(iter/s)': 0.00482} + +# activation_cpu_offload=false +# OOM +# {'loss': 1.13790035, 'grad_norm': 1.41472316, 'learning_rate': 5e-05, 'token_acc': 0.83174487, 'epoch': 0.04, 'global_step/max_steps': '1/27', 'percentage': '3.70%', 'elapsed_time': '46s', 'remaining_time': '20m 1s', 'memory(GiB)': 61.79, 'train_speed(iter/s)': 0.021641} +# Train: 11%|████████████ | 3/27 [01:52<14:28, 36.19s/it +# ... +# [rank1]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 320.00 MiB. GPU 1 has a total capacity of 63.59 GiB of which 0 bytes is free. Of the allocated memory 55.85 GiB is allocated by PyTorch, and 3.64 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) diff --git a/swift/plugin/activation_cpu_offload.py b/swift/plugin/activation_cpu_offload.py index dc348805dc..e160d71d81 100644 --- a/swift/plugin/activation_cpu_offload.py +++ b/swift/plugin/activation_cpu_offload.py @@ -16,7 +16,6 @@ from swift.utils import get_logger logger = get_logger() -logger.setLevel(logging.WARNING) def is_torch_npu_available() -> bool: From 7a3b7e9d7feb93a068d55dc4473026a7dff15de6 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Wed, 31 Dec 2025 17:51:05 +0800 Subject: [PATCH 3/3] feat: update training script with new dataset and configuration --- .../train/activation_cpu_offload/train.sh | 39 +++++++++++++------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/examples/train/activation_cpu_offload/train.sh b/examples/train/activation_cpu_offload/train.sh index 9306c74032..b9b206748c 100644 --- a/examples/train/activation_cpu_offload/train.sh +++ b/examples/train/activation_cpu_offload/train.sh @@ -1,11 +1,10 @@ #!/bin/bash CUDA_VISIBLE_DEVICES=0,1 \ +NPROC_PER_NODE=2 \ swift sft \ --model 'Qwen/Qwen3-8B' \ - --dataset 'swift/self-cognition#1000' \ \ - --load_from_cache_file true \ - --split_dataset_ratio 0.01 \ --train_type lora \ + --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \ --torch_dtype bfloat16 \ --num_train_epochs 1 \ --per_device_train_batch_size 1 \ @@ -13,27 +12,43 @@ swift sft \ --learning_rate 1e-4 \ --lora_rank 8 \ --lora_alpha 32 \ + --gradient_checkpointing false \ + --weight_decay 0.1 \ --target_modules all-linear \ - --freeze_vit true \ --gradient_accumulation_steps 16 \ --eval_steps 100 \ - --save_steps 100 \ + --save_steps 5 \ --save_total_limit 2 \ --logging_steps 5 \ --max_length 2048 \ --output_dir output \ + --system You\ are\ a\ helpful\ assistant. \ --warmup_ratio 0.05 \ --dataloader_num_workers 4 \ --fsdp './examples/train/activation_cpu_offload/fsdp2.json' + +# --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' # activation_cpu_offload=true -# {'loss': 1.13790035, 'grad_norm': 1.41501045, 'learning_rate': 5e-05, 'token_acc': 0.83174487, 'epoch': 0.04, 'global_step/max_steps': '1/27', 'percentage': '3.70%', 'elapsed_time': '3m 36s', 'remaining_time': '1h 33m 43s', 'memory(GiB)': 32.54, 'train_speed(iter/s)': 0.004623} -# {'loss': 0.94536996, 'grad_norm': 0.85681218, 'learning_rate': 9.649e-05, 'token_acc': 0.84959215, 'epoch': 0.19, 'global_step/max_steps': '5/27', 'percentage': '18.52%', 'elapsed_time': '17m 16s', 'remaining_time': '1h 16m 1s', 'memory(GiB)': 39.92, 'train_speed(iter/s)': 0.004823} -# {'loss': 0.68646059, 'grad_norm': 0.25970718, 'learning_rate': 7.679e-05, 'token_acc': 0.85168261, 'epoch': 0.37, 'global_step/max_steps': '10/27', 'percentage': '37.04%', 'elapsed_time': '34m 34s', 'remaining_time': '58m 46s', 'memory(GiB)': 39.92, 'train_speed(iter/s)': 0.00482} + +# {'loss': 2.1327579, 'grad_norm': 1.72890568, 'learning_rate': 8.346e-05, 'token_acc': 0.58396158, 'epoch': 0.32, 'global_step/max_steps': '5/16', 'percentage': '31.25%', 'elapsed_time': '5m 28s', 'remaining_time': '12m 2s', 'memory(GiB)': 24.8, 'train_speed(iter/s)': 0.015218} +# Train: 31%|██████████████████████████████████████▍ | 5/16 [05:28<11:41, 63.77s/it][INFO:swift] Saving model checkpoint to /model/ljl/output/v45-20251231-160511/checkpoint-5 +# {'loss': 1.51323957, 'grad_norm': 0.39210615, 'learning_rate': 3.455e-05, 'token_acc': 0.62368014, 'epoch': 0.64, 'global_step/max_steps': '10/16', 'percentage': '62.50%', 'elapsed_time': '10m 22s', 'remaining_time': '6m 13s', 'memory(GiB)': 24.87, 'train_speed(iter/s)': 0.016054} +# Train: 62%|████████████████████████████████████████████████████████████████████████████▎ | 10/16 [10:22<05:37, 56.26s/it][INFO:swift] Saving model checkpoint to /model/ljl/output/v45-20251231-160511/checkpoint-10 +# {'loss': 1.36127844, 'grad_norm': 0.30676287, 'learning_rate': 1.09e-06, 'token_acc': 0.64411869, 'epoch': 0.96, 'global_step/max_steps': '15/16', 'percentage': '93.75%', 'elapsed_time': '15m 6s', 'remaining_time': '1m 0s', 'memory(GiB)': 24.87, 'train_speed(iter/s)': 0.016547} +# ... +# {'train_runtime': 962.7184, 'train_samples_per_second': 0.519, 'train_steps_per_second': 0.017, 'train_loss': 1.61728384, 'token_acc': 0.62789828, 'epoch': 1.0, 'global_step/max_steps': '16/16', 'percentage': '100.00%', 'elapsed_time': '16m 2s', 'remaining_time': '0s', 'memory(GiB)': 24.87, 'train_speed(iter/s)': 0.016624} +# Train: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [16:02<00:00, 60.16s/it] + # activation_cpu_offload=false -# OOM -# {'loss': 1.13790035, 'grad_norm': 1.41472316, 'learning_rate': 5e-05, 'token_acc': 0.83174487, 'epoch': 0.04, 'global_step/max_steps': '1/27', 'percentage': '3.70%', 'elapsed_time': '46s', 'remaining_time': '20m 1s', 'memory(GiB)': 61.79, 'train_speed(iter/s)': 0.021641} -# Train: 11%|████████████ | 3/27 [01:52<14:28, 36.19s/it + +# {'loss': 2.15452981, 'grad_norm': 1.7536869, 'learning_rate': 0.0001, 'token_acc': 0.61792799, 'epoch': 0.06, 'global_step/max_steps': '1/16', 'percentage': '6.25%', 'elapsed_time': '46s', 'remaining_time': '11m 39s', 'memory(GiB)': 26.14, 'train_speed(iter/s)': 0.021458} +# {'loss': 2.13306689, 'grad_norm': 1.7279824, 'learning_rate': 8.346e-05, 'token_acc': 0.58295639, 'epoch': 0.32, 'global_step/max_steps': '5/16', 'percentage': '31.25%', 'elapsed_time': '2m 55s', 'remaining_time': '6m 26s', 'memory(GiB)': 26.59, 'train_speed(iter/s)': 0.028456} +# Train: 31%|██████████████████████████████████████▍ | 5/16 [02:55<05:59, 32.65s/it][INFO:swift] Saving model checkpoint to /model/ljl/output/v44-20251231-155036/checkpoint-5 +# {'loss': 1.51308346, 'grad_norm': 0.39151499, 'learning_rate': 3.455e-05, 'token_acc': 0.62377399, 'epoch': 0.64, 'global_step/max_steps': '10/16', 'percentage': '62.50%', 'elapsed_time': '5m 18s', 'remaining_time': '3m 10s', 'memory(GiB)': 27.73, 'train_speed(iter/s)': 0.031432} +# Train: 62%|████████████████████████████████████████████████████████████████████████████▎ | 10/16 [05:18<02:51, 28.58s/it][INFO:swift] Saving model checkpoint to /model/ljl/output/v44-20251231-155036/checkpoint-10 +# {'loss': 1.36132231, 'grad_norm': 0.30557585, 'learning_rate': 1.09e-06, 'token_acc': 0.64442776, 'epoch': 0.96, 'global_step/max_steps': '15/16', 'percentage': '93.75%', 'elapsed_time': '7m 57s', 'remaining_time': '31s', 'memory(GiB)': 27.96, 'train_speed(iter/s)': 0.031437} # ... -# [rank1]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 320.00 MiB. GPU 1 has a total capacity of 63.59 GiB of which 0 bytes is free. Of the allocated memory 55.85 GiB is allocated by PyTorch, and 3.64 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) +# {'train_runtime': 507.5282, 'train_samples_per_second': 0.985, 'train_steps_per_second': 0.032, 'train_loss': 1.61732693, 'token_acc': 0.63051608, 'epoch': 1.0, 'global_step/max_steps': '16/16', 'percentage': '100.00%', 'elapsed_time': '8m 27s', 'remaining_time': '0s', 'memory(GiB)': 27.96, 'train_speed(iter/s)': 0.031543} +# Train: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [08:27<00:00, 31.70s/it]