Skip to content

Commit 97b090d

Browse files
authored
[bc-breaking] Generalize FakeQuantizeConfig beyond intx (#2628)
* [bc-breaking] Generalize FakeQuantizeConfig beyond intx **Summary:** The existing `FakeQuantizeConfig` performs only intx quantization, but we plan to extend QAT to other dtypes such as fp8 and nvfp4 in the near future. This is the necessary refactor before that. Specifically: ``` # New abstract class FakeQuantizeConfigBase # Rename FakeQuantizeConfig -> IntxFakeQuantizeConfig ``` In the future, we will have other types of `FakeQuantizeConfigBase` for float dtypes that users can pass in instead of the existing Intx one. **BC-breaking notes:** For BC, we keep around the old names to reference the new ones. However, this commit is still BC-breaking in the sense that a few APIs now accept the abstract `FakeQuantizeConfigBase` instead. For the most part, this abstract class will be hidden from the user. Before: ``` activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=32) ``` After: ``` activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) ``` **Test Plan:** python test/quantization/test_qat.py [ghstack-poisoned] * Update on "[bc-breaking] Generalize FakeQuantizeConfig beyond intx" **Summary:** The existing `FakeQuantizeConfig` performs only intx quantization, but we plan to extend QAT to other dtypes such as fp8 and nvfp4 in the near future. This is the necessary refactor before that. Specifically: ``` # New abstract class FakeQuantizeConfigBase # Rename FakeQuantizeConfig -> IntxFakeQuantizeConfig ``` In the future, we will have other types of `FakeQuantizeConfigBase` for float dtypes that users can pass in instead of the existing Intx one. **BC-breaking notes:** For BC, we keep around the old names to reference the new ones. However, this commit is still BC-breaking in the sense that a few APIs now accept the abstract `FakeQuantizeConfigBase` instead. For the most part, this abstract class will be hidden from the user. Before: ``` activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=32) ``` After: ``` activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) ``` **Test Plan:** python test/quantization/test_qat.py [ghstack-poisoned] * Update on "[bc-breaking] Generalize FakeQuantizeConfig beyond intx" **Summary:** The existing `FakeQuantizeConfig` performs only intx quantization, but we plan to extend QAT to other dtypes such as fp8 and nvfp4 in the near future. This is the necessary refactor before that. Specifically: ``` # New abstract class FakeQuantizeConfigBase # Rename FakeQuantizeConfig -> IntxFakeQuantizeConfig ``` In the future, we will have other types of `FakeQuantizeConfigBase` for float dtypes that users can pass in instead of the existing Intx one. **BC-breaking notes:** For BC, we keep around the old names to reference the new ones. However, this commit is still BC-breaking in the sense that a few APIs now accept the abstract `FakeQuantizeConfigBase` instead. For the most part, this abstract class will be hidden from the user. Before: ``` activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=32) ``` After: ``` activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) ``` **Test Plan:** python test/quantization/test_qat.py [ghstack-poisoned] * Update on "[bc-breaking] Generalize FakeQuantizeConfig beyond intx" **Summary:** The existing `FakeQuantizeConfig` performs only intx quantization, but we plan to extend QAT to other dtypes such as fp8 and nvfp4 in the near future. This is the necessary refactor before that. Specifically: ``` # New abstract class FakeQuantizeConfigBase # Rename FakeQuantizeConfig -> IntxFakeQuantizeConfig ``` In the future, we will have other types of `FakeQuantizeConfigBase` for float dtypes that users can pass in instead of the existing Intx one. **BC-breaking notes:** For BC, we keep around the old names to reference the new ones. However, this commit is still BC-breaking in the sense that a few APIs now accept the abstract `FakeQuantizeConfigBase` instead. For the most part, this abstract class will be hidden from the user. Before: ``` activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=32) ``` After: ``` activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) ``` **Test Plan:** python test/quantization/test_qat.py [ghstack-poisoned]
1 parent 8d4a5d8 commit 97b090d

File tree

13 files changed

+432
-377
lines changed

13 files changed

+432
-377
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,9 @@ Post-training quantization can result in a fast and compact model, but may also
180180

181181
```python
182182
from torchao.quantization import quantize_
183-
from torchao.quantization.qat import FakeQuantizeConfig, IntXQuantizationAwareTrainingConfig
184-
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
185-
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
183+
from torchao.quantization.qat import IntxFakeQuantizeConfig, IntXQuantizationAwareTrainingConfig
184+
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
185+
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
186186
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
187187
quantize_(my_model, qat_config)
188188
```

docs/source/api_ref_qat.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Custom QAT APIs
2424
:toctree: generated/
2525
:nosignatures:
2626

27-
FakeQuantizeConfig
27+
IntxFakeQuantizeConfig
2828
FakeQuantizedLinear
2929
FakeQuantizedEmbedding
3030
FakeQuantizer

test/prototype/test_parq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE
3131
from torchao.quantization.granularity import PerGroup
3232
from torchao.quantization.qat import (
33-
FakeQuantizeConfig,
3433
FromIntXQuantizationAwareTrainingConfig,
34+
IntxFakeQuantizeConfig,
3535
IntXQuantizationAwareTrainingConfig,
3636
)
3737
from torchao.quantization.quant_api import (
@@ -393,7 +393,7 @@ def test_int8_dynamic_activation_intx_e2e(
393393
optimizer.step()
394394

395395
# apply torchao quantized activations on top
396-
activation_config = FakeQuantizeConfig(
396+
activation_config = IntxFakeQuantizeConfig(
397397
torch.int8,
398398
granularity="per_token",
399399
mapping_type=config.act_mapping_type,

test/quantization/test_qat.py

Lines changed: 90 additions & 83 deletions
Large diffs are not rendered by default.

torchao/experimental/tests/test_embedding_xbit_quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
)
2222
from torchao.quantization.granularity import PerAxis, PerGroup
2323
from torchao.quantization.qat import (
24-
FakeQuantizeConfig,
2524
FromIntXQuantizationAwareTrainingConfig,
2625
Int4WeightOnlyEmbeddingQATQuantizer,
26+
IntxFakeQuantizeConfig,
2727
IntXQuantizationAwareTrainingConfig,
2828
)
2929
from torchao.quantization.quant_api import (
@@ -282,7 +282,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
282282
)
283283

284284
embedding_filter = lambda m, fqn: isinstance(m, torch.nn.Embedding)
285-
weight_config = FakeQuantizeConfig(
285+
weight_config = IntxFakeQuantizeConfig(
286286
weight_dtype,
287287
group_size=group_size,
288288
is_symmetric=is_symmetric,

torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout
1717
from torchao.quantization.granularity import PerAxis, PerGroup
1818
from torchao.quantization.qat import (
19-
FakeQuantizeConfig,
2019
FromIntXQuantizationAwareTrainingConfig,
2120
Int8DynActInt4WeightQATQuantizer,
21+
IntxFakeQuantizeConfig,
2222
IntXQuantizationAwareTrainingConfig,
2323
)
2424
from torchao.quantization.quant_api import (
@@ -538,12 +538,12 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
538538
model = model.to(model_dtype)
539539
activations = activations.to(model_dtype)
540540

541-
activation_config = FakeQuantizeConfig(
541+
activation_config = IntxFakeQuantizeConfig(
542542
torch.int8,
543543
"per_token",
544544
is_symmetric=is_act_symmetric,
545545
)
546-
weight_config = FakeQuantizeConfig(
546+
weight_config = IntxFakeQuantizeConfig(
547547
weight_dtype,
548548
group_size=group_size,
549549
is_symmetric=is_symmetric,

torchao/quantization/qat/README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def train_loop(m: torch.nn.Module):
7171

7272
The recommended way to run QAT in torchao is through the `quantize_` API:
7373
1. **Prepare:** specify how weights and/or activations are to be quantized through
74-
[`FakeQuantizeConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.FakeQuantizeConfig.html#torchao.quantization.qat.FakeQuantizeConfig) and passing these to [`IntXQuantizationAwareTrainingConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.IntXQuantizationAwareTrainingConfig.html#torchao.quantization.qat.IntXQuantizationAwareTrainingConfig)
74+
[`IntxFakeQuantizeConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.IntxFakeQuantizeConfig.html#torchao.quantization.qat.IntxFakeQuantizeConfig) and passing these to [`IntXQuantizationAwareTrainingConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.IntXQuantizationAwareTrainingConfig.html#torchao.quantization.qat.IntXQuantizationAwareTrainingConfig)
7575
2. **Convert:** quantize the model using the standard post-training quantization (PTQ)
7676
functions such as [`Int8DynamicActivationInt4WeightConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int8DynamicActivationInt4WeightConfig.html#torchao.quantization.Int8DynamicActivationInt4WeightConfig)
7777

@@ -84,16 +84,16 @@ from torchao.quantization import (
8484
Int8DynamicActivationInt4WeightConfig,
8585
)
8686
from torchao.quantization.qat import (
87-
FakeQuantizeConfig,
87+
IntxFakeQuantizeConfig,
8888
FromIntXQuantizationAwareTrainingConfig,
8989
IntXQuantizationAwareTrainingConfig,
9090
)
9191
model = get_model()
9292

9393
# prepare: insert fake quantization ops
9494
# swaps `torch.nn.Linear` with `FakeQuantizedLinear`
95-
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
96-
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
95+
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
96+
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
9797
quantize_(
9898
model,
9999
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
@@ -116,8 +116,8 @@ the following with a filter function during the prepare step:
116116

117117
```
118118
# first apply linear transformation to the model as above
119-
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
120-
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
119+
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
120+
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
121121
quantize_(
122122
model,
123123
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),

torchao/quantization/qat/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from .api import (
22
ComposableQATQuantizer,
3-
FakeQuantizeConfig,
43
FromIntXQuantizationAwareTrainingConfig,
54
IntXQuantizationAwareTrainingConfig,
65
from_intx_quantization_aware_training,
@@ -11,6 +10,11 @@
1110
FakeQuantizedEmbedding,
1211
Int4WeightOnlyEmbeddingQATQuantizer,
1312
)
13+
from .fake_quantize_config import (
14+
FakeQuantizeConfig,
15+
FakeQuantizeConfigBase,
16+
IntxFakeQuantizeConfig,
17+
)
1418
from .fake_quantizer import FakeQuantizer
1519
from .linear import (
1620
FakeQuantizedLinear,
@@ -21,7 +25,7 @@
2125

2226
__all__ = [
2327
"ComposableQATQuantizer",
24-
"FakeQuantizeConfig",
28+
"FakeQuantizeConfigBase",
2529
"FakeQuantizedLinear",
2630
"FakeQuantizedEmbedding",
2731
"FakeQuantizer",
@@ -30,8 +34,11 @@
3034
"Int4WeightOnlyEmbeddingQATQuantizer",
3135
"Int4WeightOnlyQATQuantizer",
3236
"Int8DynActInt4WeightQATQuantizer",
37+
"IntxFakeQuantizeConfig",
3338
"IntXQuantizationAwareTrainingConfig",
3439
"initialize_fake_quantizers",
35-
"intx_quantization_aware_training",
40+
# for BC
41+
"FakeQuantizeConfig",
3642
"from_intx_quantization_aware_training",
43+
"intx_quantization_aware_training",
3744
]

0 commit comments

Comments
 (0)