Skip to content

CpuOffload pre_forward and post_forward should have @torch.compiler.disable(), otherwise torch.compile fails in many situations it ought to succeed #3885

@doctorpangloss

Description

@doctorpangloss

System Info

- `Accelerate` version: 1.12.0
- Platform: Linux-6.14.0-36-generic-x86_64-with-glibc2.39
- `accelerate` bash location: .venv/bin/accelerate
- Python version: 3.12.11
- Numpy version: 2.2.6
- PyTorch version: 2.9.1+cu128
- PyTorch accelerator: CUDA
- System RAM: 125.71 GB
- GPU type: NVIDIA RTX A5000
- `Accelerate` default config:
	Not found

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

This is kind of a nuanced problem.

First, let's start by saying

from accelerate.hooks import CpuOffload
CpuOffload.pre_forward = torch.compiler.disable(CpuOffload.pre_forward)
CpuOffload.post_forward = torch.compiler.disable(CpuOffload.post_forward)

fixes the issue that I'm going to report to you. Let's dElVe into why.

Here is the breaking code:

import torch

from diffusers import QwenImagePipeline, QwenImageTransformer2DModel
from sdnq.common import use_torch_compile as triton_is_available
from sdnq import sdnq_post_load_quant

torch._logging.set_logs(graph_breaks=True)
def test_sdnq_qwen_image_fast():
    model_id = "wavespeed/Qwen-Image-bf16"

    transformer = QwenImageTransformer2DModel.from_pretrained(
        model_id,
        subfolder="transformer",
        torch_dtype=torch.bfloat16,
        trust_remote_code=True
    )

    assert transformer.device == torch.device("cpu")

    # quantize
    transformer = sdnq_post_load_quant(
        transformer,
        weights_dtype="int8",
        quantized_matmul_dtype="int8",
        group_size=0,
        svd_rank=32,
        svd_steps=8,
        use_svd=False,
        use_quantized_matmul=triton_is_available,
        use_stochastic_rounding=False,
        dequantize_fp32=False,
        non_blocking=True,
        add_skip_keys=True,
        quantization_device=torch.device("cpu"),
        return_device=torch.device("cpu"),
        modules_to_not_convert=[],
        modules_dtype_dict={},
    )

    assert transformer.device == torch.device("cpu")

    # prepare the pipeline
    pipe = QwenImagePipeline.from_pretrained(
        model_id,
        transformer=None,
        torch_dtype=torch.bfloat16,
    )

    pipe.transformer = torch.compile(transformer)
    pipe.enable_model_cpu_offload()

    positive_magic = {
        "en": ", Ultra HD, 4K, cinematic composition.",
        "zh": ", 超清,4K,电影级构图."
    }

    prompt = "A man walking on the beach"
    negative_prompt = " "

    with torch.no_grad():
        image = pipe(
            prompt=prompt + positive_magic["en"],
            negative_prompt=negative_prompt,
            num_inference_steps=50,
            true_cfg_scale=4.0,
            width=512,
            height=512,
            generator=torch.Generator(device="cuda").manual_seed(42)
        ).images[0]

    assert image is not None
    image.save("sdnq_test_output.png")

What does this do?

  1. It loads Qwen Image in bf16.
  2. It quantizes Qwen Image's transformer to int8, using a quantization approach that supports torch.compile.
  3. It asks for torch compilation of the transformer.
  4. It enables model offloading.

Why?

  • when you have a low rank adapter, you want to fuse it with bf16, and not a quantized model.
  • then you want to use the most effective quantization available. Since this is SVDQuant, I use SDNQ, which implements it mostly for diffusers.
  • I only want to quantize the transformer. I don't want to quantize the text_encoder for this model, because bf16 QwenVL2 performs much better than fp8 QwenVL2.
  • So I have set up a pipeline that uses int8 SDNQ quantized transformer, bf16 vae and text encoder.
  • Since I use a 24GB Ampere GPU, I want to enable_model_cpu_offload, i.e., I want to offload the text encoder once it has done the needful, and then I want to run the transformer, etc. etc.

