Skip to content

Commit 15e501a

Browse files
authored
New multi-step QAT API (#2629)
* [bc-breaking] Generalize FakeQuantizeConfig beyond intx **Summary:** The existing `FakeQuantizeConfig` performs only intx quantization, but we plan to extend QAT to other dtypes such as fp8 and nvfp4 in the near future. This is the necessary refactor before that. Specifically: ``` # New abstract class FakeQuantizeConfigBase # Rename FakeQuantizeConfig -> IntxFakeQuantizeConfig ``` In the future, we will have other types of `FakeQuantizeConfigBase` for float dtypes that users can pass in instead of the existing Intx one. **BC-breaking notes:** For BC, we keep around the old names to reference the new ones. However, this commit is still BC-breaking in the sense that a few APIs now accept the abstract `FakeQuantizeConfigBase` instead. For the most part, this abstract class will be hidden from the user. Before: ``` activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=32) ``` After: ``` activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) ``` **Test Plan:** python test/quantization/test_qat.py [ghstack-poisoned] * New multi-step QAT API **Summary:** This commit adds a new multi-step QAT API with the main goal of simplifying the existing UX. The new API uses the same `QATConfig` for both the prepare and convert steps, and automatically infers the fake quantization configs based on a PTQ base config provided by the user: ``` from torchao.quantization import ( quantize_, Int8DynamicActivationInt4WeightConfig ) from torchao.quantization.qat import QATConfig \# prepare base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) qat_config = QATConfig(base_config, step="prepare") quantize_(m, qat_config) \# train (not shown) \# convert quantize_(m, QATConfig(base_config, step="convert")) ``` The main improvements include: - A single config for both prepare and convert steps - A single quantize_ for convert (instead of 2) - No chance for incompatible prepare vs convert configs - Much less boilerplate code for most common use case - Simpler config names For less common use cases such as experimentation, users can still specify arbitrary fake quantization configs for activations and/or weights as before. This is still important since there may not always be a corresponding PTQ base config. For example: ``` from torchao.quantization import quantize_ from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = QATConfig( activation_config=activation_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config) \# train and convert same as above (not shown) ``` **BC-breaking notes:** This change by itself is technically not BC-breaking since we keep around the old path, but will become so when we deprecate and remove the old path in the future. Before: ``` \# prepare activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config), quantize_(model, qat_config) \# train (not shown) \# convert quantize_(model, FromIntXQuantizationAwareTrainingConfig()) quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) ``` After: (see above) **Test Plan:** ``` python test/quantization/test_qat.py ``` [ghstack-poisoned] * Update on "New multi-step QAT API" **Summary:** This commit adds a new multi-step QAT API with the main goal of simplifying the existing UX. The new API uses the same `QATConfig` for both the prepare and convert steps, and automatically infers the fake quantization configs based on a PTQ base config provided by the user: ``` from torchao.quantization import ( quantize_, Int8DynamicActivationInt4WeightConfig ) from torchao.quantization.qat import QATConfig # prepare base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) quantize_(m, QATConfig(base_config, step="prepare")) # train (not shown) # convert quantize_(m, QATConfig(base_config, step="convert")) ``` The main improvements include: - A single config for both prepare and convert steps - A single quantize_ for convert (instead of 2) - No chance for incompatible prepare vs convert configs - Much less boilerplate code for most common use case - Simpler config names For less common use cases such as experimentation, users can still specify arbitrary fake quantization configs for activations and/or weights as before. This is still important since there may not always be a corresponding PTQ base config. For example: ``` from torchao.quantization import quantize_ from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig # prepare activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = QATConfig( activation_config=activation_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config) # train and convert same as above (not shown) ``` **BC-breaking notes:** This change by itself is technically not BC-breaking since we keep around the old path, but will become so when we deprecate and remove the old path in the future. Before: ``` # prepare activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config), quantize_(model, qat_config) # train (not shown) # convert quantize_(model, FromIntXQuantizationAwareTrainingConfig()) quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) ``` After: (see above) **Test Plan:** ``` python test/quantization/test_qat.py ``` [ghstack-poisoned] * Update on "New multi-step QAT API" **Summary:** This commit adds a new multi-step QAT API with the main goal of simplifying the existing UX. The new API uses the same `QATConfig` for both the prepare and convert steps, and automatically infers the fake quantization configs based on a PTQ base config provided by the user: ``` from torchao.quantization import ( quantize_, Int8DynamicActivationInt4WeightConfig ) from torchao.quantization.qat import QATConfig # prepare base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) quantize_(m, QATConfig(base_config, step="prepare")) # train (not shown) # convert quantize_(m, QATConfig(base_config, step="convert")) ``` The main improvements include: - A single config for both prepare and convert steps - A single quantize_ for convert (instead of 2) - No chance for incompatible prepare vs convert configs - Much less boilerplate code for most common use case - Simpler config names For less common use cases such as experimentation, users can still specify arbitrary fake quantization configs for activations and/or weights as before. This is still important since there may not always be a corresponding PTQ base config. For example: ``` from torchao.quantization import quantize_ from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig # prepare activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = QATConfig( activation_config=activation_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config) # train and convert same as above (not shown) ``` **BC-breaking notes:** This change by itself is technically not BC-breaking since we keep around the old path, but will become so when we deprecate and remove the old path in the future. Before: ``` # prepare activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config), quantize_(model, qat_config) # train (not shown) # convert quantize_(model, FromIntXQuantizationAwareTrainingConfig()) quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) ``` After: (see above) **Test Plan:** ``` python test/quantization/test_qat.py ``` [ghstack-poisoned] * Update base for Update on "New multi-step QAT API" **Summary:** This commit adds a new multi-step QAT API with the main goal of simplifying the existing UX. The new API uses the same `QATConfig` for both the prepare and convert steps, and automatically infers the fake quantization configs based on a PTQ base config provided by the user: ``` from torchao.quantization import ( quantize_, Int8DynamicActivationInt4WeightConfig ) from torchao.quantization.qat import QATConfig \# prepare base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) qat_config = QATConfig(base_config, step="prepare") quantize_(m, qat_config) \# train (not shown) \# convert quantize_(m, QATConfig(base_config, step="convert")) ``` The main improvements include: - A single config for both prepare and convert steps - A single quantize_ for convert (instead of 2) - No chance for incompatible prepare vs convert configs - Much less boilerplate code for most common use case - Simpler config names For less common use cases such as experimentation, users can still specify arbitrary fake quantization configs for activations and/or weights as before. This is still important since there may not always be a corresponding PTQ base config. For example: ``` from torchao.quantization import quantize_ from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = QATConfig( activation_config=activation_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config) \# train and convert same as above (not shown) ``` **BC-breaking notes:** This change by itself is technically not BC-breaking since we keep around the old path, but will become so when we deprecate and remove the old path in the future. Before: ``` \# prepare activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config), quantize_(model, qat_config) \# train (not shown) \# convert quantize_(model, FromIntXQuantizationAwareTrainingConfig()) quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) ``` After: (see above) **Test Plan:** ``` python test/quantization/test_qat.py ``` [ghstack-poisoned] * Update base for Update on "New multi-step QAT API" **Summary:** This commit adds a new multi-step QAT API with the main goal of simplifying the existing UX. The new API uses the same `QATConfig` for both the prepare and convert steps, and automatically infers the fake quantization configs based on a PTQ base config provided by the user: ``` from torchao.quantization import ( quantize_, Int8DynamicActivationInt4WeightConfig ) from torchao.quantization.qat import QATConfig \# prepare base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) qat_config = QATConfig(base_config, step="prepare") quantize_(m, qat_config) \# train (not shown) \# convert quantize_(m, QATConfig(base_config, step="convert")) ``` The main improvements include: - A single config for both prepare and convert steps - A single quantize_ for convert (instead of 2) - No chance for incompatible prepare vs convert configs - Much less boilerplate code for most common use case - Simpler config names For less common use cases such as experimentation, users can still specify arbitrary fake quantization configs for activations and/or weights as before. This is still important since there may not always be a corresponding PTQ base config. For example: ``` from torchao.quantization import quantize_ from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = QATConfig( activation_config=activation_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config) \# train and convert same as above (not shown) ``` **BC-breaking notes:** This change by itself is technically not BC-breaking since we keep around the old path, but will become so when we deprecate and remove the old path in the future. Before: ``` \# prepare activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config), quantize_(model, qat_config) \# train (not shown) \# convert quantize_(model, FromIntXQuantizationAwareTrainingConfig()) quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) ``` After: (see above) **Test Plan:** ``` python test/quantization/test_qat.py ``` [ghstack-poisoned] * Update base for Update on "New multi-step QAT API" **Summary:** This commit adds a new multi-step QAT API with the main goal of simplifying the existing UX. The new API uses the same `QATConfig` for both the prepare and convert steps, and automatically infers the fake quantization configs based on a PTQ base config provided by the user: ```Py from torchao.quantization import ( quantize_, Int8DynamicActivationInt4WeightConfig ) from torchao.quantization.qat import QATConfig \# prepare base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) qat_config = QATConfig(base_config, step="prepare") quantize_(m, qat_config) \# train (not shown) \# convert quantize_(m, QATConfig(base_config, step="convert")) ``` The main improvements include: - A single config for both prepare and convert steps - A single quantize_ for convert (instead of 2) - No chance for incompatible prepare vs convert configs - Much less boilerplate code for most common use case - Simpler config names For less common use cases such as experimentation, users can still specify arbitrary fake quantization configs for activations and/or weights as before. This is still important since there may not always be a corresponding PTQ base config. For example: ```Py from torchao.quantization import quantize_ from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = QATConfig( activation_config=activation_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config) \# train and convert same as above (not shown) ``` **BC-breaking notes:** This change by itself is technically not BC-breaking since we keep around the old path, but will become so when we deprecate and remove the old path in the future. Before: ```Py \# prepare activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config), quantize_(model, qat_config) \# train (not shown) \# convert quantize_(model, FromIntXQuantizationAwareTrainingConfig()) quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) ``` After: (see above) **Test Plan:** ``` python test/quantization/test_qat.py ``` [ghstack-poisoned] * Update base for Update on "New multi-step QAT API" **Summary:** This commit adds a new multi-step QAT API with the main goal of simplifying the existing UX. The new API uses the same `QATConfig` for both the prepare and convert steps, and automatically infers the fake quantization configs based on a PTQ base config provided by the user: ```Py from torchao.quantization import ( quantize_, Int8DynamicActivationInt4WeightConfig ) from torchao.quantization.qat import QATConfig \# prepare base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) qat_config = QATConfig(base_config, step="prepare") quantize_(m, qat_config) \# train (not shown) \# convert quantize_(m, QATConfig(base_config, step="convert")) ``` The main improvements include: - A single config for both prepare and convert steps - A single quantize_ for convert (instead of 2) - No chance for incompatible prepare vs convert configs - Much less boilerplate code for most common use case - Simpler config names For less common use cases such as experimentation, users can still specify arbitrary fake quantization configs for activations and/or weights as before. This is still important since there may not always be a corresponding PTQ base config. For example: ```Py from torchao.quantization import quantize_ from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = QATConfig( activation_config=activation_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config) \# train and convert same as above (not shown) ``` **BC-breaking notes:** This change by itself is technically not BC-breaking since we keep around the old path, but will become so when we deprecate and remove the old path in the future. Before: ```Py \# prepare activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config), quantize_(model, qat_config) \# train (not shown) \# convert quantize_(model, FromIntXQuantizationAwareTrainingConfig()) quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) ``` After: (see above) **Test Plan:** ``` python test/quantization/test_qat.py ``` [ghstack-poisoned]
1 parent 7dbc816 commit 15e501a

File tree

9 files changed

+483
-142
lines changed

9 files changed

+483
-142
lines changed

README.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,17 @@ With this quantization flow, we achieve **67% VRAM reduction and 12-20% speedup*
179179
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration with [TorchTune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md#quantization-aware-training-qat), we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the [QAT README](torchao/quantization/qat/README.md) and the [original blog](https://pytorch.org/blog/quantization-aware-training/):
180180

181181
```python
182-
from torchao.quantization import quantize_
183-
from torchao.quantization.qat import IntxFakeQuantizeConfig, IntXQuantizationAwareTrainingConfig
184-
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
185-
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
186-
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
187-
quantize_(my_model, qat_config)
182+
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
183+
from torchao.quantization.qat import QATConfig
184+
185+
# prepare
186+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
187+
quantize_(my_model, QATConfig(base_config, step="prepare"))
188+
189+
# train model (not shown)
190+
191+
# convert
192+
quantize_(my_model, QATConfig(base_config, step="convert"))
188193
```
189194

190195
Users can also combine LoRA + QAT to speed up training by [1.89x](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700) compared to vanilla QAT using this [fine-tuning recipe](https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py).

docs/source/api_ref_qat.rst

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ torchao.quantization.qat
66

77
.. currentmodule:: torchao.quantization.qat
88

9-
QAT Configs for quantize_
9+
Main Config for quantize_
1010
---------------------------------------
1111
For a full example of how to use QAT with our main `quantize_` API,
1212
please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md#quantize_-api-recommended>`__.
@@ -15,29 +15,32 @@ please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao
1515
:toctree: generated/
1616
:nosignatures:
1717

18-
IntXQuantizationAwareTrainingConfig
19-
FromIntXQuantizationAwareTrainingConfig
18+
QATConfig
19+
QATStep
2020

2121
Custom QAT APIs
2222
---------------
2323
.. autosummary::
2424
:toctree: generated/
2525
:nosignatures:
2626

27+
FakeQuantizeConfigBase
2728
IntxFakeQuantizeConfig
2829
FakeQuantizedLinear
2930
FakeQuantizedEmbedding
3031
FakeQuantizer
3132
linear.enable_linear_fake_quant
3233
linear.disable_linear_fake_quant
3334

34-
Legacy QAT Quantizers
35+
Legacy QAT APIs
3536
---------------------
3637

3738
.. autosummary::
3839
:toctree: generated/
3940
:nosignatures:
4041

42+
IntXQuantizationAwareTrainingConfig
43+
FromIntXQuantizationAwareTrainingConfig
4144
Int4WeightOnlyQATQuantizer
4245
linear.Int4WeightOnlyQATLinear
4346
Int8DynActInt4WeightQATQuantizer

docs/source/finetuning.rst

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -205,21 +205,14 @@ because we are not actually casting the fake quantized values.
205205

206206
.. code:: py
207207
208-
from torchao.quantization import (
209-
quantize_,
210-
)
211-
from torchao.quantization.qat import (
212-
FakeQuantizeConfig,
213-
IntXQuantizationAwareTrainingConfig,
214-
)
208+
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
209+
from torchao.quantization.qat import QATConfig
210+
215211
model = get_model()
216212
217-
# prepare: insert fake quantization ops
218-
# swaps `torch.nn.Linear` with `FakeQuantizedLinear`
219-
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
220-
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
221-
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config)
222-
quantize_(model, qat_config)
213+
# prepare: swap `torch.nn.Linear` -> `FakeQuantizedLinear`
214+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
215+
quantize_(model, QATConfig(base_config, step="prepare"))
223216
224217
# fine-tune
225218
train_loop(model)
@@ -232,18 +225,12 @@ The next step is to actually quantize the model:
232225

233226
.. code:: py
234227
235-
from torchao.quantization import (
236-
Int8DynamicActivationInt4WeightConfig,
237-
)
238-
from torchao.quantization.qat import (
239-
FromIntXQuantizationAwareTrainingConfig,
240-
)
228+
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
241229
242-
# convert: transform fake quantization ops into actual quantized ops
243-
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
244-
# quantized activation and weight tensor subclasses
245-
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
246-
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
230+
# convert: swap `FakeQuantizedLinear` -> `torch.nn.Linear`, then quantize using `base_config`
231+
quantize_(model, QATConfig(base_config, step="convert"))
232+
233+
# inference or generate
247234
248235
Now our model is ready for serving, and will typically have higher quantized
249236
accuracy than if we did not apply the prepare step (fake quantization) during

test/quantization/test_qat.py

Lines changed: 136 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
ComposableQATQuantizer,
3535
FromIntXQuantizationAwareTrainingConfig,
3636
IntXQuantizationAwareTrainingConfig,
37+
QATConfig,
38+
QATStep,
3739
initialize_fake_quantizers,
3840
)
3941
from torchao.quantization.qat.embedding import (
@@ -59,7 +61,7 @@
5961
_get_qmin_qmax,
6062
)
6163
from torchao.quantization.quant_api import (
62-
int8_dynamic_activation_int4_weight,
64+
Int8DynamicActivationInt4WeightConfig,
6365
)
6466
from torchao.quantization.quant_primitives import (
6567
MappingType,
@@ -1261,11 +1263,67 @@ def test_qat_prototype_bc(self):
12611263
@unittest.skipIf(
12621264
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
12631265
)
1264-
def test_quantize_api_standalone(self):
1266+
def test_qat_config_init(self):
1267+
"""
1268+
Test that the correct errors are thrown if `QATConfig` is not instantiated properly.
1269+
"""
1270+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
1271+
fq_config = IntxFakeQuantizeConfig(torch.int8, "per_channel")
1272+
1273+
# OK
1274+
QATConfig(base_config, step="prepare")
1275+
QATConfig(base_config, step="convert")
1276+
QATConfig(base_config, step=QATStep.PREPARE)
1277+
QATConfig(base_config, step=QATStep.CONVERT)
1278+
QATConfig(activation_config=fq_config, weight_config=fq_config, step="prepare")
1279+
QATConfig(weight_config=fq_config, step="prepare")
1280+
1281+
# OK: good step values
1282+
self.assertEqual(QATConfig(base_config).step, "prepare")
1283+
self.assertEqual(QATConfig(base_config, step="Prepare").step, "prepare")
1284+
self.assertEqual(QATConfig(base_config, step="CONVERT").step, "convert")
1285+
1286+
# Bad step
1287+
with self.assertRaisesRegex(ValueError, "`step` must be one of"):
1288+
QATConfig(base_config, step="blah")
1289+
1290+
# Step was not a keyword arg
1291+
with self.assertRaisesRegex(
1292+
TypeError, "4 positional arguments but 5 were given"
1293+
):
1294+
QATConfig(base_config, None, None, "prepare")
1295+
1296+
# No configs are provided
1297+
with self.assertRaisesRegex(
1298+
ValueError, "One of `base_config` or `weight_config` must be specified"
1299+
):
1300+
QATConfig(step="prepare")
1301+
1302+
# Clashing configs are provided
1303+
with self.assertRaisesRegex(ValueError, "Cannot specify both"):
1304+
QATConfig(base_config, weight_config=fq_config, step="prepare")
1305+
with self.assertRaisesRegex(ValueError, "Cannot specify both"):
1306+
QATConfig(base_config, activation_config=fq_config, step="prepare")
1307+
with self.assertRaisesRegex(
1308+
ValueError, "must be specified in the convert step"
1309+
):
1310+
QATConfig(weight_config=fq_config, step="convert")
1311+
1312+
# FakeQuantizeConfigBase was specified as base_config
1313+
with self.assertRaisesRegex(
1314+
ValueError,
1315+
"was passed as `base_config`. Did you mean to do the following instead?",
1316+
):
1317+
QATConfig(fq_config, step="prepare")
1318+
1319+
@unittest.skipIf(
1320+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1321+
)
1322+
def test_quantize_api_prepare(self):
12651323
"""
12661324
Test that the following:
12671325
1268-
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
1326+
quantize_(model, QATConfig(...))
12691327
12701328
can produce the same results as `ComposableQATQuantizer`.
12711329
"""
@@ -1290,20 +1348,15 @@ def test_quantize_api_standalone(self):
12901348
baseline_model = baseline_quantizer.prepare(baseline_model)
12911349

12921350
# quantize_ API
1293-
activation_config = IntxFakeQuantizeConfig(
1294-
torch.int8,
1295-
"per_token",
1296-
is_symmetric=False,
1297-
)
1351+
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
12981352
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1299-
quantize_(
1300-
m,
1301-
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
1353+
qat_config1 = QATConfig(
1354+
activation_config=act_config, weight_config=weight_config
13021355
)
1356+
qat_config2 = QATConfig(weight_config=weight_config)
1357+
quantize_(m, qat_config1)
13031358
quantize_(
1304-
m,
1305-
IntXQuantizationAwareTrainingConfig(weight_config=weight_config),
1306-
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
1359+
m, qat_config2, filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding)
13071360
)
13081361

13091362
# Compare model values
@@ -1322,37 +1375,29 @@ def test_quantize_api_errors(self):
13221375
Test that we throw exceptions with helpful error messages if `quantize_`
13231376
runs into unexpected configurations.
13241377
"""
1325-
my_config = IntxFakeQuantizeConfig(torch.int8, group_size=32)
1378+
fq_config = IntxFakeQuantizeConfig(torch.int8, group_size=32)
1379+
qat_config = QATConfig(activation_config=fq_config, weight_config=fq_config)
13261380
m = M3()
13271381

13281382
# Embedding currently only supports weight-only quantization
13291383
with self.assertRaisesRegex(
13301384
ValueError, "Activation fake quantization is not supported for embedding"
13311385
):
1332-
quantize_(
1333-
m,
1334-
IntXQuantizationAwareTrainingConfig(my_config, my_config),
1335-
lambda m, _: isinstance(m, torch.nn.Embedding),
1336-
)
1386+
quantize_(m, qat_config, lambda m, _: isinstance(m, torch.nn.Embedding))
13371387

13381388
# Only linear and embedding are supported currently
13391389
with self.assertRaisesRegex(ValueError, "does not have QAT support"):
1340-
quantize_(
1341-
m,
1342-
IntXQuantizationAwareTrainingConfig(my_config, my_config),
1343-
lambda m, _: isinstance(m, torch.nn.ReLU),
1344-
)
1390+
quantize_(m, qat_config, lambda m, _: isinstance(m, torch.nn.ReLU))
13451391

13461392
@unittest.skipIf(
13471393
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
13481394
)
1349-
def test_quantize_api_convert_path(self):
1395+
def test_quantize_api_e2e(self):
13501396
"""
13511397
Test that the following:
13521398
1353-
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
1354-
quantize_(model, FromIntXQuantizationAwareTrainingConfig(...))
1355-
quantize_(model, int8_dynamic_activation_int4_weight())
1399+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
1400+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))
13561401
13571402
can produce the same results as `Int8DynActInt4WeightQATQuantizer` prepare + convert.
13581403
"""
@@ -1370,16 +1415,8 @@ def test_quantize_api_convert_path(self):
13701415
baseline_model = baseline_quantizer.prepare(baseline_model)
13711416

13721417
# quantize_ prepare
1373-
activation_config = IntxFakeQuantizeConfig(
1374-
torch.int8,
1375-
"per_token",
1376-
is_symmetric=False,
1377-
)
1378-
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1379-
quantize_(
1380-
m,
1381-
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
1382-
)
1418+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
1419+
quantize_(m, QATConfig(base_config, step="prepare"))
13831420

13841421
# Compare prepared values
13851422
torch.manual_seed(self.SEED)
@@ -1393,8 +1430,7 @@ def test_quantize_api_convert_path(self):
13931430
baseline_model = baseline_quantizer.convert(baseline_model)
13941431

13951432
# quantize_ convert
1396-
quantize_(m, FromIntXQuantizationAwareTrainingConfig())
1397-
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))
1433+
quantize_(m, QATConfig(base_config, step="convert"))
13981434

13991435
# Compare converted values
14001436
torch.manual_seed(self.SEED)
@@ -1447,14 +1483,12 @@ def test_qat_linear_bias(self):
14471483
Test that QAT supports linear bias.
14481484
"""
14491485
m = ModelWithLinearBias()
1450-
activation_config = IntxFakeQuantizeConfig(
1451-
torch.int8, "per_token", is_symmetric=False
1452-
)
1486+
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
14531487
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=32)
1454-
quantize_(
1455-
m,
1456-
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
1488+
qat_config = QATConfig(
1489+
activation_config=act_config, weight_config=weight_config
14571490
)
1491+
quantize_(m, qat_config)
14581492
example_inputs = m.example_inputs()
14591493
m(*example_inputs)
14601494

@@ -1653,7 +1687,7 @@ def test_qat_range_learning(self):
16531687
)
16541688
m = M()
16551689
example_inputs = m.example_inputs()
1656-
quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config))
1690+
quantize_(m, QATConfig(weight_config=config))
16571691

16581692
# Not initialized, should fail
16591693
for t in m._get_all_weight_qparams():
@@ -1756,6 +1790,60 @@ def test_qat_fp8a4w_quantizer(self):
17561790
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
17571791
self.assertFalse(torch.equal(new_weight, prev_weight))
17581792

1793+
@unittest.skipIf(
1794+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1795+
)
1796+
def test_legacy_quantize_api_e2e(self):
1797+
"""
1798+
Test that the following two APIs are numerically equivalent:
1799+
1800+
New API:
1801+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
1802+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))
1803+
1804+
Old API:
1805+
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
1806+
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
1807+
quantize_(model, Int8DynamicActivationInt4WeightConfig())
1808+
"""
1809+
group_size = 16
1810+
torch.manual_seed(self.SEED)
1811+
m = M()
1812+
baseline_model = copy.deepcopy(m)
1813+
1814+
# Baseline prepare
1815+
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
1816+
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1817+
old_qat_config = IntXQuantizationAwareTrainingConfig(act_config, weight_config)
1818+
quantize_(baseline_model, old_qat_config)
1819+
1820+
# QATConfig prepare
1821+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
1822+
quantize_(m, QATConfig(base_config, step="prepare"))
1823+
1824+
# Compare prepared values
1825+
torch.manual_seed(self.SEED)
1826+
x = m.example_inputs()
1827+
x2 = copy.deepcopy(x)
1828+
out = m(*x)
1829+
baseline_out = baseline_model(*x2)
1830+
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
1831+
1832+
# Baseline convert
1833+
quantize_(baseline_model, FromIntXQuantizationAwareTrainingConfig())
1834+
quantize_(baseline_model, base_config)
1835+
1836+
# quantize_ convert
1837+
quantize_(m, QATConfig(base_config, step="convert"))
1838+
1839+
# Compare converted values
1840+
torch.manual_seed(self.SEED)
1841+
x = m.example_inputs()
1842+
x2 = copy.deepcopy(x)
1843+
out = m(*x)
1844+
baseline_out = baseline_model(*x2)
1845+
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
1846+
17591847

17601848
if __name__ == "__main__":
17611849
unittest.main()

0 commit comments

Comments
 (0)