Skip to content

Commit 6cfa477

Browse files
authored
Generalize FakeQuantizer beyond intx (pytorch#2714)
**Summary:** Similar to pytorch#2628, but for `FakeQuantizer`. It is cleaner to isolate the logic of each quantizer in separate classes, e.g. intx vs nvfp4 vs fp8. Naming change: ``` FakeQuantizer -> IntxFakeQuantizer ``` **BC-breaking notes:** This is technically not BC-breaking yet since we are just deprecating the old APIs while keeping them around. It will be when we do remove the old APIs in the future according to pytorch#2630. Before: ``` config = IntxFakeQuantizeConfig(torch.int8, "per_channel") FakeQuantizer(config) ``` After: ``` config = IntxFakeQuantizeConfig(torch.int8, "per_channel") IntxFakeQuantizer(config) # or FakeQuantizerBase.from_config(config) ``` **Test Plan:** ``` python test/quantization/test_qat.py ``` [ghstack-poisoned]
1 parent c086ade commit 6cfa477

File tree

8 files changed

+67
-31
lines changed

8 files changed

+67
-31
lines changed

docs/source/api_ref_qat.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ Custom QAT APIs
2828
IntxFakeQuantizeConfig
2929
FakeQuantizedLinear
3030
FakeQuantizedEmbedding
31-
FakeQuantizer
31+
FakeQuantizerBase
32+
IntxFakeQuantizer
3233
linear.enable_linear_fake_quant
3334
linear.disable_linear_fake_quant
3435

test/quantization/test_qat.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
IntxFakeQuantizeConfig,
4747
)
4848
from torchao.quantization.qat.fake_quantizer import (
49-
FakeQuantizer,
49+
IntxFakeQuantizer,
5050
_Float8RowwiseActivationFakeQuantizer,
5151
)
5252
from torchao.quantization.qat.linear import (
@@ -1466,10 +1466,10 @@ def test_fake_quantize_config_torch_intx(self):
14661466
)
14671467
def test_fake_quantizer_repr(self):
14681468
"""
1469-
Test that `repr(FakeQuantizer(config))` exposes useful config details.
1469+
Test that `repr(IntxFakeQuantizer(config))` exposes useful config details.
14701470
"""
14711471
config = IntxFakeQuantizeConfig(torch.int4, group_size=128)
1472-
fake_quantizer = FakeQuantizer(config)
1472+
fake_quantizer = IntxFakeQuantizer(config)
14731473
fake_quantizer_repr = repr(fake_quantizer)
14741474
self.assertTrue("dtype=torch.int4" in fake_quantizer_repr)
14751475
self.assertTrue("group_size=128" in fake_quantizer_repr)
@@ -1500,15 +1500,15 @@ def test_qat_linear_bias(self):
15001500
def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
15011501
"""
15021502
Test that the following produce the exact same numerics:
1503-
1. FakeQuantizer with asymmetric per_token config
1503+
1. IntxFakeQuantizer with asymmetric per_token config
15041504
2. torchao.quantization.utils.per_token_dynamic_quant
15051505
"""
15061506
from torchao.quantization.utils import per_token_dynamic_quant
15071507

15081508
torch.manual_seed(self.SEED)
15091509
x = torch.randn(1, 235, 2048).to(dtype)
15101510
config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
1511-
fake_quantizer = FakeQuantizer(config)
1511+
fake_quantizer = IntxFakeQuantizer(config)
15121512
fake_quantizer_out = fake_quantizer(x)
15131513
baseline_out = per_token_dynamic_quant(x)
15141514
torch.testing.assert_close(fake_quantizer_out, baseline_out, atol=0, rtol=0)
@@ -1580,7 +1580,7 @@ def test_fake_quantize_config_eps(self):
15801580
is_symmetric=False,
15811581
eps=eps,
15821582
)
1583-
fake_quantizer = FakeQuantizer(config)
1583+
fake_quantizer = IntxFakeQuantizer(config)
15841584
actual_out = fake_quantizer(x)
15851585
torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0)
15861586

@@ -1638,7 +1638,7 @@ def test_qat_8da4w_eps(self):
16381638
)
16391639
def test_fake_quantizer_range_learning(self):
16401640
"""
1641-
Test that range learning requires `FakeQuantizer`s to be initialized correctly.
1641+
Test that range learning requires `IntxFakeQuantizer`s to be initialized correctly.
16421642
"""
16431643
config = IntxFakeQuantizeConfig(
16441644
torch.int8,
@@ -1648,7 +1648,7 @@ def test_fake_quantizer_range_learning(self):
16481648
scale_precision=torch.float32,
16491649
zero_point_precision=torch.float32,
16501650
)
1651-
fake_quantizer = FakeQuantizer(config)
1651+
fake_quantizer = IntxFakeQuantizer(config)
16521652
example_inputs = (torch.randn(2, 3),)
16531653

16541654
# Not initialized, should fail
@@ -1770,7 +1770,7 @@ def test_qat_fp8a4w_quantizer(self):
17701770
self.assertIsInstance(
17711771
linear.activation_fake_quantizer, _Float8RowwiseActivationFakeQuantizer
17721772
)
1773-
self.assertIsInstance(linear.weight_fake_quantizer, FakeQuantizer)
1773+
self.assertIsInstance(linear.weight_fake_quantizer, IntxFakeQuantizer)
17741774
prev_weight = copy.deepcopy(m.linear1.weight)
17751775

17761776
# Simulate training
@@ -1854,6 +1854,7 @@ def test_qat_api_deprecation(self):
18541854
"""
18551855
from torchao.quantization.qat import (
18561856
FakeQuantizeConfig,
1857+
FakeQuantizer,
18571858
from_intx_quantization_aware_training,
18581859
intx_quantization_aware_training,
18591860
)
@@ -1868,6 +1869,7 @@ def test_qat_api_deprecation(self):
18681869
intx_quantization_aware_training: (),
18691870
from_intx_quantization_aware_training: (),
18701871
FakeQuantizeConfig: (torch.int8, "per_channel"),
1872+
FakeQuantizer: (IntxFakeQuantizeConfig(torch.int8, "per_channel"),),
18711873
}
18721874

18731875
with warnings.catch_warnings(record=True) as _warnings:

torchao/quantization/prototype/qat/fake_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from torchao.quantization.qat.fake_quantizer import (
2-
FakeQuantizer,
2+
IntxFakeQuantizer as FakeQuantizer,
33
)
44

55
__all__ = [

torchao/quantization/qat/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
FakeQuantizeConfigBase,
1818
IntxFakeQuantizeConfig,
1919
)
20-
from .fake_quantizer import FakeQuantizer
20+
from .fake_quantizer import (
21+
FakeQuantizer,
22+
FakeQuantizerBase,
23+
IntxFakeQuantizer,
24+
)
2125
from .linear import (
2226
FakeQuantizedLinear,
2327
Float8ActInt4WeightQATQuantizer,
@@ -29,8 +33,9 @@
2933
"QATConfig",
3034
"QATStep",
3135
"FakeQuantizeConfigBase",
36+
"FakeQuantizerBase",
3237
"IntxFakeQuantizeConfig",
33-
"FakeQuantizer",
38+
"IntxFakeQuantizer",
3439
"FakeQuantizedLinear",
3540
"FakeQuantizedEmbedding",
3641
# Prototype
@@ -42,6 +47,7 @@
4247
"Int4WeightOnlyQATQuantizer",
4348
"Int8DynActInt4WeightQATQuantizer",
4449
# for BC
50+
"FakeQuantizer",
4551
"FakeQuantizeConfig",
4652
"from_intx_quantization_aware_training",
4753
"FromIntXQuantizationAwareTrainingConfig",

torchao/quantization/qat/api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,14 +382,14 @@ def initialize_fake_quantizers(
382382
) -> None:
383383
"""
384384
(Prototype) Initialize the scales and zero points on all
385-
:class:`~torchao.quantization.qat.fake_quantizer.FakeQuantizer`
385+
:class:`~torchao.quantization.qat.fake_quantizer.IntxFakeQuantizerBase`
386386
in the model based on the provided example inputs.
387387
"""
388388
# avoid circular dependencies
389-
from torchao.quantization.qat.fake_quantizer import FakeQuantizer
389+
from torchao.quantization.qat.fake_quantizer import IntxFakeQuantizer
390390

391391
def _set_initialized(m: torch.nn.Module):
392-
if isinstance(m, FakeQuantizer):
392+
if isinstance(m, IntxFakeQuantizer):
393393
m._initialized = True
394394

395395
model.apply(_set_initialized)

torchao/quantization/qat/embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
FakeQuantizeConfigBase,
1818
IntxFakeQuantizeConfig,
1919
)
20-
from .fake_quantizer import FakeQuantizer
20+
from .fake_quantizer import FakeQuantizerBase
2121
from .utils import (
2222
_get_qmin_qmax,
2323
)
@@ -66,7 +66,7 @@ def __init__(
6666
**kwargs,
6767
)
6868
if weight_config is not None:
69-
self.weight_fake_quantizer = FakeQuantizer(weight_config)
69+
self.weight_fake_quantizer = FakeQuantizerBase.from_config(weight_config)
7070
else:
7171
self.weight_fake_quantizer = None
7272

torchao/quantization/qat/fake_quantizer.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,37 @@
3434
_fake_quantize_per_channel_group,
3535
_fake_quantize_per_token,
3636
_Float8RowwiseFakeQuantize,
37+
_log_deprecation_warning,
3738
)
3839

3940

40-
class FakeQuantizer(torch.nn.Module):
41+
class FakeQuantizerBase(torch.nn.Module):
4142
"""
4243
Generic module for applying fake quantization to a tensor, as specified in the config.
4344
"""
4445

45-
def __init__(self, config: FakeQuantizeConfigBase):
46+
config: FakeQuantizeConfigBase
47+
48+
def __repr__(self) -> str:
49+
"""
50+
Return a human readable representation of this `FakeQuantizer` with config details.
51+
"""
52+
return "FakeQuantizer(%s)" % self.config
53+
54+
@staticmethod
55+
def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase":
56+
if isinstance(config, IntxFakeQuantizeConfig):
57+
return IntxFakeQuantizer(config)
58+
else:
59+
raise ValueError(f"Unknown config type: {config}")
60+
61+
62+
class IntxFakeQuantizer(FakeQuantizerBase):
63+
"""
64+
Generic module for applying integer fake quantization to a tensor, as specified in the config.
65+
"""
66+
67+
def __init__(self, config: IntxFakeQuantizeConfig):
4668
super().__init__()
4769
self.config = config
4870
self.enabled = True
@@ -62,9 +84,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6284
if not self.enabled:
6385
return x
6486

65-
if not isinstance(self.config, IntxFakeQuantizeConfig):
66-
raise ValueError("Only IntxFakeQuantizeConfig is supported currently")
67-
6887
if (
6988
self.config.range_learning
7089
and not self._initialized
@@ -186,13 +205,19 @@ def _maybe_update_qparams_for_range_learning(self) -> None:
186205
self.scale = torch.nn.Parameter(scale, requires_grad=True)
187206
self.zero_point = torch.nn.Parameter(zero_point, requires_grad=True)
188207

189-
def __repr__(self) -> str:
190-
"""
191-
Return a human readable representation of this `FakeQuantizer` with config details.
192-
"""
193-
return "FakeQuantizer(%s)" % self.config
208+
209+
# For BC
210+
class FakeQuantizer(IntxFakeQuantizer):
211+
"""
212+
(Deprecated) Please use :class:`~torchao.quantization.qat.IntxFakeQuantizer` instead.
213+
"""
214+
215+
def __init__(self, config: FakeQuantizeConfigBase):
216+
super().__init__(config)
217+
_log_deprecation_warning(self)
194218

195219

220+
# TODO: make this a FakeQuantizerBase
196221
class _Float8RowwiseActivationFakeQuantizer(torch.nn.Module):
197222
"""
198223
Simple fake quantizer for float8 rowwise fake quantization, intended for activations only.

torchao/quantization/qat/linear.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
IntxFakeQuantizeConfig,
3333
)
3434
from .fake_quantizer import (
35-
FakeQuantizer,
35+
FakeQuantizerBase,
3636
_Float8RowwiseActivationFakeQuantizer,
3737
)
3838
from .utils import (
@@ -84,7 +84,9 @@ def __init__(
8484
)
8585
# initialize activation fake quantizer
8686
if activation_config is not None:
87-
self.activation_fake_quantizer = FakeQuantizer(activation_config)
87+
self.activation_fake_quantizer = FakeQuantizerBase.from_config(
88+
activation_config
89+
)
8890
else:
8991
self.activation_fake_quantizer = None
9092

@@ -97,7 +99,7 @@ def __init__(
9799
"in_features (%s) %% group_size (%s) must be == 0"
98100
% (in_features, group_size)
99101
)
100-
self.weight_fake_quantizer = FakeQuantizer(weight_config)
102+
self.weight_fake_quantizer = FakeQuantizerBase.from_config(weight_config)
101103
else:
102104
self.weight_fake_quantizer = None
103105

0 commit comments

Comments
 (0)