Skip to content

Regression - High memory usage when using transformers model with FSDP + LoRA #39795

@romitjain

Description

@romitjain

System Info

  • transformers version: 4.54.0
  • Platform: Linux-5.14.0-284.73.1.el9_2.x86_64-x86_64-with-glibc2.35
  • Python version: 3.10.13
  • Huggingface_hub version: 0.34.1
  • Safetensors version: 0.5.3
  • Accelerate version: 1.9.0
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.7.1+cu126 (CUDA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: Yes, FSDP with accelerate
  • Using GPU in script?: Yes
  • GPU type: NVIDIA A100-SXM4-80GB

Who can help?

@zach-huggingface @SunMarc

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

sft.py

import torch
from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from peft.utils.other import fsdp_auto_wrap_policy


def main():

    model_name = "ibm-granite/granite-8b-code-base"

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
    )
    dummy_input = tokenizer("This is a test sentence.", return_tensors="pt")

    accelerator = Accelerator()

    # if accelerator.is_main_process:
    #     torch.cuda.memory._record_memory_history(max_entries=100000)

    peft_config = LoraConfig(
        r=4,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        target_modules=["q_proj", "v_proj"]
    )

    model = get_peft_model(model, peft_config)

    fsdp_plugin = accelerator.state.fsdp_plugin
    fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model)  # type: ignore

    if accelerator.is_main_process:
        model.print_trainable_parameters()

    model = accelerator.prepare(model)
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    optimizer = accelerator.prepare(optimizer)

    model.train()

    torch.cuda.empty_cache()

    accelerator.print(f"Memory allocated after setup: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

    outputs = model(**dummy_input, labels=dummy_input["input_ids"])

    loss = outputs.loss
    accelerator.backward(loss)
    optimizer.step()
    optimizer.zero_grad()

    peak_memory = torch.cuda.max_memory_allocated() / 1e9

    accelerator.print(f"Peak memory during training step: {peak_memory:.2f} GB")

    accelerator.wait_for_everyone()
    accelerator.print("Debug script finished successfully.")

    # if accelerator.is_main_process:
    #     torch.cuda.memory._dump_snapshot("profile_449.pkl")
    #     torch.cuda.memory._record_memory_history(enabled=None)

if __name__ == "__main__":
    """
    accelerate launch --config_file fsdp.yaml -m sft
    """
    main()

fsdp.yaml

compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_cpu_ram_efficient_loading: false
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
mixed_precision: 'no'
machine_rank: 0
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true

Run

accelerate launch --config_file fsdp.yaml -m sft

Expected behavior

When I am doing LoRA fine-tuning with FSDP, I am seeing a huge memory usage compared to transformers v4.49.0. This issue is specific to versions including 4.50.0 and above. For example,

For 4 GPUs, I see the following memory usage on transformers==4.49.0

Memory allocated after setup: 4.03 GB
Peak memory during training step: 5.36 GB

vs when I am using any higher version transformers==4.54.0

Memory allocated after setup: 4.03 GB
Peak memory during training step: 20.16 GB

The peak memory usage is 4x.

Keeping all other library versions constant, the bug only appears when upgrading transformers to any version above 4.49.0. That's the reason I have raised the bug here and not in accelerate. Downgrading to transformers==4.49.0 fixes the issue.

The issue ends here, but I will provide some of my findings in case it is helpful

  1. I was able to reproduce this issue in other Llama-based models, too.
  2. The bug only appears with FSDP + LoRA. Single GPU jobs don't seem to have the bug.
  3. I have already tried the solution provided here: FSDP2 - High memory usage with LORA accelerate#3474 and it does not solve the issue
  4. The memory explosion happens during the backward pass, specifically at: accelerator.backward(loss)
  5. Looking at the memory profiling results, it seems like all attention heads (Q, V) are somehow treated as trainable and the memory is reserved for their optimizer states which is leading to this 4x spike. I am also attaching the photos from the memory profiling.
  6. For fsdp config, I have tried both values of - fsdp_cpu_ram_efficient_loading, fsdp_use_orig_params, with and without setting fsdp_transformer_layer_cls_to_wrap
Image Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions