Skip to content

Conversation

@jiqing-feng
Copy link
Contributor

@jiqing-feng jiqing-feng commented Oct 24, 2025

The test pytest -rA tests/quantization/torchao/test_torchao.py::TorchAoTest::test_model_memory_usage failed with

>       assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio
E       assert (1416704 / 1382912) >= 2.0
tests/quantization/torchao/test_torchao.py:512: AssertionError                                                                       

on A100. I guess it is because the model is too small, most memories are consumed on cuda kernel launch instead of model weight. If we change it to a large model like black-forest-labs/FLUX.1-dev, the ratio will be 24244073472 / 12473665536 = 1.9436206143278139

@sayakpaul . Please review this PR. Thanks!

@jiqing-feng
Copy link
Contributor Author

For the black-forest-labs/FLUX.1-dev script, you can run

import torch
import os
from diffusers import FluxTransformer2DModel
from transformers import TorchAoConfig

torch.use_deterministic_algorithms(True)

def get_dummy_tensor_inputs(device=None, seed: int = 0):
    batch_size = 1
    num_latent_channels = 4096
    num_image_channels = 3
    height = width = 8
    sequence_length = 48
    embedding_dim = 768
    torch.manual_seed(seed)
    hidden_states = torch.randn((batch_size, num_latent_channels, height*width)).to(device, dtype=torch.bfloat16)
    torch.manual_seed(seed)
    encoder_hidden_states = torch.randn((batch_size, sequence_length, num_latent_channels)).to(
        device, dtype=torch.bfloat16
    )
    torch.manual_seed(seed)
    pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
    torch.manual_seed(seed)
    text_ids = torch.randn((num_latent_channels, num_image_channels)).to(device, dtype=torch.bfloat16)
    torch.manual_seed(seed)
    image_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
    timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
    return {
        "hidden_states": hidden_states,
        "encoder_hidden_states": encoder_hidden_states,
        "pooled_projections": pooled_prompt_embeds,
        "txt_ids": text_ids,
        "img_ids": image_ids,
        "timestep": timestep,
        "guidance": timestep * 3054,
    }

@torch.no_grad()
@torch.inference_mode()
def get_memory_consumption_stat(model, inputs):
    device_module.reset_peak_memory_stats()
    device_module.empty_cache()
    model(**inputs)
    max_mem_allocated = device_module.max_memory_allocated()
    return max_mem_allocated

torch_device = "xpu" if torch.xpu.is_available() else "cuda"
if torch_device == "cuda":
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
device_module = torch.xpu if torch.xpu.is_available() else torch.cuda
model_id = "black-forest-labs/FLUX.1-dev"
print(f"max allocated memory before loading: {device_module.max_memory_allocated()}")
inputs = get_dummy_tensor_inputs(device=torch_device)
print(f"max allocated memory after get inputs: {device_module.max_memory_allocated()}")
transformer = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer", quantization_config=None, torch_dtype=torch.bfloat16).to(torch_device)
print(f"max allocated memory after get bf16 model: {device_module.max_memory_allocated()}")

with torch.no_grad(), torch.inference_mode():
    transformer(**inputs)
print(f"max allocated memory after bf16 model inference: {device_module.max_memory_allocated()}")

del transformer
device_module.reset_peak_memory_stats()
device_module.empty_cache()

transformer = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer", quantization_config=TorchAoConfig("int8_weight_only"), torch_dtype=torch.bfloat16).to(torch_device)
print(f"max allocated memory after get int8 model: {device_module.max_memory_allocated()}")

with torch.no_grad(), torch.inference_mode():
    transformer(**inputs)
print(f"max allocated memory after int8 model inference: {device_module.max_memory_allocated()}")

@sayakpaul
Copy link
Member

I just ran the test on an H100 and it worked fine.

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Oct 24, 2025

I just ran the test on an H100 and it worked fine.

It seems like a device-related issue. Can we change it to a big model so other devices can also work? Or change the ratio as in this PR. We want to make the test pass on XPU and A100

@jiqing-feng
Copy link
Contributor Author

I just ran the test on an H100 and it worked fine.

It seems like a device-related issue. Can we change it to a big model so other devices can also work? Or change the ratio as in this PR. We want to make the test pass on XPU and A100

Hi @sayakpaul , could you share the ratio on H100?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants