Skip to content

Commit 11f1a76

Browse files
authored
Clean up QAT API surface + add separate API ref (#2567)
This commit does a few things: 1. Make AffineFakeQuantizedTensor and associated functions private. These are not meant to be exposed to users yet. 2. Expose some commonly used APIs to top level (e.g. FakeQuantizer) 3. Deprecate some QAT APIs 4. Add separate API ref to better categorize QAT APIs As of this commit, all APIs under `torchao.quantization.qat` should be either public and documented, deprecated, or private. To preview docs: https://docs-preview.pytorch.org/pytorch/ao/2567/api_ref_qat.html
1 parent 460aaed commit 11f1a76

File tree

12 files changed

+118
-101
lines changed

12 files changed

+118
-101
lines changed

docs/source/api_ref_qat.rst

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
.. _api_qat:
2+
3+
========================
4+
torchao.quantization.qat
5+
========================
6+
7+
.. currentmodule:: torchao.quantization.qat
8+
9+
QAT Configs for quantize_
10+
---------------------------------------
11+
For a full example of how to use QAT with our main `quantize_` API,
12+
please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md#quantize_-api-recommended>`__.
13+
14+
.. autosummary::
15+
:toctree: generated/
16+
:nosignatures:
17+
18+
IntXQuantizationAwareTrainingConfig
19+
FromIntXQuantizationAwareTrainingConfig
20+
21+
Custom QAT APIs
22+
---------------
23+
.. autosummary::
24+
:toctree: generated/
25+
:nosignatures:
26+
27+
FakeQuantizeConfig
28+
FakeQuantizedLinear
29+
FakeQuantizedEmbedding
30+
FakeQuantizer
31+
linear.enable_linear_fake_quant
32+
linear.disable_linear_fake_quant
33+
34+
Legacy QAT Quantizers
35+
---------------------
36+
37+
.. autosummary::
38+
:toctree: generated/
39+
:nosignatures:
40+
41+
Int4WeightOnlyQATQuantizer
42+
linear.Int4WeightOnlyQATLinear
43+
Int8DynActInt4WeightQATQuantizer
44+
linear.Int8DynActInt4WeightQATLinear
45+
Int4WeightOnlyEmbeddingQATQuantizer
46+
embedding.Int4WeightOnlyQATEmbedding
47+
embedding.Int4WeightOnlyEmbedding
48+
Float8ActInt4WeightQATQuantizer
49+
ComposableQATQuantizer
50+
51+
Prototype
52+
---------
53+
54+
.. autosummary::
55+
:toctree: generated/
56+
:nosignatures:
57+
58+
initialize_fake_quantizers

docs/source/api_ref_quantization.rst

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,6 @@ Inference APIs for quantize\_
3434
UIntXWeightOnlyConfig
3535
FPXWeightOnlyConfig
3636

37-
.. currentmodule:: torchao.quantization.qat
38-
39-
QAT APIs
40-
----------------------
41-
42-
.. autosummary::
43-
:toctree: generated/
44-
:nosignatures:
45-
46-
IntXQuantizationAwareTrainingConfig
47-
FromIntXQuantizationAwareTrainingConfig
48-
FakeQuantizeConfig
49-
Int4WeightOnlyQATQuantizer
50-
Int8DynActInt4WeightQATQuantizer
51-
Int4WeightOnlyEmbeddingQATQuantizer
52-
ComposableQATQuantizer
53-
initialize_fake_quantizers
54-
5537
.. currentmodule:: torchao.quantization
5638

5739
Quantization Primitives

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ for an overall introduction to the library and recent highlight and updates.
3131

3232
api_ref_dtypes
3333
api_ref_quantization
34+
api_ref_qat
3435
api_ref_sparsity
3536
api_ref_float8
3637

test/quantization/test_qat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,8 +1223,8 @@ def test_qat_prototype_bc(self):
12231223
Int8DynActInt4WeightQATQuantizerModuleSwap,
12241224
)
12251225
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( # noqa: F401, F811
1226-
AffineFakeQuantizedTensor,
1227-
to_affine_fake_quantized,
1226+
_AffineFakeQuantizedTensor,
1227+
_to_affine_fake_quantized,
12281228
)
12291229
from torchao.quantization.prototype.qat.api import ( # noqa: F401, F811
12301230
ComposableQATQuantizer,

torchao/prototype/quantization/autoquant_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
def _is_linear(mod, *args):
7575
# avoid circular dependencies
7676
from torchao.quantization.qat.affine_fake_quantized_tensor import (
77-
AffineFakeQuantizedTensor,
77+
_AffineFakeQuantizedTensor,
7878
)
7979

8080
# adding weight tensor subclass isinstance check to make sure the weight is only quantized once
@@ -86,7 +86,7 @@ def _is_linear(mod, *args):
8686
and not isinstance(mod.weight, AutoQuantizableLinearWeightV1)
8787
and not isinstance(mod.weight, AffineQuantizedTensor)
8888
and not isinstance(mod.weight, LinearActivationQuantizedTensor)
89-
and not isinstance(mod.weight, AffineFakeQuantizedTensor)
89+
and not isinstance(mod.weight, _AffineFakeQuantizedTensor)
9090
and not isinstance(mod, torch.nn.modules.linear.NonDynamicallyQuantizableLinear)
9191
)
9292

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from torchao.quantization.qat.affine_fake_quantized_tensor import (
2-
AffineFakeQuantizedTensor,
3-
to_affine_fake_quantized,
2+
_AffineFakeQuantizedTensor,
3+
_to_affine_fake_quantized,
44
)
55

66
__all__ = [
7-
"AffineFakeQuantizedTensor",
8-
"to_affine_fake_quantized",
7+
"_AffineFakeQuantizedTensor",
8+
"_to_affine_fake_quantized",
99
]

torchao/quantization/qat/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
intx_quantization_aware_training,
99
)
1010
from .embedding import (
11+
FakeQuantizedEmbedding,
1112
Int4WeightOnlyEmbeddingQATQuantizer,
1213
)
14+
from .fake_quantizer import FakeQuantizer
1315
from .linear import (
16+
FakeQuantizedLinear,
1417
Float8ActInt4WeightQATQuantizer,
1518
Int4WeightOnlyQATQuantizer,
1619
Int8DynActInt4WeightQATQuantizer,
@@ -19,6 +22,9 @@
1922
__all__ = [
2023
"ComposableQATQuantizer",
2124
"FakeQuantizeConfig",
25+
"FakeQuantizedLinear",
26+
"FakeQuantizedEmbedding",
27+
"FakeQuantizer",
2228
"Float8ActInt4WeightQATQuantizer",
2329
"FromIntXQuantizationAwareTrainingConfig",
2430
"Int4WeightOnlyEmbeddingQATQuantizer",

torchao/quantization/qat/affine_fake_quantized_tensor.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,12 @@
2020
)
2121
from torchao.utils import TorchAOBaseTensor
2222

23-
from .utils import (
24-
_UnwrapAffineFakeQuantizedTensor,
25-
)
26-
2723
aten = torch.ops.aten
2824

2925

3026
class _ToAffineFakeQuantized(torch.autograd.Function):
3127
"""
32-
Differentiable constructor for `AffineFakeQuantizedTensor`,
28+
Differentiable constructor for `_AffineFakeQuantizedTensor`,
3329
needed for input activation fake quantization.
3430
"""
3531

@@ -47,12 +43,12 @@ def forward(
4743
zero_point_dtype: Optional[torch.dtype] = None,
4844
preserve_zero: bool = True,
4945
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
50-
) -> "AffineFakeQuantizedTensor":
46+
) -> "_AffineFakeQuantizedTensor":
5147
if zero_point_domain is None:
5248
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
5349

5450
def apply_fake_quant_fn(t: torch.Tensor):
55-
assert isinstance(t, AffineFakeQuantizedTensor)
51+
assert isinstance(t, _AffineFakeQuantizedTensor)
5652
qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
5753
if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero:
5854
scale, zero_point = _choose_qparams_affine_tinygemm(
@@ -102,7 +98,7 @@ def apply_fake_quant_fn(t: torch.Tensor):
10298
)
10399
return fq
104100

105-
return AffineFakeQuantizedTensor(
101+
return _AffineFakeQuantizedTensor(
106102
original_tensor,
107103
apply_fake_quant_fn,
108104
fake_quant_enabled=True,
@@ -113,7 +109,7 @@ def backward(ctx, gy):
113109
return gy, None, None, None, None, None, None, None, None, None, None
114110

115111

116-
class AffineFakeQuantizedTensor(TorchAOBaseTensor):
112+
class _AffineFakeQuantizedTensor(TorchAOBaseTensor):
117113
"""
118114
Affine fake quantized tensor subclass. Affine quantization means we quantize the floating point tensor
119115
with an affine transformation:
@@ -212,7 +208,7 @@ def get_value(self) -> torch.Tensor:
212208
if self.fake_quant_enabled:
213209
return self.apply_fake_quant_fn(self)
214210
else:
215-
return _UnwrapAffineFakeQuantizedTensor.apply(self)
211+
return self.original_tensor
216212

217213
def _get_to_kwargs(self, *args, **kwargs):
218214
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
@@ -243,14 +239,14 @@ def to(self, *args, **kwargs):
243239

244240
def _apply_fn_to_data(self, fn: Callable):
245241
"""
246-
Create a new `AffineFakeQuantizedTensor` with `fn` applied to the
242+
Create a new `_AffineFakeQuantizedTensor` with `fn` applied to the
247243
original tensor, to be called within __torch_dispatch__.
248244
"""
249245
return self._create_new(fn(self.original_tensor))
250246

251247
def _create_new(self, new_value: torch.Tensor):
252248
"""
253-
Create a new `AffineFakeQuantizedTensor` with a new value,
249+
Create a new `_AffineFakeQuantizedTensor` with a new value,
254250
to be called within __torch_dispatch__.
255251
256252
Note: `requires_grad` must be False here because tensors created
@@ -267,7 +263,7 @@ def _create_new(self, new_value: torch.Tensor):
267263
)
268264

269265

270-
implements = AffineFakeQuantizedTensor.implements
266+
implements = _AffineFakeQuantizedTensor.implements
271267

272268

273269
@implements(torch.nn.functional.linear)
@@ -277,9 +273,9 @@ def _(func, types, args, kwargs):
277273
args[1],
278274
args[2] if len(args) > 2 else None,
279275
)
280-
if isinstance(input_tensor, AffineFakeQuantizedTensor):
276+
if isinstance(input_tensor, _AffineFakeQuantizedTensor):
281277
input_tensor = input_tensor.get_value()
282-
if isinstance(weight_tensor, AffineFakeQuantizedTensor):
278+
if isinstance(weight_tensor, _AffineFakeQuantizedTensor):
283279
weight_tensor = weight_tensor.get_value()
284280
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
285281

@@ -288,9 +284,9 @@ def _(func, types, args, kwargs):
288284
def _(func, types, args, kwargs):
289285
input_tensor = args[0]
290286
weight_tensor = args[1]
291-
if isinstance(input_tensor, AffineFakeQuantizedTensor):
287+
if isinstance(input_tensor, _AffineFakeQuantizedTensor):
292288
input_tensor = input_tensor.get_value()
293-
if isinstance(weight_tensor, AffineFakeQuantizedTensor):
289+
if isinstance(weight_tensor, _AffineFakeQuantizedTensor):
294290
weight_tensor = weight_tensor.get_value()
295291
return func(input_tensor, weight_tensor)
296292

@@ -300,9 +296,9 @@ def _(func, types, args, kwargs):
300296
bias = args[0]
301297
input_tensor = args[1]
302298
weight_tensor = args[2]
303-
if isinstance(input_tensor, AffineFakeQuantizedTensor):
299+
if isinstance(input_tensor, _AffineFakeQuantizedTensor):
304300
input_tensor = input_tensor.get_value()
305-
if isinstance(weight_tensor, AffineFakeQuantizedTensor):
301+
if isinstance(weight_tensor, _AffineFakeQuantizedTensor):
306302
weight_tensor = weight_tensor.get_value()
307303
return func(bias, input_tensor, weight_tensor)
308304

@@ -348,10 +344,10 @@ def _(func, types, args, kwargs):
348344
def _(func, types, args, kwargs):
349345
assert len(args) == 2, f"dispatched the wrong op to the binary handler: {func}"
350346
new_args = pytree.tree_map_only(
351-
AffineFakeQuantizedTensor, lambda x: x.original_tensor, args
347+
_AffineFakeQuantizedTensor, lambda x: x.original_tensor, args
352348
)
353349
first_afq_tensor = (
354-
args[0] if isinstance(args[0], AffineFakeQuantizedTensor) else args[1]
350+
args[0] if isinstance(args[0], _AffineFakeQuantizedTensor) else args[1]
355351
)
356352
new_value = func(*new_args, **kwargs)
357353
out = first_afq_tensor._create_new(new_value)
@@ -384,4 +380,4 @@ def _(func, types, args, kwargs):
384380
return return_and_correct_aliasing(func, args, kwargs, out)
385381

386382

387-
to_affine_fake_quantized = AffineFakeQuantizedTensor.from_float
383+
_to_affine_fake_quantized = _AffineFakeQuantizedTensor.from_float

torchao/quantization/qat/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class FakeQuantizeConfig:
3434
"""
3535
Config for how to fake quantize weights or activations.
3636
37-
args:
37+
Args:
3838
dtype: dtype to simulate during fake quantization, e.g. torch.int8.
3939
For PyTorch versions older than 2.6, you may use `TorchAODType` to represent
4040
torch.int1 to torch.int7 instead, e.g. TorchAODType.INT4.
@@ -54,7 +54,7 @@ class FakeQuantizeConfig:
5454
range_learning (prototype): whether to learn scale and zero points during training
5555
(default false), not compatible with `is_dynamic`.
5656
57-
kwargs (optional):
57+
Keyword args:
5858
group_size: size of each group in per group fake quantization,
5959
can be set instead of `granularity`
6060
is_symmetric: whether to use symmetric or asymmetric quantization,

0 commit comments

Comments
 (0)