20B int8 fits on an Ampere 24GB, and it supports int8 matmul.

Here's where this gets complicated. SDNQ is "correctly implemented" in the sense that when you quantize a transformer model like Qwen Image with it and you try to torch.compile it, there are no graph breaks and its int8 matmul is implemented in triton. So there's a difference between something where torch.compile does not throw exceptions and torch.compile actually works. We're in the second scenario, it actually works.

But CpuOffload has hooked into the first forward in the transformer. And since torch.compile is actually working, torch.compile will try to compile the accelerate offload. Which is bad.

On a 24GB GPU, torch.compile will cause an OOM in the accelerate hooks:

test_sdnq.py:77: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:120: in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage.py:691: in __call__
    noise_pred = self.transformer(
.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:414: in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/accelerate/hooks.py:170: in new_forward
    args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/accelerate/hooks.py:731: in pre_forward
    module.to(self.execution_device)
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1371: in to
    return self._apply(convert)
           ^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:930: in _apply
    module._apply(fn)
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:930: in _apply
    module._apply(fn)
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:930: in _apply
    module._apply(fn)
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:930: in _apply
    module._apply(fn)
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:930: in _apply
    module._apply(fn)
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:957: in _apply
    param_applied = fn(param)
                    ^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

t = Parameter containing:
tensor([[-0.0126, -0.0064, -0.0781,  ..., -0.1089, -0.0364, -0.0056],
        [-0.0120, -0.0825,...        [ 0.0503, -0.0620,  0.0718,  ..., -0.0075, -0.0664, -0.0206]],
       dtype=torch.bfloat16, requires_grad=True)

    def convert(t):
        try:
            if convert_to_format is not None and t.dim() in (4, 5):
                return t.to(
                    device,
                    dtype if t.is_floating_point() or t.is_complex() else None,
                    non_blocking,
                    memory_format=convert_to_format,
                )
>           return t.to(
                device,
                dtype if t.is_floating_point() or t.is_complex() else None,
                non_blocking,
            )
E           torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 18.00 MiB. GPU 1 has a total capacity of 22.06 GiB of which 10.81 MiB is free. Including non-PyTorch memory, this process has 22.02 GiB memory in use. Of the allocated memory 21.77 GiB is allocated by PyTorch, and 10.58 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1357: OutOfMemoryError

The nuances here are key. SDNQ actually works. This won't occur with Bnb, because its matmul is not triton, so I don't think the whole model can be translated to a singular kernel. I'm not sure. But it is much slower than the solution I am showing. If torch.compile works, if it tries to build a huge singular kernel, it will successfully trace accelerate's hooks, and because copying from cpu to cuda devices is radioactive - I don't know why - a lot of VRAM is "stolen" trying to compile the hooks and torch.compile fails.

If you run

from accelerate.hooks import CpuOffload
CpuOffload.pre_forward = torch.compiler.disable(CpuOffload.pre_forward)
CpuOffload.post_forward = torch.compiler.disable(CpuOffload.post_forward)

at the very beginning, torch.compile succeeds, no breaks, and predictably, a 20GB in VRAM weights model with nice triton matmul in int8 will be super fast (it's almost 3x faster than bf16 with async - we should really call it concurrent - offload on 24GB).

This will not reproduce on a 40GB VRAM GPU. Such a GPU just happens to have enough VRAM for accelerate's hooks to steal on the loading step, and is totally uninteresting.

This will not reproduce with quantizations that cannot be torch.compiled end-to-end. So if it's not triton, it won't repro. I am not sure which quantizations will work that aren't sdnq, but sdnq is very good and should be upstreamed anyway.

Expected behavior

Accelerate's hooks should never be compiled, which will make torch.compilation work much better in many many cases in the scenarios where they are used.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions