Skip to content

Commit 7d9d1dc

Browse files
committed
address review comments
1 parent b227189 commit 7d9d1dc

File tree

3 files changed

+14
-61
lines changed

3 files changed

+14
-61
lines changed

docs/source/en/quantization/torchao.md

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ pipe = FluxPipeline.from_pretrained(
4545
pipe.to("cuda")
4646

4747
prompt = "A cat holding a sign that says hello world"
48-
image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
48+
image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0]
4949
image.save("output.png")
5050
```
5151

@@ -75,27 +75,12 @@ Dynamic activation quantization stores the model weights in a low-bit dtype, whi
7575

7676
The quantization methods supported are as follows:
7777

78-
- **Integer quantization:**
79-
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight`
80-
- Shorthands: `int4wo`, `int4dq`, `int8wo`, `int8dq`
81-
- Documentation shorthands/Common speak: `int_a16w4`, `int_a8w4`, `int_a16w8`, `int_a8w8`
82-
83-
- **Floating point 8-bit quantization:**
84-
- Full function names: `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight`
85-
- Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row`, `float8sq`
86-
- Documentation shorthands/Common speak: `float8_e5m2_a16w8`, `float8_e4m3_a16w8`, `float_a8w8`, `float_a16w8`
87-
88-
- **Floating point X-bit quantization:**
89-
- Full function names: `fpx_weight_only`
90-
- Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must be satisfied for a given shorthand notation.
91-
- Documentation shorthands/Common speak: `float_a16w3`, `float_a16w4`, `float_a16w5`, `float_a16w6`, `float_a16w7`, `float_a16w8`
92-
93-
- **Unsigned Integer quantization:**
94-
- Full function names: `uintx_weight_only`
95-
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
96-
- Documentation shorthands/Common speak: `uint_a16w1`, `uint_a16w2`, `uint_a16w3`, `uint_a16w4`, `uint_a16w5`, `uint_a16w6`, `uint_a16w7`
97-
98-
The "Documentation shorthands/Common speak" refers to the underlying storage dtype with the number of bits for storing activations and weights, respectively. For example, int_a16w8 stores the activations in 16-bit and the weights in 8-bit.
78+
| **Category** | **Full Function Names** | **Shorthands** |
79+
|--------------|-------------------------|----------------|
80+
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
81+
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row` |
82+
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
83+
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |
9984

10085
Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.
10186

