Skip to content

Commit 715ea9f

Browse files
authored
Add float8 FakeQuantizeConfig and FakeQuantizer (#2735)
**Summary:** This commit adds a QAT path for float8, using the same primitives as `torchao.quantization.Float8Tensor` targeting the following PTQ configs: - `Float8DynamicActivationFloat8WeightConfig` - `Float8DynamicActivationInt4WeightConfig` Usage: ``` from torchao.quantization.granularity import PerRow from torchao.quantization.qat import quantize_, QATConfig base_config = Float8DynamicActivationFloat8WeightConfig( torch.float8_e4m3fn, PerRow(), ) quantize_(model, QATConfig(base_config, step="prepare")) quantize_(model, QATConfig(base_config, step="convert")) ``` OR ``` from torchao.quantization.granularity import PerRow from torchao.quantization.qat import ( Float8FakeQuantizeConfig, QATConfig, quantize_, ) dtype = torch.float8_e4m3fn granularity = PerRow() quantize_(model, QATConfig( activation_config=Float8FakeQuantizeConfig(dtype, granularity), weight_config=Float8FakeQuantizeConfig(dtype, granularity), step="prepare", ) # convert (same as above, not shown) ``` **Test Plan:** ``` python test/quantization/test_qat.py -k test_float8_fake_quantize_config python test/quantization/test_qat.py -k test_float8_fake_quantize python test/quantization/test_qat.py -k test_quantize_api_fp8_fp8 python test/quantization/test_qat.py -k test_quantize_api_fp8_int4 ```
1 parent a1a9632 commit 715ea9f

File tree

8 files changed

+261
-100
lines changed

8 files changed

+261
-100
lines changed

docs/source/api_ref_qat.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@ Custom QAT APIs
2626

2727
FakeQuantizeConfigBase
2828
IntxFakeQuantizeConfig
29+
Float8FakeQuantizeConfig
2930
FakeQuantizedLinear
3031
FakeQuantizedEmbedding
3132
FakeQuantizerBase
3233
IntxFakeQuantizer
34+
Float8FakeQuantizer
3335
linear.enable_linear_fake_quant
3436
linear.disable_linear_fake_quant
3537

test/quantization/test_qat.py

Lines changed: 131 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,22 @@
1414

1515
import torch
1616
import torch.nn.functional as F
17-
from parameterized import parameterized
1817
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
18+
from torch.testing._internal.common_utils import (
19+
TestCase,
20+
instantiate_parametrized_tests,
21+
parametrize,
22+
)
1923

2024
from torchao import quantize_
21-
from torchao.float8.config import ScalingGranularity
22-
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
23-
from torchao.float8.float8_training_tensor import LinearMMConfig
25+
from torchao.core.config import AOBaseConfig
26+
from torchao.quantization import Float8Tensor
2427
from torchao.quantization.granularity import (
28+
Granularity,
2529
PerAxis,
2630
PerGroup,
2731
PerRow,
32+
PerTensor,
2833
PerToken,
2934
)
3035
from torchao.quantization.linear_quant_modules import (
@@ -43,11 +48,12 @@
4348
FakeQuantizedEmbedding,
4449
)
4550
from torchao.quantization.qat.fake_quantize_config import (
51+
Float8FakeQuantizeConfig,
4652
IntxFakeQuantizeConfig,
4753
)
4854
from torchao.quantization.qat.fake_quantizer import (
55+
Float8FakeQuantizer,
4956
IntxFakeQuantizer,
50-
_Float8RowwiseActivationFakeQuantizer,
5157
)
5258
from torchao.quantization.qat.linear import (
5359
FakeQuantizedLinear,
@@ -58,10 +64,11 @@
5864
from torchao.quantization.qat.utils import (
5965
_fake_quantize_per_channel_group,
6066
_fake_quantize_per_token,
61-
_Float8RowwiseFakeQuantize,
6267
_get_qmin_qmax,
6368
)
6469
from torchao.quantization.quant_api import (
70+
Float8DynamicActivationFloat8WeightConfig,
71+
Float8DynamicActivationInt4WeightConfig,
6572
Int8DynamicActivationInt4WeightConfig,
6673
)
6774
from torchao.quantization.quant_primitives import (
@@ -83,6 +90,10 @@
8390
get_groupwise_affine_qparams,
8491
groupwise_affine_quantize_tensor,
8592
)
93+
from torchao.utils import (
94+
_is_fbgemm_genai_gpu_available,
95+
is_sm_at_least_89,
96+
)
8697

8798
# TODO: put this in a common test utils file
8899
_CUDA_IS_AVAILABLE = torch.cuda.is_available()
@@ -193,7 +204,7 @@ def forward(self, x):
193204
return x
194205

195206

196-
class TestQAT(unittest.TestCase):
207+
class TestQAT(TestCase):
197208
SEED = 123
198209

199210
def test_fake_quantize_per_channel_group(self):
@@ -1420,7 +1431,7 @@ def test_qat_linear_bias(self):
14201431
example_inputs = m.example_inputs()
14211432
m(*example_inputs)
14221433

1423-
@parameterized.expand([(torch.float32,), (torch.bfloat16,), (torch.float16,)])
1434+
@parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
14241435
def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
14251436
"""
14261437
Test that the following produce the exact same numerics:
@@ -1437,7 +1448,7 @@ def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
14371448
baseline_out = per_token_dynamic_quant(x)
14381449
torch.testing.assert_close(fake_quantizer_out, baseline_out, atol=0, rtol=0)
14391450

1440-
@parameterized.expand([(torch.float32,), (torch.bfloat16,), (torch.float16,)])
1451+
@parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
14411452
def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
14421453
"""
14431454
Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
@@ -1548,7 +1559,7 @@ def test_qat_8da4w_eps(self):
15481559
actual_out = converted_model.linear1(x)
15491560
torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0)
15501561

1551-
@parameterized.expand([(True,), (False,)])
1562+
@parametrize("is_symmetric", [True, False])
15521563
def test_fake_quantizer_range_learning(self, is_symmetric):
15531564
"""
15541565
Test that range learning requires `IntxFakeQuantizer`s to be initialized correctly.
@@ -1589,7 +1600,7 @@ def test_fake_quantizer_range_learning(self, is_symmetric):
15891600
self.assertTrue(fake_quantizer.zero_point.requires_grad)
15901601
fake_quantizer(*example_inputs)
15911602

1592-
@parameterized.expand([(True,), (False,)])
1603+
@parametrize("is_symmetric", [True, False])
15931604
def test_qat_range_learning(self, is_symmetric):
15941605
"""
15951606
Test end-to-end QAT flow with range learning.
@@ -1664,24 +1675,6 @@ def test_qat_range_learning(self, is_symmetric):
16641675
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
16651676
self.assertFalse(torch.equal(new_weight, prev_weight))
16661677

1667-
def test_float8_rowwise_fake_quantize(self):
1668-
"""
1669-
Test that `_Float8RowwiseFakeQuantize` is numerically close to `Float8TrainingTensor`.
1670-
"""
1671-
torch.manual_seed(self.SEED)
1672-
dtype = torch.float8_e4m3fn
1673-
x = torch.randn(32, 64)
1674-
axiswise_dim = 0
1675-
out = _Float8RowwiseFakeQuantize.apply(x, dtype, axiswise_dim)
1676-
out_expected = hp_tensor_to_float8_dynamic(
1677-
x,
1678-
dtype,
1679-
LinearMMConfig(),
1680-
scaling_granularity=ScalingGranularity.AXISWISE,
1681-
axiswise_dim=axiswise_dim,
1682-
).to_original_precision()
1683-
torch.testing.assert_close(out, out_expected, atol=0, rtol=0)
1684-
16851678
def test_qat_fp8a4w_quantizer(self):
16861679
"""
16871680
Test basic model training with `Float8ActInt4WeightQATQuantizer`.
@@ -1693,7 +1686,8 @@ def test_qat_fp8a4w_quantizer(self):
16931686
for linear in [m.linear1, m.sub.linear, m.linear2]:
16941687
self.assertIsInstance(linear, FakeQuantizedLinear)
16951688
self.assertIsInstance(
1696-
linear.activation_fake_quantizer, _Float8RowwiseActivationFakeQuantizer
1689+
linear.activation_fake_quantizer,
1690+
Float8FakeQuantizer,
16971691
)
16981692
self.assertIsInstance(linear.weight_fake_quantizer, IntxFakeQuantizer)
16991693
prev_weight = copy.deepcopy(m.linear1.weight)
@@ -1833,6 +1827,113 @@ def test_qat_api_convert_no_quantization(self):
18331827
baseline_out = baseline_model(*x2)
18341828
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
18351829

1830+
def test_float8_fake_quantize_config(self):
1831+
"""
1832+
Test that the correct errors are thrown if `Float8FakeQuantizeConfig` is not instantiated properly.
1833+
"""
1834+
# OK
1835+
Float8FakeQuantizeConfig(torch.float8_e4m3fn)
1836+
Float8FakeQuantizeConfig(torch.float8_e4m3fn, PerRow())
1837+
Float8FakeQuantizeConfig(torch.float8_e4m3fn, PerTensor())
1838+
1839+
with self.assertRaisesRegex(ValueError, "not a float8 dtype"):
1840+
Float8FakeQuantizeConfig(torch.int8)
1841+
with self.assertRaisesRegex(
1842+
ValueError, "Please specify the granularity object instead of the class"
1843+
):
1844+
Float8FakeQuantizeConfig(granularity=PerRow)
1845+
with self.assertRaisesRegex(
1846+
ValueError, "Expected PerRow or PerTensor granularity"
1847+
):
1848+
Float8FakeQuantizeConfig(granularity=PerToken())
1849+
1850+
@parametrize("granularity", [PerTensor(), PerRow()])
1851+
def test_float8_fake_quantize(self, granularity: Granularity):
1852+
"""
1853+
Test that `Float8FakeQuantizer` is numerically close to `Float8Tensor`.
1854+
"""
1855+
dtype = torch.float8_e4m3fn
1856+
fq_config = Float8FakeQuantizeConfig(dtype, granularity)
1857+
fake_quantizer = Float8FakeQuantizer(fq_config)
1858+
torch.manual_seed(self.SEED)
1859+
x = torch.randn(32, 64)
1860+
out = fake_quantizer(x)
1861+
out_expected = Float8Tensor.to_float8(x, dtype, granularity).dequantize()
1862+
sqnr = compute_error(out, out_expected)
1863+
self.assertGreater(sqnr, 16)
1864+
1865+
def _test_quantize_api_against_ptq(
1866+
self,
1867+
base_config: AOBaseConfig,
1868+
target_prepare_sqnr: float,
1869+
target_convert_sqnr: float,
1870+
):
1871+
"""
1872+
Test the following:
1873+
1874+
quantize_(model, QATConfig(base_config, step="prepare"))
1875+
quantize_(model, QATConfig(base_config, step="convert"))
1876+
1877+
and compare model outputs of each step against:
1878+
1879+
quantize_(model, base_config)
1880+
"""
1881+
torch.manual_seed(self.SEED)
1882+
m = M().to(torch.bfloat16).cuda()
1883+
example_inputs = (m.example_inputs()[0].to(torch.bfloat16).cuda(),)
1884+
1885+
# baseline
1886+
m_baseline = copy.deepcopy(m)
1887+
quantize_(m_baseline, base_config)
1888+
out_baseline = m_baseline(*example_inputs)
1889+
1890+
# compare prepare
1891+
quantize_(m, QATConfig(base_config, step="prepare"))
1892+
out_prepared = m(*example_inputs)
1893+
prepare_sqnr = compute_error(out_prepared, out_baseline)
1894+
self.assertGreaterEqual(prepare_sqnr, target_prepare_sqnr)
1895+
1896+
# compare convert
1897+
quantize_(m, QATConfig(base_config, step="convert"))
1898+
out_converted = m(*example_inputs)
1899+
convert_sqnr = compute_error(out_converted, out_baseline)
1900+
self.assertGreaterEqual(convert_sqnr, target_convert_sqnr)
1901+
1902+
@parametrize("granularity", [PerTensor(), PerRow()])
1903+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1904+
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
1905+
def test_quantize_api_fp8_fp8(self, granularity: Granularity):
1906+
"""
1907+
Test the following:
1908+
quantize_(model, QATConfig(Float8DynamicActivationFloat8Weight(), step="prepare"))
1909+
quantize_(model, QATConfig(Float8DynamicActivationFloat8Weight(), step="convert"))
1910+
"""
1911+
self._test_quantize_api_against_ptq(
1912+
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
1913+
target_prepare_sqnr=15,
1914+
target_convert_sqnr=float("inf"),
1915+
)
1916+
1917+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1918+
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
1919+
@unittest.skipIf(
1920+
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
1921+
)
1922+
def test_quantize_api_fp8_int4(self):
1923+
"""
1924+
Test the following:
1925+
quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="prepare"))
1926+
quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="convert"))
1927+
"""
1928+
self._test_quantize_api_against_ptq(
1929+
Float8DynamicActivationInt4WeightConfig(group_size=128),
1930+
target_prepare_sqnr=15,
1931+
target_convert_sqnr=float("inf"),
1932+
)
1933+
1934+
1935+
instantiate_parametrized_tests(TestQAT)
1936+
18361937

18371938
if __name__ == "__main__":
18381939
unittest.main()

torchao/quantization/qat/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
from .fake_quantize_config import (
1616
FakeQuantizeConfig,
1717
FakeQuantizeConfigBase,
18+
Float8FakeQuantizeConfig,
1819
IntxFakeQuantizeConfig,
1920
)
2021
from .fake_quantizer import (
2122
FakeQuantizer,
2223
FakeQuantizerBase,
24+
Float8FakeQuantizer,
2325
IntxFakeQuantizer,
2426
)
2527
from .linear import (
@@ -34,6 +36,8 @@
3436
"QATStep",
3537
"FakeQuantizeConfigBase",
3638
"FakeQuantizerBase",
39+
"Float8FakeQuantizeConfig",
40+
"Float8FakeQuantizer",
3741
"IntxFakeQuantizeConfig",
3842
"IntxFakeQuantizer",
3943
"FakeQuantizedLinear",

0 commit comments

Comments
 (0)