14
14
15
15
import torch
16
16
import torch .nn .functional as F
17
- from parameterized import parameterized
18
17
from torch .ao .quantization .fx ._decomposed import quantized_decomposed_lib # noqa: F401
18
+ from torch .testing ._internal .common_utils import (
19
+ TestCase ,
20
+ instantiate_parametrized_tests ,
21
+ parametrize ,
22
+ )
19
23
20
24
from torchao import quantize_
21
- from torchao .float8 .config import ScalingGranularity
22
- from torchao .float8 .float8_scaling_utils import hp_tensor_to_float8_dynamic
23
- from torchao .float8 .float8_training_tensor import LinearMMConfig
25
+ from torchao .core .config import AOBaseConfig
26
+ from torchao .quantization import Float8Tensor
24
27
from torchao .quantization .granularity import (
28
+ Granularity ,
25
29
PerAxis ,
26
30
PerGroup ,
27
31
PerRow ,
32
+ PerTensor ,
28
33
PerToken ,
29
34
)
30
35
from torchao .quantization .linear_quant_modules import (
43
48
FakeQuantizedEmbedding ,
44
49
)
45
50
from torchao .quantization .qat .fake_quantize_config import (
51
+ Float8FakeQuantizeConfig ,
46
52
IntxFakeQuantizeConfig ,
47
53
)
48
54
from torchao .quantization .qat .fake_quantizer import (
55
+ Float8FakeQuantizer ,
49
56
IntxFakeQuantizer ,
50
- _Float8RowwiseActivationFakeQuantizer ,
51
57
)
52
58
from torchao .quantization .qat .linear import (
53
59
FakeQuantizedLinear ,
58
64
from torchao .quantization .qat .utils import (
59
65
_fake_quantize_per_channel_group ,
60
66
_fake_quantize_per_token ,
61
- _Float8RowwiseFakeQuantize ,
62
67
_get_qmin_qmax ,
63
68
)
64
69
from torchao .quantization .quant_api import (
70
+ Float8DynamicActivationFloat8WeightConfig ,
71
+ Float8DynamicActivationInt4WeightConfig ,
65
72
Int8DynamicActivationInt4WeightConfig ,
66
73
)
67
74
from torchao .quantization .quant_primitives import (
83
90
get_groupwise_affine_qparams ,
84
91
groupwise_affine_quantize_tensor ,
85
92
)
93
+ from torchao .utils import (
94
+ _is_fbgemm_genai_gpu_available ,
95
+ is_sm_at_least_89 ,
96
+ )
86
97
87
98
# TODO: put this in a common test utils file
88
99
_CUDA_IS_AVAILABLE = torch .cuda .is_available ()
@@ -193,7 +204,7 @@ def forward(self, x):
193
204
return x
194
205
195
206
196
- class TestQAT (unittest . TestCase ):
207
+ class TestQAT (TestCase ):
197
208
SEED = 123
198
209
199
210
def test_fake_quantize_per_channel_group (self ):
@@ -1420,7 +1431,7 @@ def test_qat_linear_bias(self):
1420
1431
example_inputs = m .example_inputs ()
1421
1432
m (* example_inputs )
1422
1433
1423
- @parameterized . expand ([( torch .float32 ,), ( torch .bfloat16 ,), ( torch .float16 ,) ])
1434
+ @parametrize ( "dtype" , [ torch .float32 , torch .bfloat16 , torch .float16 ])
1424
1435
def test_fake_quantize_per_token_vs_convert (self , dtype : torch .dtype ):
1425
1436
"""
1426
1437
Test that the following produce the exact same numerics:
@@ -1437,7 +1448,7 @@ def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
1437
1448
baseline_out = per_token_dynamic_quant (x )
1438
1449
torch .testing .assert_close (fake_quantizer_out , baseline_out , atol = 0 , rtol = 0 )
1439
1450
1440
- @parameterized . expand ([( torch .float32 ,), ( torch .bfloat16 ,), ( torch .float16 ,) ])
1451
+ @parametrize ( "dtype" , [ torch .float32 , torch .bfloat16 , torch .float16 ])
1441
1452
def test_qat_8da4w_prepare_vs_convert (self , dtype : torch .dtype ):
1442
1453
"""
1443
1454
Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
@@ -1548,7 +1559,7 @@ def test_qat_8da4w_eps(self):
1548
1559
actual_out = converted_model .linear1 (x )
1549
1560
torch .testing .assert_close (expected_out , actual_out , atol = 0 , rtol = 0 )
1550
1561
1551
- @parameterized . expand ([( True ,), ( False ,) ])
1562
+ @parametrize ( "is_symmetric" , [ True , False ])
1552
1563
def test_fake_quantizer_range_learning (self , is_symmetric ):
1553
1564
"""
1554
1565
Test that range learning requires `IntxFakeQuantizer`s to be initialized correctly.
@@ -1589,7 +1600,7 @@ def test_fake_quantizer_range_learning(self, is_symmetric):
1589
1600
self .assertTrue (fake_quantizer .zero_point .requires_grad )
1590
1601
fake_quantizer (* example_inputs )
1591
1602
1592
- @parameterized . expand ([( True ,), ( False ,) ])
1603
+ @parametrize ( "is_symmetric" , [ True , False ])
1593
1604
def test_qat_range_learning (self , is_symmetric ):
1594
1605
"""
1595
1606
Test end-to-end QAT flow with range learning.
@@ -1664,24 +1675,6 @@ def test_qat_range_learning(self, is_symmetric):
1664
1675
self .assertNotEqual (torch .count_nonzero (new_weight .grad ), 0 )
1665
1676
self .assertFalse (torch .equal (new_weight , prev_weight ))
1666
1677
1667
- def test_float8_rowwise_fake_quantize (self ):
1668
- """
1669
- Test that `_Float8RowwiseFakeQuantize` is numerically close to `Float8TrainingTensor`.
1670
- """
1671
- torch .manual_seed (self .SEED )
1672
- dtype = torch .float8_e4m3fn
1673
- x = torch .randn (32 , 64 )
1674
- axiswise_dim = 0
1675
- out = _Float8RowwiseFakeQuantize .apply (x , dtype , axiswise_dim )
1676
- out_expected = hp_tensor_to_float8_dynamic (
1677
- x ,
1678
- dtype ,
1679
- LinearMMConfig (),
1680
- scaling_granularity = ScalingGranularity .AXISWISE ,
1681
- axiswise_dim = axiswise_dim ,
1682
- ).to_original_precision ()
1683
- torch .testing .assert_close (out , out_expected , atol = 0 , rtol = 0 )
1684
-
1685
1678
def test_qat_fp8a4w_quantizer (self ):
1686
1679
"""
1687
1680
Test basic model training with `Float8ActInt4WeightQATQuantizer`.
@@ -1693,7 +1686,8 @@ def test_qat_fp8a4w_quantizer(self):
1693
1686
for linear in [m .linear1 , m .sub .linear , m .linear2 ]:
1694
1687
self .assertIsInstance (linear , FakeQuantizedLinear )
1695
1688
self .assertIsInstance (
1696
- linear .activation_fake_quantizer , _Float8RowwiseActivationFakeQuantizer
1689
+ linear .activation_fake_quantizer ,
1690
+ Float8FakeQuantizer ,
1697
1691
)
1698
1692
self .assertIsInstance (linear .weight_fake_quantizer , IntxFakeQuantizer )
1699
1693
prev_weight = copy .deepcopy (m .linear1 .weight )
@@ -1833,6 +1827,113 @@ def test_qat_api_convert_no_quantization(self):
1833
1827
baseline_out = baseline_model (* x2 )
1834
1828
torch .testing .assert_close (out , baseline_out , atol = 0 , rtol = 0 )
1835
1829
1830
+ def test_float8_fake_quantize_config (self ):
1831
+ """
1832
+ Test that the correct errors are thrown if `Float8FakeQuantizeConfig` is not instantiated properly.
1833
+ """
1834
+ # OK
1835
+ Float8FakeQuantizeConfig (torch .float8_e4m3fn )
1836
+ Float8FakeQuantizeConfig (torch .float8_e4m3fn , PerRow ())
1837
+ Float8FakeQuantizeConfig (torch .float8_e4m3fn , PerTensor ())
1838
+
1839
+ with self .assertRaisesRegex (ValueError , "not a float8 dtype" ):
1840
+ Float8FakeQuantizeConfig (torch .int8 )
1841
+ with self .assertRaisesRegex (
1842
+ ValueError , "Please specify the granularity object instead of the class"
1843
+ ):
1844
+ Float8FakeQuantizeConfig (granularity = PerRow )
1845
+ with self .assertRaisesRegex (
1846
+ ValueError , "Expected PerRow or PerTensor granularity"
1847
+ ):
1848
+ Float8FakeQuantizeConfig (granularity = PerToken ())
1849
+
1850
+ @parametrize ("granularity" , [PerTensor (), PerRow ()])
1851
+ def test_float8_fake_quantize (self , granularity : Granularity ):
1852
+ """
1853
+ Test that `Float8FakeQuantizer` is numerically close to `Float8Tensor`.
1854
+ """
1855
+ dtype = torch .float8_e4m3fn
1856
+ fq_config = Float8FakeQuantizeConfig (dtype , granularity )
1857
+ fake_quantizer = Float8FakeQuantizer (fq_config )
1858
+ torch .manual_seed (self .SEED )
1859
+ x = torch .randn (32 , 64 )
1860
+ out = fake_quantizer (x )
1861
+ out_expected = Float8Tensor .to_float8 (x , dtype , granularity ).dequantize ()
1862
+ sqnr = compute_error (out , out_expected )
1863
+ self .assertGreater (sqnr , 16 )
1864
+
1865
+ def _test_quantize_api_against_ptq (
1866
+ self ,
1867
+ base_config : AOBaseConfig ,
1868
+ target_prepare_sqnr : float ,
1869
+ target_convert_sqnr : float ,
1870
+ ):
1871
+ """
1872
+ Test the following:
1873
+
1874
+ quantize_(model, QATConfig(base_config, step="prepare"))
1875
+ quantize_(model, QATConfig(base_config, step="convert"))
1876
+
1877
+ and compare model outputs of each step against:
1878
+
1879
+ quantize_(model, base_config)
1880
+ """
1881
+ torch .manual_seed (self .SEED )
1882
+ m = M ().to (torch .bfloat16 ).cuda ()
1883
+ example_inputs = (m .example_inputs ()[0 ].to (torch .bfloat16 ).cuda (),)
1884
+
1885
+ # baseline
1886
+ m_baseline = copy .deepcopy (m )
1887
+ quantize_ (m_baseline , base_config )
1888
+ out_baseline = m_baseline (* example_inputs )
1889
+
1890
+ # compare prepare
1891
+ quantize_ (m , QATConfig (base_config , step = "prepare" ))
1892
+ out_prepared = m (* example_inputs )
1893
+ prepare_sqnr = compute_error (out_prepared , out_baseline )
1894
+ self .assertGreaterEqual (prepare_sqnr , target_prepare_sqnr )
1895
+
1896
+ # compare convert
1897
+ quantize_ (m , QATConfig (base_config , step = "convert" ))
1898
+ out_converted = m (* example_inputs )
1899
+ convert_sqnr = compute_error (out_converted , out_baseline )
1900
+ self .assertGreaterEqual (convert_sqnr , target_convert_sqnr )
1901
+
1902
+ @parametrize ("granularity" , [PerTensor (), PerRow ()])
1903
+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
1904
+ @unittest .skipIf (not is_sm_at_least_89 (), "Need sm89+" )
1905
+ def test_quantize_api_fp8_fp8 (self , granularity : Granularity ):
1906
+ """
1907
+ Test the following:
1908
+ quantize_(model, QATConfig(Float8DynamicActivationFloat8Weight(), step="prepare"))
1909
+ quantize_(model, QATConfig(Float8DynamicActivationFloat8Weight(), step="convert"))
1910
+ """
1911
+ self ._test_quantize_api_against_ptq (
1912
+ Float8DynamicActivationFloat8WeightConfig (granularity = granularity ),
1913
+ target_prepare_sqnr = 15 ,
1914
+ target_convert_sqnr = float ("inf" ),
1915
+ )
1916
+
1917
+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
1918
+ @unittest .skipIf (not is_sm_at_least_89 (), "Need sm89+" )
1919
+ @unittest .skipIf (
1920
+ not _is_fbgemm_genai_gpu_available (), "Requires fbgemm-gpu-genai >= 1.2.0"
1921
+ )
1922
+ def test_quantize_api_fp8_int4 (self ):
1923
+ """
1924
+ Test the following:
1925
+ quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="prepare"))
1926
+ quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="convert"))
1927
+ """
1928
+ self ._test_quantize_api_against_ptq (
1929
+ Float8DynamicActivationInt4WeightConfig (group_size = 128 ),
1930
+ target_prepare_sqnr = 15 ,
1931
+ target_convert_sqnr = float ("inf" ),
1932
+ )
1933
+
1934
+
1935
+ instantiate_parametrized_tests (TestQAT )
1936
+
1836
1937
1837
1938
if __name__ == "__main__" :
1838
1939
unittest .main ()
0 commit comments