diff --git a/examples/train/activation_cpu_offload/fsdp2.json b/examples/train/activation_cpu_offload/fsdp2.json new file mode 100644 index 0000000000..73d856389a --- /dev/null +++ b/examples/train/activation_cpu_offload/fsdp2.json @@ -0,0 +1,26 @@ +{ + "_description": "FSDP2 configuration for distributed training (PyTorch native FSDP v2)", + "_requires": "torch>=2.4.0", + "_note": "This is the recommended configuration for multi-GPU training without CPU offloading. NOTE: When using FSDP2, do NOT use --gradient_checkpointing, use activation_checkpointing in fsdp_config instead.", + + "_param_docs": { + "fsdp": "FSDP strategy string. Options: 'full_shard' (ZeRO-3 style, shards params+grads+optimizer), 'shard_grad_op' (ZeRO-2 style, shards grads+optimizer only). Add 'auto_wrap' to enable automatic layer wrapping. Add 'offload' to enable CPU offloading.", + "fsdp_version": "FSDP version. Use 2 for PyTorch native FSDP2 (recommended). FSDP2 uses DTensor for per-parameter sharding, supports LoRA/QLoRA natively.", + "auto_wrap_policy": "How to wrap model layers. 'TRANSFORMER_BASED_WRAP' wraps transformer decoder layers (from model._no_split_modules). 'SIZE_BASED_WRAP' wraps modules exceeding min_num_params.", + "cpu_ram_efficient_loading": "If true, only rank 0 loads full model weights, then broadcasts to other ranks. Reduces CPU RAM usage during initialization.", + "state_dict_type": "'SHARDED_STATE_DICT' (recommended): each rank saves its own shard without extra communication. 'FULL_STATE_DICT': gathers full model on rank 0 (higher memory, slower).", + "reshard_after_forward": "true = FULL_SHARD (ZeRO-3), reshards params after forward pass. false = SHARD_GRAD_OP (ZeRO-2), keeps params gathered during forward/backward.", + "activation_checkpointing": "Use FSDP's native activation checkpointing instead of gradient_checkpointing. This is the correct way to save memory with FSDP.", + "activation_cpu_offload": "true = offload activations to CPU. false = keep activations on GPU,can enable when using activation_checkpointing." + }, + "fsdp": "full_shard auto_wrap", + "fsdp_config": { + "fsdp_version": 2, + "reshard_after_forward": true, + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "cpu_ram_efficient_loading": true, + "state_dict_type": "SHARDED_STATE_DICT", + "activation_checkpointing": false, + "activation_cpu_offload": true + } +} diff --git a/examples/train/activation_cpu_offload/train.sh b/examples/train/activation_cpu_offload/train.sh new file mode 100644 index 0000000000..b9b206748c --- /dev/null +++ b/examples/train/activation_cpu_offload/train.sh @@ -0,0 +1,54 @@ +#!/bin/bash +CUDA_VISIBLE_DEVICES=0,1 \ +NPROC_PER_NODE=2 \ +swift sft \ + --model 'Qwen/Qwen3-8B' \ + --train_type lora \ + --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-4 \ + --lora_rank 8 \ + --lora_alpha 32 \ + --gradient_checkpointing false \ + --weight_decay 0.1 \ + --target_modules all-linear \ + --gradient_accumulation_steps 16 \ + --eval_steps 100 \ + --save_steps 5 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --system You\ are\ a\ helpful\ assistant. \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --fsdp './examples/train/activation_cpu_offload/fsdp2.json' + + +# --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' +# activation_cpu_offload=true + +# {'loss': 2.1327579, 'grad_norm': 1.72890568, 'learning_rate': 8.346e-05, 'token_acc': 0.58396158, 'epoch': 0.32, 'global_step/max_steps': '5/16', 'percentage': '31.25%', 'elapsed_time': '5m 28s', 'remaining_time': '12m 2s', 'memory(GiB)': 24.8, 'train_speed(iter/s)': 0.015218} +# Train: 31%|██████████████████████████████████████▍ | 5/16 [05:28<11:41, 63.77s/it][INFO:swift] Saving model checkpoint to /model/ljl/output/v45-20251231-160511/checkpoint-5 +# {'loss': 1.51323957, 'grad_norm': 0.39210615, 'learning_rate': 3.455e-05, 'token_acc': 0.62368014, 'epoch': 0.64, 'global_step/max_steps': '10/16', 'percentage': '62.50%', 'elapsed_time': '10m 22s', 'remaining_time': '6m 13s', 'memory(GiB)': 24.87, 'train_speed(iter/s)': 0.016054} +# Train: 62%|████████████████████████████████████████████████████████████████████████████▎ | 10/16 [10:22<05:37, 56.26s/it][INFO:swift] Saving model checkpoint to /model/ljl/output/v45-20251231-160511/checkpoint-10 +# {'loss': 1.36127844, 'grad_norm': 0.30676287, 'learning_rate': 1.09e-06, 'token_acc': 0.64411869, 'epoch': 0.96, 'global_step/max_steps': '15/16', 'percentage': '93.75%', 'elapsed_time': '15m 6s', 'remaining_time': '1m 0s', 'memory(GiB)': 24.87, 'train_speed(iter/s)': 0.016547} +# ... +# {'train_runtime': 962.7184, 'train_samples_per_second': 0.519, 'train_steps_per_second': 0.017, 'train_loss': 1.61728384, 'token_acc': 0.62789828, 'epoch': 1.0, 'global_step/max_steps': '16/16', 'percentage': '100.00%', 'elapsed_time': '16m 2s', 'remaining_time': '0s', 'memory(GiB)': 24.87, 'train_speed(iter/s)': 0.016624} +# Train: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [16:02<00:00, 60.16s/it] + + +# activation_cpu_offload=false + +# {'loss': 2.15452981, 'grad_norm': 1.7536869, 'learning_rate': 0.0001, 'token_acc': 0.61792799, 'epoch': 0.06, 'global_step/max_steps': '1/16', 'percentage': '6.25%', 'elapsed_time': '46s', 'remaining_time': '11m 39s', 'memory(GiB)': 26.14, 'train_speed(iter/s)': 0.021458} +# {'loss': 2.13306689, 'grad_norm': 1.7279824, 'learning_rate': 8.346e-05, 'token_acc': 0.58295639, 'epoch': 0.32, 'global_step/max_steps': '5/16', 'percentage': '31.25%', 'elapsed_time': '2m 55s', 'remaining_time': '6m 26s', 'memory(GiB)': 26.59, 'train_speed(iter/s)': 0.028456} +# Train: 31%|██████████████████████████████████████▍ | 5/16 [02:55<05:59, 32.65s/it][INFO:swift] Saving model checkpoint to /model/ljl/output/v44-20251231-155036/checkpoint-5 +# {'loss': 1.51308346, 'grad_norm': 0.39151499, 'learning_rate': 3.455e-05, 'token_acc': 0.62377399, 'epoch': 0.64, 'global_step/max_steps': '10/16', 'percentage': '62.50%', 'elapsed_time': '5m 18s', 'remaining_time': '3m 10s', 'memory(GiB)': 27.73, 'train_speed(iter/s)': 0.031432} +# Train: 62%|████████████████████████████████████████████████████████████████████████████▎ | 10/16 [05:18<02:51, 28.58s/it][INFO:swift] Saving model checkpoint to /model/ljl/output/v44-20251231-155036/checkpoint-10 +# {'loss': 1.36132231, 'grad_norm': 0.30557585, 'learning_rate': 1.09e-06, 'token_acc': 0.64442776, 'epoch': 0.96, 'global_step/max_steps': '15/16', 'percentage': '93.75%', 'elapsed_time': '7m 57s', 'remaining_time': '31s', 'memory(GiB)': 27.96, 'train_speed(iter/s)': 0.031437} +# ... +# {'train_runtime': 507.5282, 'train_samples_per_second': 0.985, 'train_steps_per_second': 0.032, 'train_loss': 1.61732693, 'token_acc': 0.63051608, 'epoch': 1.0, 'global_step/max_steps': '16/16', 'percentage': '100.00%', 'elapsed_time': '8m 27s', 'remaining_time': '0s', 'memory(GiB)': 27.96, 'train_speed(iter/s)': 0.031543} +# Train: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [08:27<00:00, 31.70s/it] diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 841bdb9ffa..9f6a2990a1 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -265,6 +265,7 @@ def train(self, trainer): @RayHelper.function(group='default') def _prepare_callbacks(self): from .callback import DynamicLayerActivationCallback, TrainerAdapterCallback + from swift.plugin import ActivationCpuOffloadCallBack args = self.args callbacks = [] if args.lisa_activated_layers > 0: @@ -275,6 +276,10 @@ def _prepare_callbacks(self): model=self.model) lisa_callback.switch_active_layers() # Make trainable parameters printing a correct value callbacks.append(lisa_callback) + # Check activation_cpu_offload from fsdp_config + fsdp_config = getattr(self.args, 'fsdp_config', {}) + if isinstance(fsdp_config, dict) and fsdp_config.get('activation_cpu_offload', False): + callbacks.append(ActivationCpuOffloadCallBack()) if args.is_adapter and args.train_type == 'adalora': callbacks.append(TrainerAdapterCallback(args)) diff --git a/swift/plugin/__init__.py b/swift/plugin/__init__.py index 870ece61cd..63cb2fc0e7 100644 --- a/swift/plugin/__init__.py +++ b/swift/plugin/__init__.py @@ -17,6 +17,7 @@ from .rm_plugin import rm_plugins from .env import envs, Env from .context_manager import context_managers, ContextManager + from swift.plugin.activation_cpu_offload import ActivationCpuOffloadCallBack else: _import_structure = { @@ -34,6 +35,7 @@ 'rm_plugin': ['rm_plugins'], 'env': ['envs', 'Env'], 'context_manager': ['context_managers', 'ContextManager'], + 'activation_cpu_offload': ['ActivationCpuOffloadCallBack'], } import sys diff --git a/swift/plugin/activation_cpu_offload.py b/swift/plugin/activation_cpu_offload.py new file mode 100644 index 0000000000..e160d71d81 --- /dev/null +++ b/swift/plugin/activation_cpu_offload.py @@ -0,0 +1,611 @@ +"""Functionality for CPU offloading of tensors saved for backward pass.""" +from __future__ import annotations +import functools +import logging +import os +from contextlib import nullcontext +from typing import Any, Dict, Optional + +import torch +from torch.distributed.fsdp import FSDPModule as FSDP2 +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from transformers import TrainerCallback +from transformers.trainer_callback import TrainerControl, TrainerState +from transformers.training_args import TrainingArguments + +from swift.utils import get_logger + +logger = get_logger() + + +def is_torch_npu_available() -> bool: + """Check the availability of NPU""" + try: + if hasattr(torch, 'npu') and callable(getattr(torch.npu, 'is_available', None)): + return torch.npu.is_available() + return False + except ImportError: + return False + + +is_cuda_available = torch.cuda.is_available() +is_npu_available = is_torch_npu_available() + + +def _get_unique_tensor_key(tensor): + key = (tensor.untyped_storage().data_ptr() + tensor.storage_offset(), tensor.dtype) + return key + + +def get_device_name() -> str: + """Function that gets the torch.device based on the current machine. + This currently only supports CPU, CUDA, NPU. + Returns: + device + """ + if is_cuda_available: + device = 'cuda' + elif is_npu_available: + device = 'npu' + else: + device = 'cpu' + return device + + +class FSDPParameterFilter: + + def __init__(self): + self.model_parameters_storage = set() + + def __call__(self, tensor): + return tensor.untyped_storage().data_ptr() not in self.model_parameters_storage + + def update_model_parameters(self, model): + new_storage = set() + for p in model.parameters(): + new_storage.add(p.data.untyped_storage().data_ptr()) + self.model_parameters_storage = new_storage + + +def get_torch_device() -> any: + """Return the corresponding torch attribute based on the device type string. + Returns: + module: The corresponding torch device namespace, or torch.cuda if not found. + """ + device_name = get_device_name() + try: + return getattr(torch, device_name) + except AttributeError: + logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.") + return torch.cuda + + +class CpuOffloadHookWithOffloadHandler: + """Context-manager that offloads/recovers tensors through an offload hander. + + The hook just offloads/recovers the tensor object to the handler through `tensor_push` + and `tensor_pop` interface. How the offload-handler manages the offloading, recovering + or prefetching timing is transparent to this hook. + """ + + def __init__( + self, + offload_handler: OffloadHandler, + handler_extra_kwargs: Optional[dict[str, Any]] = None, + ) -> None: + if handler_extra_kwargs is None: + handler_extra_kwargs = {} + self.offload_handler: OffloadHandler = offload_handler + self.handler_extra_kwargs: dict[str, Any] = handler_extra_kwargs + self.inside_context = False + + def __enter__(self): + self.inside_context = True + torch._C._autograd._push_saved_tensors_default_hooks(self.on_save_for_backward, self.on_get_saved_tensor) + + def __exit__(self, *args: Any): + self.inside_context = False + torch._C._autograd._pop_saved_tensors_default_hooks() + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) + return retrieve_identifier + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs) + return tensor + + +class OffloadHandler: + """A base class for CPU offload-handler.""" + + def __init__(self) -> None: + pass + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + """Tensor push.""" + raise NotImplementedError( + '`tensor_push is not implented in OffloadHandler class. Inherit this class and implement your ' + 'custom tensor_push.') + + def tensor_pop(self, tensor_tag: Any, **kwargs): + """Tensor pop.""" + raise NotImplementedError( + '`tensor_pop is not implented in OffloadHandler class. Inherit this class and implement your ' + 'custom tensor_pop.') + + +class GroupCommitFunction(torch.autograd.Function): + """this is a dummy op with output identical to input. + However, it is necessary for marking a timepoint for offload handler to + accomplish all synchronizations. Implementing it as a function is necessary + because we need to actions in both forward and backward. + """ + + @staticmethod + def forward(ctx, tensor, cpu_offload_handler): + # pylint: disable=missing-function-docstring + cpu_offload_handler.on_group_commit_forward() + ctx.cpu_offload_handler = cpu_offload_handler + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, grad_output): + # pylint: disable=missing-function-docstring + cpu_offload_handler = ctx.cpu_offload_handler + cpu_offload_handler.on_group_commit_backward() + return grad_output, None + + +group_prefetch_offload_commit = GroupCommitFunction.apply + + +class SynchronizedGroupOffloadHandler(OffloadHandler): + """Offload Handler that offloads/reloads in a synchronized way. + The device-to-host and host-to-device copying happen in the same stream + as the computation kernels, thus the copying will block computation. + """ + + def __init__(self, num_offload_group, tensor_need_offloading_checker=(lambda _: True)) -> None: + super().__init__() + + self.num_offload_group = num_offload_group + self.tensor_need_offloading_checker = tensor_need_offloading_checker + + self.groupid_reset() + + def groupid_reset(self): + """Groupid reset.""" + # Data structures to label saved tensors and book-keep their cpu copies. + # Currently, on push, create a new cpu tensor and copies; on pop, copies + # the tensor back to gpu and deletes the cpu tensor. + # These will increment whenever `group_commit()` is invoked + self.current_group, self.tensor_count_current_group = (0, 0) + self.torch_tensor_count = 0 + self.tensor_tag_to_state = {} + + def on_group_commit_forward(self): + """On group commit forward.""" + # finishing up with updating current group and tensor count + self.current_group += 1 # increment + self.tensor_count_current_group = 0 # reset + + def on_group_commit_backward(self): + """On group commit backward.""" + self.current_group -= 1 + assert self.current_group >= 0 + + @staticmethod + def offload(src_tensor, pin_memory=True): + """Offload.""" + + cpu_backup = torch.empty( + src_tensor.size(), + dtype=src_tensor.dtype, + layout=src_tensor.layout, + device='cpu', + pin_memory=pin_memory, + ) + cpu_backup.copy_(src_tensor, non_blocking=True) + state = (src_tensor.device, cpu_backup) + return state + + @staticmethod + def reload(state, non_blocking=None): + """Reload.""" + dev, cpu_backup = state + if non_blocking is None: + non_blocking = cpu_backup.is_pinned() + return cpu_backup.to(dev, non_blocking=non_blocking) + + def tensor_push(self, tensor: torch.Tensor, **kwargs): + """Tensor push.""" + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + assert tensor_tag not in self.tensor_tag_to_state + if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(tensor): + state = SynchronizedGroupOffloadHandler.offload(tensor) + self.tensor_tag_to_state[tensor_tag] = state + else: + # will be offloaded together after group commit + self.tensor_tag_to_state[tensor_tag] = tensor + + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + assert tensor_tag in self.tensor_tag_to_state + state = self.tensor_tag_to_state.pop(tensor_tag) + if isinstance(state, tuple): + tensor = SynchronizedGroupOffloadHandler.reload(state) + else: + tensor = state + return tensor + + +class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): + """Compared to synchronize, this uses more memory because of the buffer but + achieves better performance due to the overlapping. D2h and h2d copying are + completely hidden behind computation if computation time of a layer is longer + than host-device communication time. Bulk offloading with delay and bulk reloading + with prefetch are implemented.""" + + def __init__( + self, + num_offload_group, # must be <= actual number of groups (number of commits) + num_model_group, + tensor_need_offloading_checker=(lambda t: True), + ) -> None: + super().__init__( + num_offload_group=num_offload_group, + tensor_need_offloading_checker=tensor_need_offloading_checker, + ) + # Number of layers in the model + self.num_layers = num_model_group + # Data Structure to maintain reference to activation tensors + self.tensor_tag_to_buf = {} + # Tracking the number of layers offloaded + self.offloaded_group_count = 0 + # Core data structure that decides the window for offloading + self.layer_window_map = {} + self.group_offload_mapping = {} + + # Logic to make offloading load balance across computation + # for optimal CPU/GPU interconnect usage + constant = 0 + for i in range(self.num_offload_group): + self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1 + if i < (self.num_layers % self.num_offload_group): + self.layer_window_map[i] += i + 1 + constant = i + 1 + else: + self.layer_window_map[i] += constant + + # allocate streams and events for synchronization + self.d2h_stream = get_torch_device().Stream() + self.h2d_stream = get_torch_device().Stream() + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + torch_stray_tensor = isinstance( + tensor, + torch._subclasses.fake_tensor.FakeTensor | torch._subclasses.functional_tensor.FunctionalTensor, + ) + need_offload = not torch_stray_tensor + need_offload = need_offload and self.tensor_need_offloading_checker(tensor) + + if need_offload: + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + + assert tensor_tag not in self.tensor_tag_to_state + self.tensor_tag_to_state[tensor_tag] = tensor + + if self.current_group < self.num_offload_group: + self.tensor_tag_to_buf[tensor_tag] = tensor + else: + tensor_tag = tensor + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + if isinstance(tensor_tag, torch.Tensor): + return tensor_tag + assert tensor_tag in self.tensor_tag_to_state + tensor = self.tensor_tag_to_state.pop(tensor_tag) + self.tensor_tag_to_buf.pop(tensor_tag, None) + + # the tensor should have been copied back in on_group_commit_backward() + # which invokes bulk_reload_group. + assert not isinstance(tensor, tuple) + return tensor + + def bulk_offload_group(self, group_to_offload): + """Bulk offload group.""" + offload_mapping = {} + offload_size = 0 + with get_torch_device().stream(self.d2h_stream): + for tensor_tag, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_tag + if group_id == group_to_offload: + assert not isinstance(state, tuple) + key = _get_unique_tensor_key(state) + if key not in offload_mapping: + offload_mapping[key] = state + # if offload, return the reference to cpu copy + self.tensor_tag_to_state[tensor_tag] = (key, state.shape) + for key, tensor in offload_mapping.items(): + state = SynchronizedGroupOffloadHandler.offload(tensor) + offload_size += tensor.numel() * tensor.element_size() + offload_mapping[key] = state + + self.group_offload_mapping[group_to_offload] = offload_mapping + + def synchronize_on_group_commit_forward(self, current_group): + """Synchronize on group commit forward.""" + + # For the first group, kickstart the offload after we have + # the first compute completion + if current_group == 0: + self.d2h_stream.wait_stream(get_torch_device().current_stream()) + self.bulk_offload_group(current_group) + + # Window map data structure helps us synchronize based on number + # of layers offloaded + if self.layer_window_map[self.offloaded_group_count] == current_group: + # Stream synchronization both ways + self.d2h_stream.wait_stream(get_torch_device().current_stream()) + get_torch_device().current_stream().wait_stream(self.d2h_stream) + + # Time to free the activation memory after usage + for tensor_tag, _ in self.tensor_tag_to_buf.items(): + if tensor_tag[0] == self.offloaded_group_count: + self.tensor_tag_to_buf[tensor_tag] = None + + # Time to offload the next group + if self.offloaded_group_count < (self.num_offload_group - 1): + self.bulk_offload_group(self.offloaded_group_count + 1) + + # Increment the offload group count to keep track + self.offloaded_group_count += 1 + + def on_group_commit_forward(self): + """This function will cause host device synchronization""" + # handle synchronization events + self.synchronize_on_group_commit_forward(self.current_group) + + super().on_group_commit_forward() + + @torch.no_grad + def bulk_reload_group(self, group_to_reload): + """Bulk reload group.""" + assert group_to_reload < self.num_offload_group + + with get_torch_device().stream(self.h2d_stream): + # move back tensors + offload_mapping = self.group_offload_mapping.pop(group_to_reload) + assert offload_mapping is not None + for key, state in offload_mapping.items(): + offload_mapping[key] = SynchronizedGroupOffloadHandler.reload(state) + for tensor_label, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_label + if group_id == group_to_reload and not isinstance(state, torch.Tensor): + assert isinstance(state, tuple), f'{group_id} {state}' + key, shape = state + recovered_tensor = offload_mapping[key].view(shape) + self.tensor_tag_to_state[tensor_label] = recovered_tensor + + def on_group_commit_backward(self): + # first decrement the current group. + # after last commit in forward, the group will +1; in backward it -1. + # Finally it should be decremented to 0. + self.current_group -= 1 + assert self.current_group >= 0 + + # Layer window data structure helps us to reload at right times + if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group: + # Stream synchronization both ways + self.h2d_stream.wait_stream(get_torch_device().current_stream()) + get_torch_device().current_stream().wait_stream(self.h2d_stream) + + # Time to reload the next group + self.bulk_reload_group(self.offloaded_group_count - 1) + + # Decrease the offloading group counter + self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0 + + # Last group computation needs to wait till all the reloads complete + if self.current_group == 0: + get_torch_device().current_stream().wait_stream(self.h2d_stream) + self.offloaded_group_count = 0 + + +def get_activation_offload_context(num_layers: int = 1, + model_layers: int = 1, + tensor_need_offloading_checker=(lambda t: True)): + cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( + num_offload_group=num_layers, + num_model_group=model_layers, + tensor_need_offloading_checker=tensor_need_offloading_checker, + ) + + def group_prefetch_offload_commit_async(tensor): + return group_prefetch_offload_commit(tensor, cpu_offload_handler) + + return ( + CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler), + group_prefetch_offload_commit_async, + ) + + +class ActivationHandler: + + def __init__(self, offload_ctx, sync_func, tensor_filter, enable_ckpt): + self._offload_ctx = offload_ctx + self._sync_func = sync_func + self._enable_ckpt = enable_ckpt + self._tensor_filter = tensor_filter + if enable_ckpt: + self.checkpoint_fn = functools.partial( + torch.utils.checkpoint.checkpoint, + use_reentrant=True, + ) + + def pre_forward(self, module): + if module.training: + self._offload_ctx.__enter__() + self._tensor_filter.update_model_parameters(module) + + def post_forward(self, module): + if module.training: + self._offload_ctx.__exit__(None, None, None) + + def _pack_kwargs(self, *args, **kwargs): + kwarg_keys = [] + flat_args = list(args) + for k, v in kwargs.items(): + kwarg_keys.append(k) + flat_args.append(v) + + return tuple(flat_args), tuple(kwarg_keys) + + def _unpack_kwargs(self, flat_args, kwarg_keys): + assert len(kwarg_keys) <= len(flat_args), f'too many keys {len(kwarg_keys)} vs. {len(flat_args)}' + if len(kwarg_keys) == 0: + return flat_args, {} + args = flat_args[:-len(kwarg_keys)] + kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys):], strict=True)) + return args, kwargs + + def _ckpt_forward(self, forward_method, *args, **kwargs): + flat_args, kwarg_keys = self._pack_kwargs(*args, **kwargs) + + def my_function(*inputs): + # unpack back into args and kwargs + nonlocal forward_method, kwarg_keys + unpacked_args, unpacked_kwargs = self._unpack_kwargs(inputs, kwarg_keys) + # run original module + return forward_method(*unpacked_args, **unpacked_kwargs) + + return self.checkpoint_fn( + my_function, + *flat_args, + ) + + def forward(self, module, forward_method, *args, **kwargs): + if not module.training: + return forward_method(*args, **kwargs) + if not self._enable_ckpt: + ret = forward_method(*args, **kwargs) + else: + ret = self._ckpt_forward(forward_method, *args, **kwargs) + binded_tensor = ret + if isinstance(ret, tuple): + binded_tensor = ret[0] + binded_tensor = self._sync_func(binded_tensor) + final_ret = binded_tensor + if isinstance(ret, tuple): + final_ret = (final_ret, ) + ret[1:] + return final_ret + + def wrap_module_forward_method(self, module): + orig_method = module.forward + handler = self + + @functools.wraps(orig_method) + def wrapped_method(model_self, *args, **kwargs): + nonlocal handler + handler.pre_forward(model_self) + out = handler.forward(model_self, orig_method, *args, **kwargs) + handler.post_forward(model_self) + return out + + module.forward = wrapped_method.__get__(module, type(module)) + + +def enable_activation_offloading(model, strategy, enable_ckpt=False): + """ + Enable activation offloading for the model. It groups activations by TransformerLayer and offloads activation + groups asynchronously. This means that the offloading of the i-th activation group and the computation of the i+1-th + activation group happen at the same time, and there are at most two activation groups in GPU memory. + + Args: + model: the model to enable activation offloading + strategy: the training strategy of the model, such as "fsdp" + enable_ckpt: whether activation checkpointing(also called gradient checkpointing) has been enabled for the model + + Note: + For best efficiency, activation offloading is usually combined with activation checkpointing. However, this + implementation of activation offloading is conflicted with the implementation of activation checkpointing in + some training strategies. This function resolves this conflict, and therefore requires the "strategy" and + "enable_ckpt" arguments. + + Returns: + + """ + + assert strategy == 'fsdp' or strategy == 'fsdp2', 'activation offloading only supports fsdp strategy' + layers = [] + + def get_layers(module): + for name, child in module.named_children(): + if not isinstance(child, FSDP | FSDP2): + get_layers(child) + else: + wrapped_module = child + if isinstance(child, FSDP): + wrapped_module = child._fsdp_wrapped_module + # In some cases, torch.nn.Embedding is wrapped with FSDP alone. However, the activation + # size of torch.nn.Embedding is small, so it's not necessary to offload it. + if not isinstance(wrapped_module, torch.nn.Embedding): + layers.append(child) + + get_layers(model) + if len(layers) < 3: + logger.warning(f'Find only {len(layers)} fsdp layers, not neccessary to enable async activation offloading') + return + + tensor_filter = FSDPParameterFilter() + context, sync_func = get_activation_offload_context(len(layers) - 1, len(layers), tensor_filter) + if enable_ckpt: + # The implementation of activation checkpointing in transformers library is incompatible with + # activation offloading, + # so it will be disabled, but this implementation supports another version of activation checkpointing, so that + # these two features can be enabled at the same time. + for module in model.modules(): + if hasattr(module, 'gradient_checkpointing_disable'): + module.gradient_checkpointing_disable() + + handler = ActivationHandler(context, sync_func, tensor_filter, enable_ckpt) + for layer in layers: + module = layer + if isinstance(layer, FSDP): + module = module._fsdp_wrapped_module + handler.wrap_module_forward_method(module) + + +class ActivationCpuOffloadCallBack(TrainerCallback): + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the beginning of training. + """ + model = kwargs['model'] + + # Check if model is wrapped with FSDP + if isinstance(model, FSDP) or isinstance(model, FSDP2): + if args is not None and hasattr(args, 'fsdp_config'): + fsdp_config = args.fsdp_config + # Check if fsdp_config is a dictionary and has activation_cpu_offload enabled + if isinstance(fsdp_config, dict) and fsdp_config.get('activation_cpu_offload', False): + # Get FSDP version from fsdp_config + strategy = fsdp_config.get('version', None) + if strategy is not None: + fsdp_version = 'fsdp' if strategy == 1 else 'fsdp2' + # Get activation checkpointing setting from fsdp_config + enable_ckpt = fsdp_config.get('activation_checkpointing', False) + if enable_ckpt and hasattr(model, 'enable_input_require_grads'): + model.enable_input_require_grads() + enable_activation_offloading(model, strategy=fsdp_version, enable_ckpt=enable_ckpt)