|
21 | 21 | apply_dynamic_quant,
|
22 | 22 | apply_weight_only_int8_quant,
|
23 | 23 | change_linear_weights_to_dqtensors,
|
| 24 | + change_linear_weights_to_woqtensors, |
24 | 25 | _replace_with_custom_fn_if_matches_filter,
|
25 | 26 | )
|
26 | 27 | from torchao.quantization.quant_primitives import (
|
|
42 | 43 | )
|
43 | 44 | from torchao.quantization.subclass import (
|
44 | 45 | DynamicallyQuantizedLinearWeight,
|
| 46 | + WeightOnlyQuantizedLinearWeight |
45 | 47 | )
|
46 | 48 | from torchao.quantization.utils import (
|
47 | 49 | apply_logging_hook,
|
|
51 | 53 | LoggingTensorMode,
|
52 | 54 | )
|
53 | 55 | from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
|
| 56 | +from torchao.quantization.weight_only import ( |
| 57 | + WeightOnlyInt8QuantLinear |
| 58 | +) |
| 59 | + |
54 | 60 |
|
55 | 61 | torch.manual_seed(0)
|
56 | 62 |
|
@@ -782,84 +788,62 @@ def test_qlinear_per_channel_numerics_cuda(self):
|
782 | 788 |
|
783 | 789 |
|
784 | 790 | class TestSubclass(unittest.TestCase):
|
785 |
| - def test_dq_lin_weight_subclass_aot(self): |
786 |
| - m, k, n = 32, 64, 32 |
787 |
| - x = torch.randn(m, k, device="cuda", dtype=torch.float32) |
788 |
| - lin = torch.nn.Linear(k, n, device="cuda") |
789 |
| - |
790 |
| - import copy |
791 |
| - |
792 |
| - linq = DynamicallyPerAxisQuantizedLinear.from_float(copy.deepcopy(lin)) |
793 |
| - |
794 |
| - ref_f = lin(x) |
795 |
| - ref_q = linq(x) |
796 |
| - |
797 |
| - print(SQNR(ref_f, ref_q), "float to dq") |
798 |
| - |
799 |
| - lin.weight = torch.nn.Parameter( |
800 |
| - DynamicallyQuantizedLinearWeight.from_float(lin.weight), requires_grad=False |
801 |
| - ) |
802 |
| - test = lin(x) |
803 |
| - print(SQNR(ref_f, test), "float to dq class") |
804 |
| - print(SQNR(ref_q, test), "dq to dq class") |
805 |
| - assert SQNR(ref_f, test) > 35 |
806 |
| - assert SQNR(ref_q, test) > 35 |
807 |
| - |
808 |
| - lin_comp = torch.compile(lin, backend="aot_eager") |
809 |
| - linq_comp = torch.compile(linq, backend="aot_eager") |
810 |
| - test_comp = lin_comp(x) |
811 |
| - ref_q_comp = linq_comp(x) |
812 |
| - print(SQNR(ref_f, test_comp), "float to dq class compiled") |
813 |
| - print(SQNR(ref_q_comp, test_comp), "dq compiled to dq class compiled") |
814 |
| - assert SQNR(ref_f, test_comp) > 35 |
815 |
| - assert SQNR(ref_q_comp, test_comp) > 35 |
816 |
| - |
817 |
| - def test_dq_lin_weight_subclass_max_autotune(self): |
818 |
| - m, k, n = 32, 64, 32 |
819 |
| - x = torch.randn(m, k, device="cuda", dtype=torch.float32) |
820 |
| - lin = torch.nn.Linear(k, n, device="cuda") |
821 |
| - |
822 |
| - import copy |
823 |
| - |
824 |
| - linq = DynamicallyPerAxisQuantizedLinear.from_float(copy.deepcopy(lin)) |
825 |
| - |
826 |
| - ref_f = lin(x) |
827 |
| - ref_q = linq(x) |
| 791 | + def _test_lin_weight_subclass_impl(self, |
| 792 | + test_subclass, |
| 793 | + min_sqnr=35, |
| 794 | + test_dtypes=[torch.float32, torch.float16, torch.bfloat16], |
| 795 | + test_shape=[32, 64, 32] |
| 796 | + ): |
| 797 | + for test_dtype in test_dtypes: |
| 798 | + m, k, n = test_shape |
| 799 | + x = torch.randn(m, k, device="cuda", dtype=test_dtype) |
| 800 | + lin = torch.nn.Linear(k, n, device="cuda").to(test_dtype) |
| 801 | + ref_f = lin(x) |
| 802 | + |
| 803 | + lin.weight = torch.nn.Parameter( |
| 804 | + test_subclass.from_float(lin.weight), requires_grad=False |
| 805 | + ) |
| 806 | + test = lin(x) |
| 807 | + self.assertGreater(SQNR(ref_f, test), min_sqnr, f"{test_subclass.__name__} failed, no compile, dtype={test_dtype}, (m, k, n)={test_shape}") |
| 808 | + lin_comp = torch.compile(lin, mode='max-autotune') |
| 809 | + test_comp = lin_comp(x) |
| 810 | + self.assertGreater(SQNR(ref_f, test_comp), min_sqnr, f"{test_subclass.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}") |
828 | 811 |
|
829 |
| - print(SQNR(ref_f, ref_q), "float to dq") |
| 812 | + def test_int8_dynamic_quant_subclass(self): |
| 813 | + self._test_lin_weight_subclass_impl(DynamicallyQuantizedLinearWeight, 35) |
830 | 814 |
|
831 |
| - lin.weight = torch.nn.Parameter( |
832 |
| - DynamicallyQuantizedLinearWeight.from_float(lin.weight), requires_grad=False |
833 |
| - ) |
834 |
| - test = lin(x) |
835 |
| - print(SQNR(ref_f, test), "float to dq class") |
836 |
| - print(SQNR(ref_q, test), "dq to dq class") |
837 |
| - assert SQNR(ref_f, test) > 35 |
838 |
| - assert SQNR(ref_q, test) > 35 |
839 |
| - |
840 |
| - lin_comp = torch.compile(lin, mode="max-autotune") |
841 |
| - linq_comp = torch.compile(linq, mode="max-autotune") |
842 |
| - |
843 |
| - test_comp = lin_comp(x) |
844 |
| - ref_q_comp = linq_comp(x) |
845 |
| - print(SQNR(ref_f, test_comp), "float to dq class compiled") |
846 |
| - print(SQNR(ref_q_comp, test_comp), "dq compiled to dq class compiled") |
847 |
| - assert SQNR(ref_f, test_comp) > 35 |
848 |
| - assert SQNR(ref_q_comp, test_comp) > 35 |
| 815 | + def test_int8_weight_only_quant_subclass(self): |
| 816 | + self._test_lin_weight_subclass_impl(WeightOnlyQuantizedLinearWeight, 40) |
849 | 817 |
|
850 | 818 | @torch.no_grad()
|
851 |
| - def test_dq_lin_weight_subclass_max_autotune_api(self): |
852 |
| - m, k, n = 32, 64, 32 |
853 |
| - x = torch.randn(m, k, device="cuda", dtype=torch.float32) |
854 |
| - |
855 |
| - mod = nn.Sequential( |
856 |
| - nn.Linear(k, n, device="cuda"), nn.ReLU(), nn.Linear(n, n, device="cuda") |
857 |
| - ) |
858 |
| - change_linear_weights_to_dqtensors(mod) |
859 |
| - mod_qc = torch.compile(mod, mode="max-autotune") |
860 |
| - mod_qc(x) |
861 |
| - mod_qc(x) |
862 |
| - |
| 819 | + def _test_lin_weight_subclass_api_impl( |
| 820 | + self, |
| 821 | + api, |
| 822 | + min_sqnr=35, |
| 823 | + test_dtypes=[torch.float32, torch.float16, torch.bfloat16], |
| 824 | + test_shape=[32, 64, 32] |
| 825 | + ): |
| 826 | + for test_dtype in test_dtypes: |
| 827 | + m, k, n = test_shape |
| 828 | + x = torch.randn(m, k, device="cuda", dtype=test_dtype) |
| 829 | + mod = nn.Sequential( |
| 830 | + nn.Linear(k, n, device="cuda"), nn.ReLU(), nn.Linear(n, n, device="cuda") |
| 831 | + ).to(test_dtype) |
| 832 | + ref_f = mod(x) |
| 833 | + api(mod) |
| 834 | + test = mod(x) |
| 835 | + self.assertGreater(SQNR(ref_f, test), min_sqnr, f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}") |
| 836 | + |
| 837 | + mod_qc = torch.compile(mod, mode="max-autotune") |
| 838 | + test_comp = mod_qc(x) |
| 839 | + self.assertGreater(SQNR(ref_f, test_comp), min_sqnr, f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}") |
| 840 | + |
| 841 | + |
| 842 | + def test_int8_dynamic_quant_subclass_api(self): |
| 843 | + self._test_lin_weight_subclass_api_impl(change_linear_weights_to_dqtensors, 35) |
| 844 | + |
| 845 | + def test_int8_weight_only_quant_subclass_api(self): |
| 846 | + self._test_lin_weight_subclass_api_impl(change_linear_weights_to_woqtensors, 40) |
863 | 847 |
|
864 | 848 | class TestDynamicQuant(unittest.TestCase):
|
865 | 849 | def test_dynamic_quant(self):
|
|
0 commit comments