src/diffusers/quantizers/quantization_config.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -405,29 +405,22 @@ class TorchAoConfig(QuantizationConfigMixin):
405405
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
406406
`int8_weight_only`, `int8_dynamic_activation_int8_weight`
407407
- Shorthands: `int4wo`, `int4dq`, `int8wo`, `int8dq`
408-
- Documentation shorthands/Common speak: `int_a16w4`, `int_a8w4`, `int_a16w8`, `int_a8w8`
409408
410409
- **Floating point 8-bit quantization:**
411410
- Full function names: `float8_weight_only`, `float8_dynamic_activation_float8_weight`,
412411
`float8_static_activation_float8_weight`
413412
- Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`,
414-
`float8_e4m3_tensor`, `float8_e4m3_row`, `float8sq`
415-
- Documentation shorthands/Common speak: `float8_e5m2_a16w8`, `float8_e4m3_a16w8`, `float_a8w8`,
416-
`float_a16w8`
413+
`float8_e4m3_tensor`, `float8_e4m3_row`,
417414
418415
- **Floating point X-bit quantization:**
419416
- Full function names: `fpx_weight_only`
420417
- Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number
421418
of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must
422419
be satisfied for a given shorthand notation.
423-
- Documentation shorthands/Common speak: `float_a16w3`, `float_a16w4`, `float_a16w5`,
424-
`float_a16w6`, `float_a16w7`, `float_a16w8`
425420
426421
- **Unsigned Integer quantization:**
427422
- Full function names: `uintx_weight_only`
428423
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
429-
- Documentation shorthands/Common speak: `uint_a16w1`, `uint_a16w2`, `uint_a16w3`, `uint_a16w4`,
430-
`uint_a16w5`, `uint_a16w6`, `uint_a16w7`
431424
modules_to_not_convert (`List[str]`, *optional*, default to `None`):
432425
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
433426
modules left in their original precision.
@@ -584,7 +577,6 @@ def generate_fpx_quantization_types(bits: int):
584577
**generate_float8dq_types(torch.float8_e4m3fn),
585578
# float8 weight + float8 activation (static)
586579
"float8_static_activation_float8_weight": float8_static_activation_float8_weight,
587-
"float8sq": float8_static_activation_float8_weight,
588580
# For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly
589581
# fpx weight + bfloat16/float16 activation
590582
**generate_fpx_quantization_types(3),
@@ -606,42 +598,13 @@ def generate_fpx_quantization_types(bits: int):
606598
# "uint8wo": partial(uintx_weight_only, dtype=torch.uint8), # uint8 quantization is not supported
607599
}
608600

609-
SHORTHAND_QUANTIZATION_TYPES = {
610-
"int_a16w4": int4_weight_only,
611-
"int_a8w4": int8_dynamic_activation_int4_weight,
612-
"int_a16w8": int8_weight_only,
613-
"int_a8w8": int8_dynamic_activation_int8_weight,
614-
"uint_a16w1": partial(uintx_weight_only, dtype=torch.uint1),
615-
"uint_a16w2": partial(uintx_weight_only, dtype=torch.uint2),
616-
"uint_a16w3": partial(uintx_weight_only, dtype=torch.uint3),
617-
"uint_a16w4": partial(uintx_weight_only, dtype=torch.uint4),
618-
"uint_a16w5": partial(uintx_weight_only, dtype=torch.uint5),
619-
"uint_a16w6": partial(uintx_weight_only, dtype=torch.uint6),
620-
"uint_a16w7": partial(uintx_weight_only, dtype=torch.uint7),
621-
# "uint_a16w8": partial(uintx_weight_only, dtype=torch.uint8), # uint8 quantization is not supported
622-
}
623-
624-
SHORTHAND_FLOAT_QUANTIZATION_TYPES = {
625-
"float_e5m2_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
626-
"float_e4m3_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn),
627-
"float_a8w8": float8_dynamic_activation_float8_weight,
628-
"float_a16w3": partial(fpx_weight_only, ebits=2, mbits=0),
629-
"float_a16w4": partial(fpx_weight_only, ebits=2, mbits=1),
630-
"float_a16w5": partial(fpx_weight_only, ebits=3, mbits=1),
631-
"float_a16w6": partial(fpx_weight_only, ebits=3, mbits=2),
632-
"float_a16w7": partial(fpx_weight_only, ebits=4, mbits=2),
633-
"float_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
634-
}
635-
636601
QUANTIZATION_TYPES = {}
637602
QUANTIZATION_TYPES.update(INT4_QUANTIZATION_TYPES)
638603
QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
639604
QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)
640-
QUANTIZATION_TYPES.update(SHORTHAND_QUANTIZATION_TYPES)
641605

642606
if cls._is_cuda_capability_atleast_8_9():
643607
QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)
644-
QUANTIZATION_TYPES.update(SHORTHAND_FLOAT_QUANTIZATION_TYPES)
645608

646609
return QUANTIZATION_TYPES
647610
else:

tests/quantization/torchao/test_torchao.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
enable_full_determinism,
3434
is_torch_available,
3535
is_torchao_available,
36+
nightly,
3637
require_torch,
3738
require_torch_gpu,
3839
require_torchao_version_greater,
@@ -280,7 +281,8 @@ def test_int4wo_quant_bfloat16_conversion(self):
280281

281282
def test_offload(self):
282283
"""
283-
Test if the quantized model int4 weight-only is working properly with cpu/disk offload.
284+
Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies
285+
that the device map is correctly set (in the `hf_device_map` attribute of the model).
284286
"""
285287

286288
device_map_offload = {
@@ -306,6 +308,8 @@ def test_offload(self):
306308
offload_folder=offload_folder,
307309
)
308310

311+
self.assertTrue(quantized_model.hf_device_map == device_map_offload)
312+
309313
output = quantized_model(**inputs)[0]
310314
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
311315

@@ -539,6 +543,7 @@ class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest):
539543
@require_torch_gpu
540544
@require_torchao_version_greater("0.6.0")
541545
@slow
546+
@nightly
542547
class SlowTorchAoTests(unittest.TestCase):
543548
def tearDown(self):
544549
gc.collect()

0 commit comments

Comments
 (0)