Skip to content

Commit 3e72979

Browse files
committed
improve tests
1 parent 77a3456 commit 3e72979

File tree

2 files changed

+79
-9
lines changed

2 files changed

+79
-9
lines changed

docs/source/en/quantization/torchao.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]
2525
The example below only quantizes the weights to int8.
2626

2727
```python
28+
import torch
2829
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
2930

3031
model_id = "black-forest-labs/FLUX.1-dev"
@@ -44,6 +45,10 @@ pipe = FluxPipeline.from_pretrained(
4445
)
4546
pipe.to("cuda")
4647

48+
# Without quantization: ~31.447 GB
49+
# With quantization: ~20.40 GB
50+
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
51+
4752
prompt = "A cat holding a sign that says hello world"
4853
image = pipe(
4954
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512

tests/quantization/torchao/test_torchao.py

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -443,21 +443,29 @@ def test_memory_footprint(self):
443443
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
444444
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
445445

446-
self.assertTrue(
447-
isinstance(transformer_int4wo.transformer_blocks[0].ff.net[2].weight, AffineQuantizedTensor)
448-
)
449-
self.assertTrue(
450-
isinstance(transformer_int4wo_gs32.transformer_blocks[0].ff.net[2].weight, AffineQuantizedTensor)
451-
)
452-
self.assertTrue(
453-
isinstance(transformer_int8wo.transformer_blocks[0].ff.net[2].weight, AffineQuantizedTensor)
454-
)
446+
# Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64
447+
for block in transformer_int4wo.transformer_blocks:
448+
self.assertTrue(isinstance(block.ff.net[2].weight, AffineQuantizedTensor))
449+
self.assertTrue(isinstance(block.ff_context.net[2].weight, AffineQuantizedTensor))
450+
451+
# Will quantize all the linear layers except x_embedder
452+
for name, module in transformer_int4wo_gs32.named_modules():
453+
if name == "x_embedder":
454+
print(module)
455+
if isinstance(module, nn.Linear) and name not in ["x_embedder"]:
456+
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
457+
458+
# Will quantize all the linear layers
459+
for module in transformer_int8wo.modules():
460+
if isinstance(module, nn.Linear):
461+
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
455462

456463
total_int4wo = get_model_size_in_bytes(transformer_int4wo)
457464
total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32)
458465
total_int8wo = get_model_size_in_bytes(transformer_int8wo)
459466
total_bf16 = get_model_size_in_bytes(transformer_bf16)
460467

468+
# TODO: refactor to align with other quantization tests
461469
# Latter has smaller group size, so more groups -> more scales and zero points
462470
self.assertTrue(total_int4wo < total_int4wo_gs32)
463471
# int8 quantizes more layers compare to int4 with default group size
@@ -735,3 +743,60 @@ def test_memory_footprint_int8wo(self):
735743
)
736744
int8wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3
737745
self.assertTrue(int8wo_memory_in_gb < expected_memory_in_gb)
746+
747+
748+
@require_torch
749+
@require_torch_gpu
750+
@require_torchao_version_greater_or_equal("0.7.0")
751+
@slow
752+
@nightly
753+
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
754+
def tearDown(self):
755+
gc.collect()
756+
torch.cuda.empty_cache()
757+
758+
def get_dummy_inputs(self, device: torch.device, seed: int = 0):
759+
if str(device).startswith("mps"):
760+
generator = torch.manual_seed(seed)
761+
else:
762+
generator = torch.Generator().manual_seed(seed)
763+
764+
inputs = {
765+
"prompt": "an astronaut riding a horse in space",
766+
"height": 512,
767+
"width": 512,
768+
"num_inference_steps": 20,
769+
"output_type": "np",
770+
"generator": generator,
771+
}
772+
773+
return inputs
774+
775+
def test_transformer_int8wo(self):
776+
# fmt: off
777+
expected_slice = np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])
778+
# fmt: on
779+
780+
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing
781+
cache_dir = None
782+
transformer = FluxTransformer2DModel.from_pretrained(
783+
"hf-internal-testing/FLUX.1-Dev-TorchAO-int8wo-transformer",
784+
torch_dtype=torch.bfloat16,
785+
use_safetensors=False,
786+
cache_dir=cache_dir,
787+
)
788+
pipe = FluxPipeline.from_pretrained(
789+
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16, cache_dir=cache_dir
790+
)
791+
pipe.enable_model_cpu_offload()
792+
793+
# Verify that all linear layer weights are quantized
794+
for name, module in pipe.transformer.named_modules():
795+
if isinstance(module, nn.Linear):
796+
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
797+
798+
# Verify outputs match expected slice
799+
inputs = self.get_dummy_inputs(torch_device)
800+
output = pipe(**inputs)[0].flatten()
801+
output_slice = np.concatenate((output[:16], output[-16:]))
802+
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))

0 commit comments

Comments
 (0)