Skip to content

Commit 10deb16

Browse files
committed
add more tests; add expected slices
1 parent 25d3cf8 commit 10deb16

File tree

2 files changed

+200
-32
lines changed

2 files changed

+200
-32
lines changed

tests/quantization/torchao/README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
The tests here are adapted from [`transformers` tests](https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac174797063e0/tests/quantization/torchao_integration/).
22

3-
They were conducted on a single H100. Below is `nvidia-smi`:
3+
The benchmarks were run on a single H100. Below is `nvidia-smi`:
44

55
```bash
66
+---------------------------------------------------------------------------------------+
@@ -26,6 +26,12 @@ They were conducted on a single H100. Below is `nvidia-smi`:
2626

2727
The benchmark results for Flux and CogVideoX can be found in [this](https://github.com/huggingface/diffusers/pull/10009) PR.
2828

29+
The tests, and the expected slices, were obtained from the `aws-g6e-xlarge-plus` GPU test runners. To run the slow tests, use the following command or an equivalent:
30+
31+
```bash
32+
HF_HUB_ENABLE_HF_TRANSFER=1 RUN_SLOW=1 pytest -s tests/quantization/torchao/test_torchao.py::SlowTorchAoTests
33+
```
34+
2935
`diffusers-cli`:
3036

3137
```bash

tests/quantization/torchao/test_torchao.py

Lines changed: 193 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,20 @@
3030
)
3131
from diffusers.models.attention_processor import Attention
3232
from diffusers.utils.testing_utils import (
33+
enable_full_determinism,
3334
is_torch_available,
3435
is_torchao_available,
3536
require_torch,
3637
require_torch_gpu,
3738
require_torchao_version_greater,
39+
slow,
3840
torch_device,
3941
)
4042

4143

