Skip to content

Commit 57e29ee

Browse files
committed
Update Int8DynActOnlyConfig
1 parent b48499c commit 57e29ee

File tree

4 files changed

+37
-12
lines changed

4 files changed

+37
-12
lines changed

test/prototype/test_parq.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,7 @@
3939
quantize_,
4040
)
4141
from torchao.quantization.quant_primitives import MappingType
42-
from torchao.utils import (
43-
TORCH_VERSION_AT_LEAST_2_4,
44-
TORCH_VERSION_AT_LEAST_2_6,
45-
check_cpu_version,
46-
)
42+
from torchao.utils import check_cpu_version
4743

4844
_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4945

@@ -208,7 +204,6 @@ class TestUnifTorchaoQuantizer(common_utils.TestCase):
208204
def setUp(self):
209205
torch.manual_seed(123)
210206

211-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
212207
@common_utils.parametrize("group_size", [32, 256])
213208
def test_int4_weight_only(self, group_size: int = 32):
214209
model = M(m=512, n=512).to(_DEVICE, dtype=torch.bfloat16)
@@ -225,7 +220,6 @@ def test_int4_weight_only(self, group_size: int = 32):
225220
model, m_ref, Int4UnifTorchaoQuantizer(), b, group_size
226221
)
227222

228-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
229223
@common_utils.parametrize("b", [2, 3, 4, 8])
230224
@common_utils.parametrize("group_size", [32, 512])
231225
def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
@@ -243,7 +237,6 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
243237
quantizer = UnifTorchaoQuantizer()
244238
compare_quantized_models(model, m_ref, quantizer, b, group_size)
245239

246-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
247240
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
248241
def test_int4_weight_only_e2e(self, group_size: int = 32):
249242
model = M(m=512, n=512).to(torch.bfloat16).to(_DEVICE)
@@ -265,7 +258,6 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
265258
)
266259
compare_parq_convert(model, m_ref, optimizer, config)
267260

268-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
269261
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
270262
@common_utils.parametrize("b", [2, 3, 4, 8])
271263
def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
@@ -315,7 +307,6 @@ def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32
315307
torch.testing.assert_close(q, q_ref, atol=0, rtol=0)
316308
torch.testing.assert_close(Q, Q_ref, atol=0, rtol=0)
317309

318-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
319310
@common_utils.parametrize("b", [2, 3])
320311
@common_utils.parametrize("group_size", [32, 512])
321312
def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
@@ -337,7 +328,6 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
337328

338329
compare_quantized_models(model, m_ref, quantizer, b, group_size)
339330

340-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
341331
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
342332
@common_utils.parametrize("b", [2, 3])
343333
def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
@@ -369,7 +359,6 @@ class TestInt8DynamicActivationTorchaoQuantizer(common_utils.TestCase):
369359
def setUp(self):
370360
torch.manual_seed(123)
371361

372-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
373362
@common_utils.parametrize("b", [2, 3, 4, 8])
374363
@common_utils.parametrize("model_dtype", [torch.float16, torch.float32])
375364
@common_utils.parametrize("group_size", [32, 128])

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
GemliteUIntXWeightOnlyConfig,
5555
Int4DynamicActivationInt4WeightConfig,
5656
Int4WeightOnlyConfig,
57+
Int8DynActOnlyConfig,
5758
Int8DynamicActivationInt4WeightConfig,
5859
Int8DynamicActivationInt8WeightConfig,
5960
Int8DynamicActivationIntxWeightConfig,
@@ -144,6 +145,7 @@
144145
"Int8DynamicActivationIntxWeightConfig",
145146
"Int4WeightOnlyConfig",
146147
"Float8DynamicActivationInt4WeightConfig",
148+
"Int8DynActOnlyConfig",
147149
"Int8WeightOnlyConfig",
148150
"Float8WeightOnlyConfig",
149151
"Float8DynamicActivationFloat8WeightConfig",

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ def _infer_fake_quantize_configs(
292292
# avoid circular imports
293293
from torchao.quantization import (
294294
Int4WeightOnlyConfig,
295+
Int8DynActOnlyConfig,
295296
Int8DynamicActivationInt4WeightConfig,
296297
)
297298

@@ -315,5 +316,12 @@ def _infer_fake_quantize_configs(
315316
zero_point_domain=base_config.zero_point_domain,
316317
)
317318
return (None, weight_config)
319+
elif isinstance(base_config, Int8DynActOnlyConfig):
320+
act_config = IntxFakeQuantizeConfig(
321+
dtype=torch.int8,
322+
granularity="per_token",
323+
is_symmetric=base_config.is_symmetric,
324+
)
325+
return (act_config, None)
318326
else:
319327
raise ValueError("Unexpected base config: %s" % base_config)

torchao/quantization/quant_api.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@
148148
"gemlite_uintx_weight_only",
149149
"float8_dynamic_activation_float8_weight",
150150
"float8_static_activation_float8_weight",
151+
"Int8DynActOnlyConfig",
151152
"Int8DynActInt4WeightQuantizer",
152153
"Float8DynamicActivationFloat8SemiSparseWeightConfig",
153154
"ModuleFqnToConfig",
@@ -1312,6 +1313,31 @@ def _float8_cutlass_quant_sparse(
13121313
)
13131314

13141315

1316+
@dataclass
1317+
class Int8DynActOnlyConfig(AOBaseConfig):
1318+
"""
1319+
Configuration for applying int8 dynamic symmetric per-token activation quantization to linear layers.
1320+
Args:
1321+
is_symmetric: bool = False - Whether to use symmetric quantization for activations.
1322+
"""
1323+
1324+
is_symmetric: bool = False
1325+
1326+
1327+
@register_quantize_module_handler(Int8DynActOnlyConfig)
1328+
def _int8_dynamic_activation_transform(
1329+
module: torch.nn.Module, config: Int8DynActOnlyConfig
1330+
) -> torch.nn.Module:
1331+
weight = module.weight
1332+
if config.is_symmetric == MappingType.SYMMETRIC:
1333+
input_quant_func = _int8_symm_per_token_reduced_range_quant
1334+
else:
1335+
input_quant_func = _int8_asymm_per_token_quant
1336+
weight = to_linear_activation_quantized(weight, input_quant_func)
1337+
module.weight = torch.nn.Parameter(weight, requires_grad=False)
1338+
return module
1339+
1340+
13151341
@dataclass
13161342
class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
13171343
"""

0 commit comments

Comments
 (0)