Skip to content

Commit 9f4ee3e

Browse files
authored
Add StretchedUnifTorchaoQuantizer (#2576)
* Add StretchedUnifTorchaoQuantizer * Fix tinygemm test case * Test equivalence to PARQ UnifQuantizer; custom choose_qparams, quantize, dequantize * Remove dequantize_stretched_affine
1 parent a71c684 commit 9f4ee3e

File tree

4 files changed

+399
-70
lines changed

4 files changed

+399
-70
lines changed

test/prototype/test_parq.py

Lines changed: 137 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
from torchao.prototype.parq.quant import (
2222
Int4UnifTorchaoQuantizer,
2323
LSBQuantizer,
24+
StretchedUnifTorchaoQuantizer,
2425
TernaryUnifQuantizer,
2526
UnifQuantizer,
2627
UnifTorchaoQuantizer,
2728
)
29+
from torchao.prototype.parq.quant.quant_api import StretchedIntxWeightOnlyConfig
2830
from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE
2931
from torchao.quantization.granularity import PerGroup
3032
from torchao.quantization.qat import (
@@ -35,11 +37,11 @@
3537
from torchao.quantization.quant_api import (
3638
Int8DynamicActivationIntxWeightConfig,
3739
IntxWeightOnlyConfig,
38-
MappingType,
3940
_is_linear,
4041
int4_weight_only,
4142
quantize_,
4243
)
44+
from torchao.quantization.quant_primitives import MappingType
4345
from torchao.utils import (
4446
TORCH_VERSION_AT_LEAST_2_4,
4547
TORCH_VERSION_AT_LEAST_2_6,
@@ -74,6 +76,59 @@ def build_param_groups(model, b: int = 2, group_size: Optional[int] = None):
7476
]
7577

7678

79+
def compare_quantized_models(
80+
model: nn.Module,
81+
m_ref: nn.Module,
82+
quantizer: UnifTorchaoQuantizer,
83+
b: int,
84+
group_size: int,
85+
):
86+
for n, module in model.named_children():
87+
if not _is_linear(module):
88+
continue
89+
90+
# simulate grouping from QuantOptimizer.step
91+
p = module.weight
92+
original_shape = p.shape
93+
p = p.view(-1, group_size)
94+
95+
q, Q = quantizer.quantize(p, b=b, dim=-1)
96+
97+
# compare to AffineQuantizedTensor instance
98+
q = q.view(original_shape)
99+
ref = getattr(m_ref, n).weight.dequantize()
100+
torch.testing.assert_close(q, ref, atol=0, rtol=0)
101+
102+
103+
def compare_parq_convert(
104+
model: nn.Module,
105+
m_ref: nn.Module,
106+
optimizer: QuantOptimizer,
107+
config: AOBaseConfig,
108+
):
109+
# do not update model weights, just quantize
110+
optimizer.zero_grad()
111+
optimizer.step()
112+
113+
orig_model = copy.deepcopy(model) # save copy of PARQ quantized model
114+
115+
# equivalent to torchao's convert step
116+
model.eval()
117+
optimizer.restore_latent_params()
118+
quantize_(model, config, filter_fn=optimizer.get_filter_fn(model))
119+
120+
for n, module in model.named_modules():
121+
if not _is_linear(module):
122+
continue
123+
124+
p_orig = getattr(orig_model, n).weight # PARQ weight
125+
p = module.weight.dequantize() # PARQ weight after quantize_
126+
p_ref = getattr(m_ref, n).weight.dequantize() # native quantize_
127+
128+
torch.testing.assert_true(p_orig, p_ref, atol=0, rtol=0)
129+
torch.testing.assert_true(p, p_ref, atol=0, rtol=0)
130+
131+
77132
class M(nn.Module):
78133
def __init__(self, m=256, n=128, k=16, bias=False, embedding=True):
79134
super().__init__()
@@ -143,59 +198,6 @@ class TestUnifTorchaoQuantizer(common_utils.TestCase):
143198
def setUp(self):
144199
torch.manual_seed(123)
145200

146-
def compare_quantized_models(
147-
self,
148-
model: nn.Module,
149-
m_ref: nn.Module,
150-
quantizer: UnifTorchaoQuantizer,
151-
b: int,
152-
group_size: int,
153-
):
154-
for n, module in model.named_children():
155-
if not _is_linear(module):
156-
continue
157-
158-
# simulate grouping from QuantOptimizer.step
159-
p = module.weight
160-
original_shape = p.shape
161-
p = p.view(-1, group_size)
162-
163-
q, Q = quantizer.quantize(p, b=b, dim=-1)
164-
165-
# compare to AffineQuantizedTensor instance
166-
q = q.view(original_shape)
167-
ref = getattr(m_ref, n).weight.dequantize()
168-
torch.testing.assert_close(q, ref, atol=0, rtol=0)
169-
170-
def compare_parq_convert(
171-
self,
172-
model: nn.Module,
173-
m_ref: nn.Module,
174-
optimizer: QuantOptimizer,
175-
config: AOBaseConfig,
176-
):
177-
# do not update model weights, just quantize
178-
optimizer.zero_grad()
179-
optimizer.step()
180-
181-
orig_model = copy.deepcopy(model) # save copy of PARQ quantized model
182-
183-
# equivalent to torchao's convert step
184-
model.eval()
185-
optimizer.restore_latent_params()
186-
quantize_(model, config, filter_fn=optimizer.get_filter_fn(model))
187-
188-
for n, module in model.named_modules():
189-
if not _is_linear(module):
190-
continue
191-
192-
p_orig = getattr(orig_model, n).weight # PARQ weight
193-
p = module.weight.dequantize() # PARQ weight after quantize_
194-
p_ref = getattr(m_ref, n).weight.dequantize() # native quantize_
195-
196-
torch.testing.assert_true(p_orig, p_ref, atol=0, rtol=0)
197-
torch.testing.assert_true(p, p_ref, atol=0, rtol=0)
198-
199201
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
200202
@common_utils.parametrize("group_size", [32, 256])
201203
def test_int4_weight_only(self, group_size: int = 32):
@@ -209,7 +211,7 @@ def test_int4_weight_only(self, group_size: int = 32):
209211
quantize_(m_ref, config)
210212

211213
b = 4
212-
self.compare_quantized_models(
214+
compare_quantized_models(
213215
model, m_ref, Int4UnifTorchaoQuantizer(), b, group_size
214216
)
215217

@@ -229,7 +231,7 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
229231
)
230232

231233
quantizer = UnifTorchaoQuantizer()
232-
self.compare_quantized_models(model, m_ref, quantizer, b, group_size)
234+
compare_quantized_models(model, m_ref, quantizer, b, group_size)
233235

234236
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
235237
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
@@ -251,7 +253,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
251253
ProxHardQuant(),
252254
quant_per_channel=True,
253255
)
254-
self.compare_parq_convert(model, m_ref, optimizer, config)
256+
compare_parq_convert(model, m_ref, optimizer, config)
255257

256258
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
257259
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
@@ -273,7 +275,84 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
273275
ProxHardQuant(),
274276
quant_per_channel=True,
275277
)
276-
self.compare_parq_convert(model, m_ref, optimizer, config)
278+
compare_parq_convert(model, m_ref, optimizer, config)
279+
280+
281+
class TestStretchedUnifTorchaoQuantizer(common_utils.TestCase):
282+
def setUp(self):
283+
torch.manual_seed(123)
284+
285+
@common_utils.parametrize("b", [2, 3])
286+
@common_utils.parametrize("group_size", [32, 256])
287+
def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32):
288+
model = M(m=512, n=512).to(_DEVICE)
289+
model.reset_parameters()
290+
291+
quantizer_ref = UnifQuantizer()
292+
quantizer = StretchedUnifTorchaoQuantizer(b)
293+
294+
for n, module in model.named_children():
295+
if not _is_linear(module):
296+
continue
297+
298+
# simulate grouping from QuantOptimizer.step
299+
p = module.weight
300+
p = p.view(-1, group_size)
301+
302+
q_ref, Q_ref = quantizer_ref.quantize(p, b=b, dim=-1)
303+
q, Q = quantizer.quantize(p, b=b, dim=-1)
304+
305+
torch.testing.assert_close(q, q_ref, atol=0, rtol=0)
306+
torch.testing.assert_close(Q, Q_ref, atol=0, rtol=0)
307+
308+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
309+
@common_utils.parametrize("b", [2, 3])
310+
@common_utils.parametrize("group_size", [32, 512])
311+
def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
312+
model = M(m=512, n=512).to(_DEVICE)
313+
model.reset_parameters()
314+
315+
quantizer = StretchedUnifTorchaoQuantizer(b)
316+
317+
m_ref = copy.deepcopy(model).eval().to(_DEVICE)
318+
quantize_(
319+
m_ref,
320+
StretchedIntxWeightOnlyConfig(
321+
b=b,
322+
quant_min=quantizer.quant_min,
323+
quant_max=quantizer.quant_max,
324+
granularity=PerGroup(group_size),
325+
),
326+
)
327+
328+
compare_quantized_models(model, m_ref, quantizer, b, group_size)
329+
330+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
331+
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
332+
@common_utils.parametrize("b", [2, 3])
333+
def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
334+
model = M(m=512, n=512).to(_DEVICE)
335+
model.reset_parameters()
336+
337+
quantizer = StretchedUnifTorchaoQuantizer(b)
338+
339+
m_ref = copy.deepcopy(model).eval().to(_DEVICE)
340+
config = StretchedIntxWeightOnlyConfig(
341+
b=b,
342+
quant_min=quantizer.quant_min,
343+
quant_max=quantizer.quant_max,
344+
granularity=PerGroup(group_size),
345+
)
346+
quantize_(m_ref, config)
347+
348+
base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size))
349+
optimizer = QuantOptimizer(
350+
base_optimizer,
351+
quantizer,
352+
ProxHardQuant(),
353+
quant_per_channel=True,
354+
)
355+
compare_parq_convert(model, m_ref, optimizer, config)
277356

278357

279358
class TestInt8DynamicActivationTorchaoQuantizer(common_utils.TestCase):

torchao/prototype/parq/quant/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313
)
1414
from .uniform_torchao import ( # noqa: F401
1515
Int4UnifTorchaoQuantizer,
16+
StretchedUnifTorchaoQuantizer,
1617
UnifTorchaoQuantizer,
1718
)

0 commit comments

Comments
 (0)