44+
enable_full_determinism()
45+
46+
4247
if is_torch_available():
4348
import torch
4449
import torch.nn as nn
@@ -101,9 +106,21 @@ def test_repr(self):
101106
Check that there is no error in the repr
102107
"""
103108
quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8)
104-
repr(quantization_config)
105-
106-
109+
expected_repr = """TorchAoConfig {
110+
"modules_to_not_convert": [
111+
"conv"
112+
],
113+
"quant_method": "torchao",
114+
"quant_type": "int4_weight_only",
115+
"quant_type_kwargs": {
116+
"group_size": 8
117+
}
118+
}""".replace(" ", "").replace("\n", "")
119+
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
120+
self.assertEqual(quantization_repr, expected_repr)
121+
122+
123+
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
107124
@require_torch
108125
@require_torch_gpu
109126
@require_torchao_version_greater("0.6.0")
@@ -202,32 +219,44 @@ def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: L
202219
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
203220

204221
def test_quantization(self):
205-
# TODO(aryan): update these values from our CI
222+
# fmt: off
206223
QUANTIZATION_TYPES_TO_TEST = [
207-
("int4wo", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
208-
("int4dq", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
209-
("int8wo", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
210-
("int8dq", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
211-
("uint4wo", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
212-
("int_a8w8", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
213-
("uint_a16w7", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
224+
("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
225+
("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
226+
("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
227+
("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
228+
("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
229+
("int_a8w8", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
230+
("uint_a16w7", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
214231
]
215232

216233
if TorchAoConfig._is_cuda_capability_atleast_8_9():
217-
QUANTIZATION_TYPES_TO_TEST.extend(
218-
[
219-
("float8wo_e5m2", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
220-
("float8wo_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
221-
("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
222-
("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
223-
("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
224-
("fp4wo", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
225-
("fp6", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
226-
]
227-
)
234+
QUANTIZATION_TYPES_TO_TEST.extend([
235+
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
236+
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
237+
# =====
238+
# The following lead to an internal torch error:
239+
# RuntimeError: mat2 shape (32x4 must be divisible by 16
240+
# Skip these for now; TODO(aryan): investigate later
241+
# ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
242+
# ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
243+
# =====
244+
# Cutlass fails to initialize for below
245+
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
246+
# =====
247+
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
248+
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
249+
])
250+
# fmt: on
228251

229252
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
230-
quantization_config = TorchAoConfig(quant_type=quantization_name)
253+
quant_kwargs = {}
254+
if quantization_name in ["uint4wo", "uint_a16w7"]:
255+
# The dummy flux model that we use requires us to impose some restrictions on group_size here
256+
quant_kwargs.update({"group_size": 16})
257+
quantization_config = TorchAoConfig(
258+
quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs
259+
)
231260
self._test_quant_type(quantization_config, expected_slice)
232261

233262
def test_int4wo_quant_bfloat16_conversion(self):
@@ -277,10 +306,9 @@ def test_offload(self):
277306
)
278307

279308
output = quantized_model(**inputs)[0]
280-
281309
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
282-
# TODO(aryan): get slice from CI
283-
expected_slice = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])
310+
311+
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
284312
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
285313

286314
def test_modules_to_not_convert(self):
@@ -333,6 +361,7 @@ def test_training(self):
333361
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
334362

335363
def test_torch_compile(self):
364+
r"""Test that verifies if torch.compile works with torchao quantization."""
336365
quantization_config = TorchAoConfig("int8_weight_only")
337366
components = self.get_dummy_components(quantization_config)
338367
pipe = FluxPipeline(**components)
@@ -348,7 +377,54 @@ def test_torch_compile(self):
348377
# Note: Seems to require higher tolerance
349378
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3))
350379

380+
@staticmethod
381+
def _get_memory_footprint(module):
382+
quantized_param_memory = 0.0
383+
unquantized_param_memory = 0.0
384+
385+
for param in module.parameters():
386+
if param.__class__.__name__ == "AffineQuantizedTensor":
387+
data, scale, zero_point = param.layout_tensor.get_plain()
388+
quantized_param_memory += data.numel() + data.element_size()
389+
quantized_param_memory += scale.numel() + scale.element_size()
390+
quantized_param_memory += zero_point.numel() + zero_point.element_size()
391+
else:
392+
unquantized_param_memory += param.data.numel() * param.data.element_size()
393+
394+
total_memory = quantized_param_memory + unquantized_param_memory
395+
return total_memory, quantized_param_memory, unquantized_param_memory
396+
397+
def test_memory_footprint(self):
398+
r"""
399+
A simple test to check if the model conversion has been done correctly by checking on the
400+
memory footprint of the converted model and the class type of the linear layers of the converted models
401+
"""
402+
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"))["transformer"]
403+
transformer_int4wo_gs32 = self.get_dummy_components(TorchAoConfig("int4wo", group_size=32))["transformer"]
404+
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"]
405+
transformer_bf16 = self.get_dummy_components(None)["transformer"]
406+
407+
total_int4wo, quantized_int4wo, unquantized_int4wo = self._get_memory_footprint(transformer_int4wo)
408+
total_int4wo_gs32, quantized_int4wo_gs32, unquantized_int4wo_gs32 = self._get_memory_footprint(
409+
transformer_int4wo_gs32
410+
)
411+
total_int8wo, quantized_int8wo, unquantized_int8wo = self._get_memory_footprint(transformer_int8wo)
412+
total_bf16, quantized_bf16, unquantized_bf16 = self._get_memory_footprint(transformer_bf16)
351413

414+
self.assertTrue(quantized_bf16 == 0 and total_bf16 == unquantized_bf16)
415+
# int4wo_gs32 has smaller group size, so more groups -> more scales and zero points
416+
self.assertTrue(total_int8wo < total_bf16 < total_int4wo_gs32)
417+
# int4 with default group size quantized very few linear layers compared to a smaller group size of 32
418+
self.assertTrue(quantized_int4wo < quantized_int4wo_gs32 and unquantized_int4wo > unquantized_int4wo_gs32)
419+
# int8 quantizes more layers compare to int4 with default group size
420+
self.assertTrue(quantized_int8wo < quantized_int4wo)
421+
422+
def test_wrong_config(self):
423+
with self.assertRaises(ValueError):
424+
self.get_dummy_components(TorchAoConfig("int42"))
425+
426+
427+
# This class is not to be run as a test by itself. See the tests that follow this class
352428
@require_torch
353429
@require_torch_gpu
354430
@require_torchao_version_greater("0.6.0")
@@ -371,14 +447,15 @@ def get_dummy_model(self, device=None):
371447
)
372448
return quantized_model.to(device)
373449

374-
def get_dummy_tensor_inputs(self, device=None):
450+
def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
375451
batch_size = 1
376452
num_latent_channels = 4
377453
num_image_channels = 3
378454
height = width = 4
379455
sequence_length = 48
380456
embedding_dim = 32
381457

458+
torch.manual_seed(seed)
382459
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16)
383460
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
384461
device, dtype=torch.bfloat16
@@ -425,27 +502,112 @@ def test_serialization_expected_slice(self):
425502

426503
class TorchAoSerializationINTA8W8Test(TorchAoSerializationTest):
427504
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
428-
expected_slice = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])
505+
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
429506
serialized_expected_slice = expected_slice
430507
device = "cuda"
431508

432509

433510
class TorchAoSerializationINTA16W8Test(TorchAoSerializationTest):
434511
quant_method, quant_method_kwargs = "int8_weight_only", {}
435-
expected_slice = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])
512+
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
436513
serialized_expected_slice = expected_slice
437514
device = "cuda"
438515

439516

440517
class TorchAoSerializationINTA8W8CPUTest(TorchAoSerializationTest):
441518
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
442-
expected_slice = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])
519+
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
443520
serialized_expected_slice = expected_slice
444521
device = "cpu"
445522

446523

447524
class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest):
448525
quant_method, quant_method_kwargs = "int8_weight_only", {}
449-
expected_slice = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])
526+
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
450527
serialized_expected_slice = expected_slice
451528
device = "cpu"
529+
530+
531+
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
532+
@require_torch
533+
@require_torch_gpu
534+
@require_torchao_version_greater("0.6.0")
535+
@slow
536+
class SlowTorchAoTests(unittest.TestCase):
537+
def tearDown(self):
538+
gc.collect()
539+
torch.cuda.empty_cache()
540+
541+
def get_dummy_components(self, quantization_config: TorchAoConfig):
542+
model_id = "black-forest-labs/FLUX.1-dev"
543+
transformer = FluxTransformer2DModel.from_pretrained(
544+
model_id,
545+
subfolder="transformer",
546+
quantization_config=quantization_config,
547+
torch_dtype=torch.bfloat16,
548+
)
549+
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
550+
text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
551+
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
552+
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
553+
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
554+
scheduler = FlowMatchEulerDiscreteScheduler()
555+
556+
return {
557+
"scheduler": scheduler,
558+
"text_encoder": text_encoder,
559+
"text_encoder_2": text_encoder_2,
560+
"tokenizer": tokenizer,
561+
"tokenizer_2": tokenizer_2,
562+
"transformer": transformer,
563+
"vae": vae,
564+
}
565+
566+
def get_dummy_inputs(self, device: torch.device, seed: int = 0):
567+
if str(device).startswith("mps"):
568+
generator = torch.manual_seed(seed)
569+
else:
570+
generator = torch.Generator().manual_seed(seed)
571+
572+
inputs = {
573+
"prompt": "an astronaut riding a horse in space",
574+
"height": 512,
575+
"width": 512,
576+
"num_inference_steps": 20,
577+
"output_type": "np",
578+
"generator": generator,
579+
}
580+
581+
return inputs
582+
583+
def _test_quant_type(self, quantization_config, expected_slice):
584+
components = self.get_dummy_components(quantization_config)
585+
pipe = FluxPipeline(**components).to(dtype=torch.bfloat16)
586+
pipe.enable_model_cpu_offload()
587+
588+
inputs = self.get_dummy_inputs(torch_device)
589+
output = pipe(**inputs)[0].flatten()
590+
output_slice = np.concatenate((output[:16], output[-16:]))
591+
592+
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
593+
594+
def test_quantization(self):
595+
# fmt: off
596+
QUANTIZATION_TYPES_TO_TEST = [
597+
("int8wo", np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])),
598+
("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])),
599+
]
600+
601+
if TorchAoConfig._is_cuda_capability_atleast_8_9():
602+
QUANTIZATION_TYPES_TO_TEST.extend([
603+
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
604+
("fp5_e3m1", np.array([0.0527, 0.0742, 0.1289, 0.0449, 0.0625, 0.1308, 0.0585, 0.0742, 0.1269, 0.0585, 0.0722, 0.1328, 0.0566, 0.0742, 0.1347, 0.0585, 0.3691, 0.7578, 0.5429, 0.4355, 0.7695, 0.5546, 0.4414, 0.7578, 0.5468, 0.4179, 0.7265, 0.5273, 0.3945, 0.6992, 0.5234, 0.4316])),
605+
])
606+
# fmt: on
607+
608+
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
609+
quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"])
610+
self._test_quant_type(quantization_config, expected_slice)
611+
gc.collect()
612+
torch.cuda.empty_cache()
613+
torch.cuda.synchronize()

0 commit comments

Comments
 (0)