Skip to content

Model getting offloaded to CPU without user's intentionΒ #10914

@janzd

Description

@janzd

Describe the bug

I came across an issue that my model kept getting moved to CPU after loading LoRA weights with the load_lora_weights() method.
I found out that is_sequential_cpu_offload is set to True while loading LoRA on https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/lora_base.py#L441 despite never enabling CPU offload in my code. Inference then takes about 15x more time than when the model sits on the GPU.

I'm using an 8-bit quantized FLUX model and a FLUX LoRA and my model is supposed to be on the GPU.

When I add the parameter that I commented out in the code ( device_map="balanced" ) to pipeline initialization, the issue disappears and the model stays on the GPU.
Is it intended behavior that it doesn't work without the extra parameter?

This issue is related to #7539 but that issue is stale and I came across the issue in a different way, so I decided to open a new one. In #7539 , the author of the issue explicitly calls align_device_hook(). I'm just trying to load LoRA weights without doing anything with device hooks.
The author of #7539 created a PR (#8750) that solved my issue without having to add device_map="balanced" to the code, but the PR never got merged.

Reproduction

import torch
from diffusers import FluxPipeline, BitsAndBytesConfig, FluxTransformer2DModel
from transformers import T5EncoderModel
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

text_encoder_2_8bit = T5EncoderModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="text_encoder_2",
    quantization_config=TransformersBitsAndBytesConfig(load_in_8bit=True),
    torch_dtype=torch.float16,
)            
transformer_8bit = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="transformer",
    quantization_config=BitsAndBytesConfig(load_in_8bit=True),
    torch_dtype=torch.float16,
)

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", 
    text_encoder_2=text_encoder_2_8bit,
    transformer=transformer_8bit,
    torch_dtype=torch.float16,
    #device_map="balanced"
)

pipe.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="anime_lora.safetensors")

prompt = "a hand holding a knife and cutting a cabbage, anime"
out = pipe(
    prompt=prompt,
    guidance_scale=3.5,
    height=1024,
    width=1024,
    num_inference_steps=25,
).images[0]
out.save("image.png")

Logs

System Info

  • πŸ€— Diffusers version: 0.33.0.dev0
  • Platform: Linux-5.15.0-126-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.11
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.28.1
  • Transformers version: 4.48.1
  • Accelerate version: 1.3.0
  • PEFT version: 0.14.1.dev0
  • Bitsandbytes version: 0.45.3
  • Safetensors version: 0.4.5
  • xFormers version: 0.0.29.post2
  • Accelerator: NVIDIA L4, 23034 MiB
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions