Skip to content

Commit cf6df46

Browse files
committed
Move activation-only config to PARQ prototype
1 parent 701bc31 commit cf6df46

File tree

5 files changed

+41
-41
lines changed

5 files changed

+41
-41
lines changed

test/prototype/test_parq.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@
2626
UnifQuantizer,
2727
UnifTorchaoQuantizer,
2828
)
29-
from torchao.prototype.parq.quant.quant_api import StretchedIntxWeightOnlyConfig
29+
from torchao.prototype.parq.quant.quant_api import (
30+
Int8DynamicActivationOnlyConfig,
31+
StretchedIntxWeightOnlyConfig,
32+
)
3033
from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE
3134
from torchao.quantization.granularity import PerGroup
3235
from torchao.quantization.qat import QATConfig
3336
from torchao.quantization.quant_api import (
34-
Int8DynActOnlyConfig,
3537
Int8DynamicActivationIntxWeightConfig,
3638
IntxWeightOnlyConfig,
3739
_is_linear,
@@ -392,7 +394,7 @@ def test_int8_dynamic_activation_intx_e2e(
392394
optimizer.step()
393395

394396
# apply torchao quantized activations on top
395-
qat_config = QATConfig(Int8DynActOnlyConfig(), step="prepare")
397+
qat_config = QATConfig(Int8DynamicActivationOnlyConfig(), step="prepare")
396398
filter_fn = optimizer.get_filter_fn(model)
397399
quantize_(model, qat_config, filter_fn=filter_fn)
398400
out = model(x)

torchao/prototype/parq/quant/quant_api.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,47 @@
1010
import torch
1111
from torch import nn
1212

13+
from torchao.core.config import AOBaseConfig
1314
from torchao.dtypes import AffineQuantizedTensor, Layout, QDQLayout
14-
from torchao.quantization.granularity import PerAxis, PerGroup
15-
from torchao.quantization.quant_api import IntxWeightOnlyConfig
16-
from torchao.quantization.quant_primitives import (
17-
_SUB_BYTE_UINT_BOUNDS,
15+
from torchao.quantization import (
1816
MappingType,
17+
PerAxis,
18+
PerGroup,
1919
ZeroPointDomain,
20-
_get_reduction_params,
2120
dequantize_affine,
21+
to_linear_activation_quantized,
22+
)
23+
from torchao.quantization.quant_api import (
24+
IntxWeightOnlyConfig,
25+
_int8_asymm_per_token_quant,
26+
_int8_symm_per_token_reduced_range_quant,
27+
)
28+
from torchao.quantization.quant_primitives import (
29+
_SUB_BYTE_UINT_BOUNDS,
30+
_get_reduction_params,
2231
)
2332
from torchao.quantization.transform_module import register_quantize_module_handler
2433

2534

35+
@dataclass
36+
class Int8DynamicActivationOnlyConfig(AOBaseConfig):
37+
is_symmetric: bool = False
38+
39+
40+
@register_quantize_module_handler(Int8DynamicActivationOnlyConfig)
41+
def _int8_dynamic_activation_transform(
42+
module: torch.nn.Module, config: Int8DynamicActivationOnlyConfig
43+
) -> torch.nn.Module:
44+
weight = module.weight
45+
if config.is_symmetric:
46+
input_quant_func = _int8_symm_per_token_reduced_range_quant
47+
else:
48+
input_quant_func = _int8_asymm_per_token_quant
49+
weight = to_linear_activation_quantized(weight, input_quant_func)
50+
module.weight = torch.nn.Parameter(weight, requires_grad=False)
51+
return module
52+
53+
2654
def choose_qparams_stretched_affine(
2755
input_float: torch.Tensor,
2856
mapping_type: MappingType,

torchao/quantization/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
GemliteUIntXWeightOnlyConfig,
5555
Int4DynamicActivationInt4WeightConfig,
5656
Int4WeightOnlyConfig,
57-
Int8DynActOnlyConfig,
5857
Int8DynamicActivationInt4WeightConfig,
5958
Int8DynamicActivationInt8WeightConfig,
6059
Int8DynamicActivationIntxWeightConfig,
@@ -145,7 +144,6 @@
145144
"Int8DynamicActivationIntxWeightConfig",
146145
"Int4WeightOnlyConfig",
147146
"Float8DynamicActivationInt4WeightConfig",
148-
"Int8DynActOnlyConfig",
149147
"Int8WeightOnlyConfig",
150148
"Float8WeightOnlyConfig",
151149
"Float8DynamicActivationFloat8WeightConfig",

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,9 @@ def _infer_fake_quantize_configs(
290290
Return a 2-tuple of (activation_config, weight_config) for fake quantization.
291291
"""
292292
# avoid circular imports
293+
from torchao.prototype.parq.quant.quant_api import Int8DynamicActivationOnlyConfig
293294
from torchao.quantization import (
294295
Int4WeightOnlyConfig,
295-
Int8DynActOnlyConfig,
296296
Int8DynamicActivationInt4WeightConfig,
297297
)
298298

@@ -316,7 +316,7 @@ def _infer_fake_quantize_configs(
316316
zero_point_domain=base_config.zero_point_domain,
317317
)
318318
return (None, weight_config)
319-
elif isinstance(base_config, Int8DynActOnlyConfig):
319+
elif isinstance(base_config, Int8DynamicActivationOnlyConfig):
320320
act_config = IntxFakeQuantizeConfig(
321321
dtype=torch.int8,
322322
granularity="per_token",

torchao/quantization/quant_api.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@
2626
import torch.nn.utils.parametrize as parametrize
2727

2828
import torchao
29-
from torchao.core.config import (
30-
AOBaseConfig,
31-
)
29+
from torchao.core.config import AOBaseConfig
3230
from torchao.dtypes import (
3331
AffineQuantizedTensor,
3432
CutlassInt4PackedLayout,
@@ -148,7 +146,6 @@
148146
"gemlite_uintx_weight_only",
149147
"float8_dynamic_activation_float8_weight",
150148
"float8_static_activation_float8_weight",
151-
"Int8DynActOnlyConfig",
152149
"Int8DynActInt4WeightQuantizer",
153150
"Float8DynamicActivationFloat8SemiSparseWeightConfig",
154151
"ModuleFqnToConfig",
@@ -1313,31 +1310,6 @@ def _float8_cutlass_quant_sparse(
13131310
)
13141311

13151312

1316-
@dataclass
1317-
class Int8DynActOnlyConfig(AOBaseConfig):
1318-
"""
1319-
Configuration for applying int8 dynamic symmetric per-token activation quantization to linear layers.
1320-
Args:
1321-
is_symmetric: bool = False - Whether to use symmetric quantization for activations.
1322-
"""
1323-
1324-
is_symmetric: bool = False
1325-
1326-
1327-
@register_quantize_module_handler(Int8DynActOnlyConfig)
1328-
def _int8_dynamic_activation_transform(
1329-
module: torch.nn.Module, config: Int8DynActOnlyConfig
1330-
) -> torch.nn.Module:
1331-
weight = module.weight
1332-
if config.is_symmetric:
1333-
input_quant_func = _int8_symm_per_token_reduced_range_quant
1334-
else:
1335-
input_quant_func = _int8_asymm_per_token_quant
1336-
weight = to_linear_activation_quantized(weight, input_quant_func)
1337-
module.weight = torch.nn.Parameter(weight, requires_grad=False)
1338-
return module
1339-
1340-
13411313
@dataclass
13421314
class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
13431315
"""

0 commit comments

Comments
 (0)