Skip to content

Commit a9e0596

Browse files
committed
Adding subclass and api for weight-only quant
Summary: Reconfigured subclasses to inherit for int8weightquantized subclass since all teh weight manipulation code is the same, only the quantized op differs Test Plan: python test.py -k "subclass" Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 1d82a46 Pull Request resolved: #11
1 parent 27bf5bf commit a9e0596

File tree

4 files changed

+189
-123
lines changed

4 files changed

+189
-123
lines changed

test/test.py

Lines changed: 58 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
apply_dynamic_quant,
2222
apply_weight_only_int8_quant,
2323
change_linear_weights_to_dqtensors,
24+
change_linear_weights_to_woqtensors,
2425
_replace_with_custom_fn_if_matches_filter,
2526
)
2627
from torchao.quantization.quant_primitives import (
@@ -42,6 +43,7 @@
4243
)
4344
from torchao.quantization.subclass import (
4445
DynamicallyQuantizedLinearWeight,
46+
WeightOnlyQuantizedLinearWeight
4547
)
4648
from torchao.quantization.utils import (
4749
apply_logging_hook,
@@ -51,6 +53,10 @@
5153
LoggingTensorMode,
5254
)
5355
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
56+
from torchao.quantization.weight_only import (
57+
WeightOnlyInt8QuantLinear
58+
)
59+
5460

5561
torch.manual_seed(0)
5662

@@ -782,84 +788,62 @@ def test_qlinear_per_channel_numerics_cuda(self):
782788

783789

784790
class TestSubclass(unittest.TestCase):
785-
def test_dq_lin_weight_subclass_aot(self):
786-
m, k, n = 32, 64, 32
787-
x = torch.randn(m, k, device="cuda", dtype=torch.float32)
788-
lin = torch.nn.Linear(k, n, device="cuda")
789-
790-
import copy
791-
792-
linq = DynamicallyPerAxisQuantizedLinear.from_float(copy.deepcopy(lin))
793-
794-
ref_f = lin(x)
795-
ref_q = linq(x)
796-
797-
print(SQNR(ref_f, ref_q), "float to dq")
798-
799-
lin.weight = torch.nn.Parameter(
800-
DynamicallyQuantizedLinearWeight.from_float(lin.weight), requires_grad=False
801-
)
802-
test = lin(x)
803-
print(SQNR(ref_f, test), "float to dq class")
804-
print(SQNR(ref_q, test), "dq to dq class")
805-
assert SQNR(ref_f, test) > 35
806-
assert SQNR(ref_q, test) > 35
807-
808-
lin_comp = torch.compile(lin, backend="aot_eager")
809-
linq_comp = torch.compile(linq, backend="aot_eager")
810-
test_comp = lin_comp(x)
811-
ref_q_comp = linq_comp(x)
812-
print(SQNR(ref_f, test_comp), "float to dq class compiled")
813-
print(SQNR(ref_q_comp, test_comp), "dq compiled to dq class compiled")
814-
assert SQNR(ref_f, test_comp) > 35
815-
assert SQNR(ref_q_comp, test_comp) > 35
816-
817-
def test_dq_lin_weight_subclass_max_autotune(self):
818-
m, k, n = 32, 64, 32
819-
x = torch.randn(m, k, device="cuda", dtype=torch.float32)
820-
lin = torch.nn.Linear(k, n, device="cuda")
821-
822-
import copy
823-
824-
linq = DynamicallyPerAxisQuantizedLinear.from_float(copy.deepcopy(lin))
825-
826-
ref_f = lin(x)
827-
ref_q = linq(x)
791+
def _test_lin_weight_subclass_impl(self,
792+
test_subclass,
793+
min_sqnr=35,
794+
test_dtypes=[torch.float32, torch.float16, torch.bfloat16],
795+
test_shape=[32, 64, 32]
796+
):
797+
for test_dtype in test_dtypes:
798+
m, k, n = test_shape
799+
x = torch.randn(m, k, device="cuda", dtype=test_dtype)
800+
lin = torch.nn.Linear(k, n, device="cuda").to(test_dtype)
801+
ref_f = lin(x)
802+
803+
lin.weight = torch.nn.Parameter(
804+
test_subclass.from_float(lin.weight), requires_grad=False
805+
)
806+
test = lin(x)
807+
self.assertGreater(SQNR(ref_f, test), min_sqnr, f"{test_subclass.__name__} failed, no compile, dtype={test_dtype}, (m, k, n)={test_shape}")
808+
lin_comp = torch.compile(lin, mode='max-autotune')
809+
test_comp = lin_comp(x)
810+
self.assertGreater(SQNR(ref_f, test_comp), min_sqnr, f"{test_subclass.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}")
828811

829-
print(SQNR(ref_f, ref_q), "float to dq")
812+
def test_int8_dynamic_quant_subclass(self):
813+
self._test_lin_weight_subclass_impl(DynamicallyQuantizedLinearWeight, 35)
830814

831-
lin.weight = torch.nn.Parameter(
832-
DynamicallyQuantizedLinearWeight.from_float(lin.weight), requires_grad=False
833-
)
834-
test = lin(x)
835-
print(SQNR(ref_f, test), "float to dq class")
836-
print(SQNR(ref_q, test), "dq to dq class")
837-
assert SQNR(ref_f, test) > 35
838-
assert SQNR(ref_q, test) > 35
839-
840-
lin_comp = torch.compile(lin, mode="max-autotune")
841-
linq_comp = torch.compile(linq, mode="max-autotune")
842-
843-
test_comp = lin_comp(x)
844-
ref_q_comp = linq_comp(x)
845-
print(SQNR(ref_f, test_comp), "float to dq class compiled")
846-
print(SQNR(ref_q_comp, test_comp), "dq compiled to dq class compiled")
847-
assert SQNR(ref_f, test_comp) > 35
848-
assert SQNR(ref_q_comp, test_comp) > 35
815+
def test_int8_weight_only_quant_subclass(self):
816+
self._test_lin_weight_subclass_impl(WeightOnlyQuantizedLinearWeight, 40)
849817

