Skip to content

Conversation

@MekkCyber
Copy link

What does this PR do?

Adds finegrained FP8

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member

Just for bookkeeping, relaying stuff from our DM.

I had to make the following changes to make this PR work:

Expand
diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py
index 638c5fbfb..737525143 100644
--- a/src/diffusers/models/modeling_utils.py
+++ b/src/diffusers/models/modeling_utils.py
@@ -1238,8 +1238,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
         }
 
         # Dispatch model with hooks on all devices if necessary
-        print(model.transformer_blocks[0].attn.to_q.weight)
-        print(model.transformer_blocks[0].attn.to_q.weight_scale_inv)
+        # print(model.transformer_blocks[0].attn.to_q.weight)
+        # print(model.transformer_blocks[0].attn.to_q.weight_scale_inv)
         if device_map is not None:
             device_map_kwargs = {
                 "device_map": device_map,
diff --git a/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py b/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py
index 5dec8b0b8..7212befcd 100644
--- a/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py
+++ b/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py
@@ -90,9 +90,9 @@ class FinegrainedFP8Quantizer(DiffusersQuantizer):
         Quantizes weights to FP8 format using Block-wise quantization
         """
         # print("############ create quantized param ########")
-        from accelerate.utils import set_module_tensor_to_device
+        # from accelerate.utils import set_module_tensor_to_device
 
-        set_module_tensor_to_device(model, param_name, target_device, param_value)
+        # set_module_tensor_to_device(model, param_name, target_device, param_value)
 
         module, tensor_name = get_module_from_name(model, param_name)
 
@@ -131,8 +131,8 @@ class FinegrainedFP8Quantizer(DiffusersQuantizer):
         scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
 
         # Load into the model
-        module._buffers[tensor_name] = quantized_param.to(target_device)
-        module._buffers["weight_scale_inv"] = scale.to(target_device)
+        module._parameters[tensor_name] = quantized_param.to(target_device)
+        module._parameters["weight_scale_inv"] = scale.to(target_device)
         # print("_buffers[0]", module._buffers["weight_scale_inv"])
 
     def check_if_quantized_param(

Inference code:

import torch
from diffusers import FluxPipeline, AutoModel, FinegrainedFP8Config
from diffusers.quantizers.finegrained_fp8.utils import FP8Linear

model_id = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16

quantization_config = FinegrainedFP8Config(
    modules_to_not_convert=["norm", "proj_out", "x_embedder"], # weight_block_size=(32, 32)
)
transformer = AutoModel.from_pretrained(
    model_id,
    subfolder="transformer",
    quantization_config=quantization_config,
    torch_dtype=dtype,
    device_map="cuda"
)
pipe = FluxPipeline.from_pretrained(
    model_id,
    transformer=transformer,
    torch_dtype=dtype,
)
pipe.to("cuda")

for name, module in pipe.transformer.named_modules():
    if isinstance(module, FP8Linear) and getattr(module, "weight_scale_inv", None) is not None:
        if module.weight_scale_inv.ndim == 1:
            print(name, module.weight_scale_inv.shape)


print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")

prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
).images[0]
image.save("output.png")
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")

The modules_to_not_convert includes proj_out and x_embedder because otherwise, we violate the shape constraint on scale (scale.ndim == 2).

@sayakpaul sayakpaul requested review from DN6 and sayakpaul and removed request for sayakpaul June 16, 2025 08:12
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for starting this! Would be nice to also have some benchmarks:

  • With and without finegrained FP8 quant (with visual outputs)
  • With and without torch.compile

Comment on lines 484 to 498
def _check_serialization_expected_slice(self, expected_slice, device):
quantized_model = self.get_dummy_model(device)

with tempfile.TemporaryDirectory() as tmp_dir:
quantized_model.save_pretrained(tmp_dir, safe_serialization=False)
loaded_quantized_model = FluxTransformer2DModel.from_pretrained(
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
).to(device=torch_device)

inputs = self.get_dummy_tensor_inputs(torch_device)
output = loaded_quantized_model(**inputs)[0]

output_slice = output.flatten()[-9:].detach().float().cpu().numpy()

self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think instead of delegating certain calls to other methods, we can have all of the implementations under this one. This way, everything remains self-contained. Furthermore, since this test class doesn't have other tests, we don't have to modularize too much.

WDYT?

@sayakpaul sayakpaul requested a review from SunMarc June 16, 2025 12:24
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MekkCyber LMK if there are any unresolved comments on my part. Excited to see this getting merged soon!

Comment on lines +10 to +14
# FinegrainedFP8

## Overview

## Usage
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just get an example for this, and we can merge?


module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, FP8Linear):
if self.pre_quantized or tensor_name == "bias":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checks if the param needs to be quantized ?

This.

Example:

elif is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
):
hf_quantizer.create_quantized_param(
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
)
else:
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)

I think this is because during this step, we call

hf_quantizer.preprocess_model(

which does:

def _process_model_before_weight_loading(

(the job replacing the regular linears with bnb linears)

LMK if you need more clarifications.

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good 👍🏽 Just one minor comment on importing the quant config. And I think the test file needs clean up. Thanks!



# All tests were conducted on a H100 instance
@require_torch
Copy link
Member

@sayakpaul sayakpaul Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just require_torch_gpu() should be enough no?

Additionally, should we not also add something like require_torch_gpu_capability(major=8, minor=9) (at least to the execution tests where we actually run the models)?

gc.collect()
torch.cuda.empty_cache()

def get_dummy_components(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we cannot do pipe.from_pretrained(model_id) here?


return inputs

def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we using it anywhere?

Comment on lines +267 to +269
unquantized_layer = quantized_model_with_not_convert.transformer_blocks[0].ff.net[2]
self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear))
self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could test all the modules from modules_to_not_convert to go through these assertions?

@nightly
def test_torch_compile(self):
r"""Test that verifies if torch.compile works with fp8 quantization."""
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why we need to test for two model ids here?

transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
transformer_bf16.to(torch_device)
unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs)
del transformer_bf16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might wanna call .cpu() to free the GPU VRAM.


@require_torch
@require_torch_gpu
class SlowTorchAoTests(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To rename.

Comment on lines +444 to +459
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32

torch.manual_seed(seed)
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
device, dtype=torch.bfloat16
)
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we could just test the actual pipeline. That will reduce the code bloat a bit.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments. As an environment validation step, we could also detect the if the user GPU has at least 8.9 compute capability and error out if needed?

@sayakpaul
Copy link
Member

Benchmarked finegrained FP8 with torchao FP8:

{"time": 18.975, "memory": 24.484153747558594, "quant_type": "finegrained"}
{"time": 6.625, "memory": 22.85780096054077, "quant_type": "torchao"}
Method Visualization
Finegrained Finegrained
TorchAO TorchAO
Code
import torch
torch.set_grad_enabled(False)

import torch.utils.benchmark as benchmark
import argparse
import json
from diffusers import FluxPipeline
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers import TorchAoConfig, FinegrainedFP8Config

def benchmark_fn(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)",
        globals={"args": args, "kwargs": kwargs, "f": f},
        num_threads=1,
    )
    return float(f"{(t0.blocked_autorange().mean):.3f}")

def get_pipeline(quant_type="int8wo"):
    dtype = torch.bfloat16
    model_id = "black-forest-labs/FLUX.1-dev"
    if quant_type == "torchao":
        quant_config = TorchAoConfig(quant_type="float8dq_e4m3_row")
    else:
        quant_config = FinegrainedFP8Config(
            modules_to_not_convert=["x_embedder", "proj_out"], weight_block_size=(64, 64)
        )
    pipeline_quant_config = PipelineQuantizationConfig(
        quant_mapping={"transformer": quant_config}
    )
    pipe = FluxPipeline.from_pretrained(
        model_id, quantization_config=pipeline_quant_config, torch_dtype=dtype
    ).to("cuda")
    pipe.transformer.compile(fullgraph=True)
    pipe.set_progress_bar_config(disable=True)
    return pipe

def run_inference(pipe, pipe_kwargs):
    _ = pipe(**pipe_kwargs)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--quant_type", type=str, default="torchao", choices=["torchao", "finegrained"])
    args = parser.parse_args()

    pipe = get_pipeline(quant_type=args.quant_type)
    pipe_kwargs = {
        "prompt": "A cat holding a sign that says hello world",
        "height": 1024,
        "width": 1024,
        "guidance_scale": 3.5,
        "num_inference_steps": 50,
        "max_sequence_length": 512,
        "generator": torch.manual_seed(0)
    }
    time = benchmark_fn(run_inference, pipe, pipe_kwargs)
    inference_memory = torch.cuda.max_memory_allocated() / (1024 ** 3)
    image = pipe(**pipe_kwargs).images[0]
    
    artifact_dict = {"time": time, "memory": inference_memory}
    artifact_dict.update(vars(args))
    file_prefix = f"quant@{args.quant_type}"
    image.save(f"{file_prefix}.png")
    with open(f"{file_prefix}.json", "w") as f:
        json.dump(artifact_dict, f)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants