Skip to content

Commit 6e9bf26

Browse files
authored
Support QAT int4 v1 path for BC (#2888)
**Summary:** `Int4WeightOnlyConfig` supports version 1 (targeting tinygemm) and version 2 (targeting fbgemm). However, the latter requires a new dependency (fbgemm_gpu_genai >= 1.2.0), which is problematic for torchao integrations with other frameworks. For now, we should continue to support the v1 path for BC. **Test Plan:** ``` python test/quantization/test_qat.py -k test_infer_int4_weight_only_config ```
1 parent 3bf21d0 commit 6e9bf26

File tree

2 files changed

+64
-7
lines changed

2 files changed

+64
-7
lines changed

test/quantization/test_qat.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from torchao.quantization.quant_api import (
7070
Float8DynamicActivationFloat8WeightConfig,
7171
Float8DynamicActivationInt4WeightConfig,
72+
Int4WeightOnlyConfig,
7273
Int8DynamicActivationInt4WeightConfig,
7374
)
7475
from torchao.quantization.quant_primitives import (
@@ -1933,6 +1934,22 @@ def test_quantize_api_fp8_int4(self):
19331934
)
19341935

19351936
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1937+
@unittest.skipIf(
1938+
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
1939+
)
1940+
@parametrize("version", [1, 2])
1941+
def test_quantize_api_int4(self, version: int):
1942+
"""
1943+
Test the following:
1944+
quantize_(model, QATConfig(Int4WeightOnlyConfig(), step="prepare"))
1945+
quantize_(model, QATConfig(Int4WeightOnlyConfig(), step="convert"))
1946+
"""
1947+
self._test_quantize_api_against_ptq(
1948+
Int4WeightOnlyConfig(version=version),
1949+
target_prepare_sqnr=12,
1950+
target_convert_sqnr=float("inf"),
1951+
)
1952+
19361953
def test_infer_fp8_int4_config(self):
19371954
"""
19381955
Test that fake quantize configs are correctly inferred from
@@ -1952,6 +1969,29 @@ def test_infer_fp8_int4_config(self):
19521969
self.assertEqual(weight_config.group_size, 128)
19531970
self.assertTrue(weight_config.is_symmetric)
19541971

1972+
def test_infer_int4_weight_only_config(self):
1973+
"""
1974+
Test that fake quantize configs are correctly inferred from `Int4WeightOnlyConfig`.
1975+
"""
1976+
from torchao.quantization.qat.fake_quantize_config import (
1977+
_infer_fake_quantize_configs,
1978+
)
1979+
1980+
base_config = Int4WeightOnlyConfig(version=1)
1981+
(act_config, weight_config) = _infer_fake_quantize_configs(base_config)
1982+
self.assertIsNone(act_config)
1983+
self.assertIsInstance(weight_config, IntxFakeQuantizeConfig)
1984+
self.assertEqual(weight_config.dtype, torch.uint4)
1985+
self.assertEqual(weight_config.group_size, 128)
1986+
self.assertFalse(weight_config.is_symmetric)
1987+
1988+
base_config = Int4WeightOnlyConfig(version=2)
1989+
(act_config, weight_config) = _infer_fake_quantize_configs(base_config)
1990+
self.assertIsNone(act_config)
1991+
self.assertEqual(weight_config.dtype, torch.int4)
1992+
self.assertEqual(weight_config.group_size, 128)
1993+
self.assertTrue(weight_config.is_symmetric)
1994+
19551995
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
19561996
def test_quantize_api_nvfp4(self):
19571997
"""

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -358,14 +358,31 @@ def _infer_fake_quantize_configs(
358358
is_symmetric=base_config.mapping_type == MappingType.SYMMETRIC,
359359
)
360360
elif isinstance(base_config, Int4WeightOnlyConfig):
361-
if base_config.version != 2:
362-
raise ValueError(f"Only version 2 of {type(base_config)} is supported")
363361
act_config = None
364-
weight_config = IntxFakeQuantizeConfig(
365-
dtype=torch.int4,
366-
group_size=base_config.group_size,
367-
is_symmetric=True,
368-
)
362+
if base_config.version == 2:
363+
weight_config = IntxFakeQuantizeConfig(
364+
dtype=torch.int4,
365+
group_size=base_config.group_size,
366+
is_symmetric=True,
367+
)
368+
elif base_config.version == 1:
369+
# For BC
370+
from torchao.quantization.quant_api import (
371+
LAYOUT_TO_ZERO_POINT_DOMAIN,
372+
)
373+
374+
if base_config.zero_point_domain == ZeroPointDomain.NONE:
375+
zp_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(base_config.layout)][0]
376+
else:
377+
zp_domain = base_config.zero_point_domain
378+
weight_config = IntxFakeQuantizeConfig(
379+
dtype=torch.uint4,
380+
group_size=base_config.group_size,
381+
is_symmetric=False,
382+
zero_point_domain=zp_domain,
383+
)
384+
else:
385+
raise ValueError(f"Unknown version on base config {type(base_config)}")
369386
elif isinstance(base_config, Float8DynamicActivationFloat8WeightConfig):
370387
if base_config.version != 2:
371388
raise ValueError(f"Only version 2 of {type(base_config)} is supported")

0 commit comments

Comments
 (0)