850818
@torch.no_grad()
851-
def test_dq_lin_weight_subclass_max_autotune_api(self):
852-
m, k, n = 32, 64, 32
853-
x = torch.randn(m, k, device="cuda", dtype=torch.float32)
854-
855-
mod = nn.Sequential(
856-
nn.Linear(k, n, device="cuda"), nn.ReLU(), nn.Linear(n, n, device="cuda")
857-
)
858-
change_linear_weights_to_dqtensors(mod)
859-
mod_qc = torch.compile(mod, mode="max-autotune")
860-
mod_qc(x)
861-
mod_qc(x)
862-
819+
def _test_lin_weight_subclass_api_impl(
820+
self,
821+
api,
822+
min_sqnr=35,
823+
test_dtypes=[torch.float32, torch.float16, torch.bfloat16],
824+
test_shape=[32, 64, 32]
825+
):
826+
for test_dtype in test_dtypes:
827+
m, k, n = test_shape
828+
x = torch.randn(m, k, device="cuda", dtype=test_dtype)
829+
mod = nn.Sequential(
830+
nn.Linear(k, n, device="cuda"), nn.ReLU(), nn.Linear(n, n, device="cuda")
831+
).to(test_dtype)
832+
ref_f = mod(x)
833+
api(mod)
834+
test = mod(x)
835+
self.assertGreater(SQNR(ref_f, test), min_sqnr, f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}")
836+
837+
mod_qc = torch.compile(mod, mode="max-autotune")
838+
test_comp = mod_qc(x)
839+
self.assertGreater(SQNR(ref_f, test_comp), min_sqnr, f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}")
840+
841+
842+
def test_int8_dynamic_quant_subclass_api(self):
843+
self._test_lin_weight_subclass_api_impl(change_linear_weights_to_dqtensors, 35)
844+
845+
def test_int8_weight_only_quant_subclass_api(self):
846+
self._test_lin_weight_subclass_api_impl(change_linear_weights_to_woqtensors, 40)
863847

864848
class TestDynamicQuant(unittest.TestCase):
865849
def test_dynamic_quant(self):

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"apply_weight_only_int8_quant",
1717
"apply_dynamic_quant",
1818
"change_linear_weights_to_dqtensors",
19+
"change_linear_weights_to_woqtensors",
1920
"insert_subclass",
2021
"safe_int_mm",
2122
"dynamically_quantize_per_tensor",
@@ -34,6 +35,7 @@
3435
"smooth_fq_linear_to_inference",
3536
"set_smooth_fq_attribute",
3637
"DynamicallyQuantizedLinearWeight",
38+
"WeightOnlyQuantizedLinearWeight",
3739
"log_with_rank",
3840
"clear_logs",
3941
"compute_error",

torchao/quantization/quant_api.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from .subclass import (
2323
DynamicallyQuantizedLinearWeight,
24+
WeightOnlyQuantizedLinearWeight,
2425
)
2526
from .weight_only import (
2627
WeightOnlyInt8QuantLinear,
@@ -30,6 +31,7 @@
3031
"apply_weight_only_int8_quant",
3132
"apply_dynamic_quant",
3233
"change_linear_weights_to_dqtensors",
34+
"change_linear_weights_to_woqtensors",
3335
]
3436

3537

@@ -74,17 +76,35 @@ def apply_dynamic_quant(model):
7476
lambda mod: DynamicallyPerAxisQuantizedLinear.from_float(mod),
7577
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
7678
)
79+
80+
def _get_subclass_inserter(cls):
81+
def insert_subclass(lin):
82+
lin.weight = torch.nn.Parameter(
83+
cls.from_float(lin.weight), requires_grad=False
84+
)
85+
return lin
86+
return insert_subclass
87+
7788
def change_linear_weights_to_dqtensors(model):
7889
"""
7990
Converts all linear weight tensors to the `DynamicallyQuantizedLinearWeight`
8091
Tensor subclass, effectively applying the same form of quantization
8192
as apply_dynamic_quant while not modifying the linear modules.
8293
"""
83-
def insert_subclass(lin):
84-
lin.weight = torch.nn.Parameter(
85-
DynamicallyQuantizedLinearWeight.from_float(lin.weight), requires_grad=False
86-
)
87-
return lin
8894
_replace_with_custom_fn_if_matches_filter(
89-
model, insert_subclass, lambda mod, fqn: isinstance(mod, torch.nn.Linear)
95+
model,
96+
_get_subclass_inserter(DynamicallyQuantizedLinearWeight),
97+
lambda mod, fqn: isinstance(mod, torch.nn.Linear)
98+
)
99+
100+
def change_linear_weights_to_woqtensors(model):
101+
"""
102+
Converts all linear weight tensors to the `WeightOnlyQuantizedLinearWeight`
103+
Tensor subclass, effectively applying the same form of quantization
104+
as apply_dynamic_quant while not modifying the linear modules.
105+
"""
106+
_replace_with_custom_fn_if_matches_filter(
107+
model,
108+
_get_subclass_inserter(WeightOnlyQuantizedLinearWeight),
109+
lambda mod, fqn: isinstance(mod, torch.nn.Linear)
90110
)

0 commit comments

Comments
 (0)