Skip to content

Commit 29281fd

Browse files
committed
Update Int8DynActOnlyConfig
1 parent b48499c commit 29281fd

File tree

3 files changed

+36
-0
lines changed

3 files changed

+36
-0
lines changed

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
GemliteUIntXWeightOnlyConfig,
5555
Int4DynamicActivationInt4WeightConfig,
5656
Int4WeightOnlyConfig,
57+
Int8DynActOnlyConfig,
5758
Int8DynamicActivationInt4WeightConfig,
5859
Int8DynamicActivationInt8WeightConfig,
5960
Int8DynamicActivationIntxWeightConfig,
@@ -144,6 +145,7 @@
144145
"Int8DynamicActivationIntxWeightConfig",
145146
"Int4WeightOnlyConfig",
146147
"Float8DynamicActivationInt4WeightConfig",
148+
"Int8DynActOnlyConfig",
147149
"Int8WeightOnlyConfig",
148150
"Float8WeightOnlyConfig",
149151
"Float8DynamicActivationFloat8WeightConfig",

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ def _infer_fake_quantize_configs(
292292
# avoid circular imports
293293
from torchao.quantization import (
294294
Int4WeightOnlyConfig,
295+
Int8DynActOnlyConfig,
295296
Int8DynamicActivationInt4WeightConfig,
296297
)
297298

@@ -315,5 +316,12 @@ def _infer_fake_quantize_configs(
315316
zero_point_domain=base_config.zero_point_domain,
316317
)
317318
return (None, weight_config)
319+
elif isinstance(base_config, Int8DynActOnlyConfig):
320+
act_config = IntxFakeQuantizeConfig(
321+
dtype=torch.int8,
322+
granularity="per_token",
323+
is_symmetric=base_config.is_symmetric,
324+
)
325+
return (act_config, None)
318326
else:
319327
raise ValueError("Unexpected base config: %s" % base_config)

torchao/quantization/quant_api.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@
148148
"gemlite_uintx_weight_only",
149149
"float8_dynamic_activation_float8_weight",
150150
"float8_static_activation_float8_weight",
151+
"Int8DynActOnlyConfig",
151152
"Int8DynActInt4WeightQuantizer",
152153
"Float8DynamicActivationFloat8SemiSparseWeightConfig",
153154
"ModuleFqnToConfig",
@@ -1312,6 +1313,31 @@ def _float8_cutlass_quant_sparse(
13121313
)
13131314

13141315

1316+
@dataclass
1317+
class Int8DynActOnlyConfig(AOBaseConfig):
1318+
"""
1319+
Configuration for applying int8 dynamic symmetric per-token activation quantization to linear layers.
1320+
Args:
1321+
is_symmetric: bool = False - Whether to use symmetric quantization for activations.
1322+
"""
1323+
1324+
is_symmetric: bool = False
1325+
1326+
1327+
@register_quantize_module_handler(Int8DynActOnlyConfig)
1328+
def _int8_dynamic_activation_transform(
1329+
module: torch.nn.Module, config: Int8DynActOnlyConfig
1330+
) -> torch.nn.Module:
1331+
weight = module.weight
1332+
if config.is_symmetric == MappingType.SYMMETRIC:
1333+
input_quant_func = _int8_symm_per_token_reduced_range_quant
1334+
else:
1335+
input_quant_func = _int8_asymm_per_token_quant
1336+
weight = to_linear_activation_quantized(weight, input_quant_func)
1337+
module.weight = torch.nn.Parameter(weight, requires_grad=False)
1338+
return module
1339+
1340+
13151341
@dataclass
13161342
class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
13171343
"""

0 commit comments

Comments
 (0)