21
21
from torchao .prototype .parq .quant import (
22
22
Int4UnifTorchaoQuantizer ,
23
23
LSBQuantizer ,
24
+ StretchedUnifTorchaoQuantizer ,
24
25
TernaryUnifQuantizer ,
25
26
UnifQuantizer ,
26
27
UnifTorchaoQuantizer ,
27
28
)
29
+ from torchao .prototype .parq .quant .quant_api import StretchedIntxWeightOnlyConfig
28
30
from torchao .prototype .parq .quant .uniform_torchao import _BIT_WIDTH_TO_DTYPE
29
31
from torchao .quantization .granularity import PerGroup
30
32
from torchao .quantization .qat import (
35
37
from torchao .quantization .quant_api import (
36
38
Int8DynamicActivationIntxWeightConfig ,
37
39
IntxWeightOnlyConfig ,
38
- MappingType ,
39
40
_is_linear ,
40
41
int4_weight_only ,
41
42
quantize_ ,
42
43
)
44
+ from torchao .quantization .quant_primitives import MappingType
43
45
from torchao .utils import (
44
46
TORCH_VERSION_AT_LEAST_2_4 ,
45
47
TORCH_VERSION_AT_LEAST_2_6 ,
@@ -74,6 +76,59 @@ def build_param_groups(model, b: int = 2, group_size: Optional[int] = None):
74
76
]
75
77
76
78
79
+ def compare_quantized_models (
80
+ model : nn .Module ,
81
+ m_ref : nn .Module ,
82
+ quantizer : UnifTorchaoQuantizer ,
83
+ b : int ,
84
+ group_size : int ,
85
+ ):
86
+ for n , module in model .named_children ():
87
+ if not _is_linear (module ):
88
+ continue
89
+
90
+ # simulate grouping from QuantOptimizer.step
91
+ p = module .weight
92
+ original_shape = p .shape
93
+ p = p .view (- 1 , group_size )
94
+
95
+ q , Q = quantizer .quantize (p , b = b , dim = - 1 )
96
+
97
+ # compare to AffineQuantizedTensor instance
98
+ q = q .view (original_shape )
99
+ ref = getattr (m_ref , n ).weight .dequantize ()
100
+ torch .testing .assert_close (q , ref , atol = 0 , rtol = 0 )
101
+
102
+
103
+ def compare_parq_convert (
104
+ model : nn .Module ,
105
+ m_ref : nn .Module ,
106
+ optimizer : QuantOptimizer ,
107
+ config : AOBaseConfig ,
108
+ ):
109
+ # do not update model weights, just quantize
110
+ optimizer .zero_grad ()
111
+ optimizer .step ()
112
+
113
+ orig_model = copy .deepcopy (model ) # save copy of PARQ quantized model
114
+
115
+ # equivalent to torchao's convert step
116
+ model .eval ()
117
+ optimizer .restore_latent_params ()
118
+ quantize_ (model , config , filter_fn = optimizer .get_filter_fn (model ))
119
+
120
+ for n , module in model .named_modules ():
121
+ if not _is_linear (module ):
122
+ continue
123
+
124
+ p_orig = getattr (orig_model , n ).weight # PARQ weight
125
+ p = module .weight .dequantize () # PARQ weight after quantize_
126
+ p_ref = getattr (m_ref , n ).weight .dequantize () # native quantize_
127
+
128
+ torch .testing .assert_true (p_orig , p_ref , atol = 0 , rtol = 0 )
129
+ torch .testing .assert_true (p , p_ref , atol = 0 , rtol = 0 )
130
+
131
+
77
132
class M (nn .Module ):
78
133
def __init__ (self , m = 256 , n = 128 , k = 16 , bias = False , embedding = True ):
79
134
super ().__init__ ()
@@ -143,59 +198,6 @@ class TestUnifTorchaoQuantizer(common_utils.TestCase):
143
198
def setUp (self ):
144
199
torch .manual_seed (123 )
145
200
146
- def compare_quantized_models (
147
- self ,
148
- model : nn .Module ,
149
- m_ref : nn .Module ,
150
- quantizer : UnifTorchaoQuantizer ,
151
- b : int ,
152
- group_size : int ,
153
- ):
154
- for n , module in model .named_children ():
155
- if not _is_linear (module ):
156
- continue
157
-
158
- # simulate grouping from QuantOptimizer.step
159
- p = module .weight
160
- original_shape = p .shape
161
- p = p .view (- 1 , group_size )
162
-
163
- q , Q = quantizer .quantize (p , b = b , dim = - 1 )
164
-
165
- # compare to AffineQuantizedTensor instance
166
- q = q .view (original_shape )
167
- ref = getattr (m_ref , n ).weight .dequantize ()
168
- torch .testing .assert_close (q , ref , atol = 0 , rtol = 0 )
169
-
170
- def compare_parq_convert (
171
- self ,
172
- model : nn .Module ,
173
- m_ref : nn .Module ,
174
- optimizer : QuantOptimizer ,
175
- config : AOBaseConfig ,
176
- ):
177
- # do not update model weights, just quantize
178
- optimizer .zero_grad ()
179
- optimizer .step ()
180
-
181
- orig_model = copy .deepcopy (model ) # save copy of PARQ quantized model
182
-
183
- # equivalent to torchao's convert step
184
- model .eval ()
185
- optimizer .restore_latent_params ()
186
- quantize_ (model , config , filter_fn = optimizer .get_filter_fn (model ))
187
-
188
- for n , module in model .named_modules ():
189
- if not _is_linear (module ):
190
- continue
191
-
192
- p_orig = getattr (orig_model , n ).weight # PARQ weight
193
- p = module .weight .dequantize () # PARQ weight after quantize_
194
- p_ref = getattr (m_ref , n ).weight .dequantize () # native quantize_
195
-
196
- torch .testing .assert_true (p_orig , p_ref , atol = 0 , rtol = 0 )
197
- torch .testing .assert_true (p , p_ref , atol = 0 , rtol = 0 )
198
-
199
201
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "Test only enabled for 2.4+" )
200
202
@common_utils .parametrize ("group_size" , [32 , 256 ])
201
203
def test_int4_weight_only (self , group_size : int = 32 ):
@@ -209,7 +211,7 @@ def test_int4_weight_only(self, group_size: int = 32):
209
211
quantize_ (m_ref , config )
210
212
211
213
b = 4
212
- self . compare_quantized_models (
214
+ compare_quantized_models (
213
215
model , m_ref , Int4UnifTorchaoQuantizer (), b , group_size
214
216
)
215
217
@@ -229,7 +231,7 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
229
231
)
230
232
231
233
quantizer = UnifTorchaoQuantizer ()
232
- self . compare_quantized_models (model , m_ref , quantizer , b , group_size )
234
+ compare_quantized_models (model , m_ref , quantizer , b , group_size )
233
235
234
236
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "Test only enabled for 2.4+" )
235
237
@unittest .skipIf (_DEVICE == "cpu" , "Need GPU available" )
@@ -251,7 +253,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
251
253
ProxHardQuant (),
252
254
quant_per_channel = True ,
253
255
)
254
- self . compare_parq_convert (model , m_ref , optimizer , config )
256
+ compare_parq_convert (model , m_ref , optimizer , config )
255
257
256
258
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_6 , "Test only enabled for 2.6+" )
257
259
@unittest .skipIf (_DEVICE == "cpu" , "Need GPU available" )
@@ -273,7 +275,84 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
273
275
ProxHardQuant (),
274
276
quant_per_channel = True ,
275
277
)
276
- self .compare_parq_convert (model , m_ref , optimizer , config )
278
+ compare_parq_convert (model , m_ref , optimizer , config )
279
+
280
+
281
+ class TestStretchedUnifTorchaoQuantizer (common_utils .TestCase ):
282
+ def setUp (self ):
283
+ torch .manual_seed (123 )
284
+
285
+ @common_utils .parametrize ("b" , [2 , 3 ])
286
+ @common_utils .parametrize ("group_size" , [32 , 256 ])
287
+ def test_intx_weight_only_parq_equivalent (self , b : int = 2 , group_size : int = 32 ):
288
+ model = M (m = 512 , n = 512 ).to (_DEVICE )
289
+ model .reset_parameters ()
290
+
291
+ quantizer_ref = UnifQuantizer ()
292
+ quantizer = StretchedUnifTorchaoQuantizer (b )
293
+
294
+ for n , module in model .named_children ():
295
+ if not _is_linear (module ):
296
+ continue
297
+
298
+ # simulate grouping from QuantOptimizer.step
299
+ p = module .weight
300
+ p = p .view (- 1 , group_size )
301
+
302
+ q_ref , Q_ref = quantizer_ref .quantize (p , b = b , dim = - 1 )
303
+ q , Q = quantizer .quantize (p , b = b , dim = - 1 )
304
+
305
+ torch .testing .assert_close (q , q_ref , atol = 0 , rtol = 0 )
306
+ torch .testing .assert_close (Q , Q_ref , atol = 0 , rtol = 0 )
307
+
308
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_6 , "Test only enabled for 2.6+" )
309
+ @common_utils .parametrize ("b" , [2 , 3 ])
310
+ @common_utils .parametrize ("group_size" , [32 , 512 ])
311
+ def test_intx_weight_only (self , b : int = 2 , group_size : int = 32 ):
312
+ model = M (m = 512 , n = 512 ).to (_DEVICE )
313
+ model .reset_parameters ()
314
+
315
+ quantizer = StretchedUnifTorchaoQuantizer (b )
316
+
317
+ m_ref = copy .deepcopy (model ).eval ().to (_DEVICE )
318
+ quantize_ (
319
+ m_ref ,
320
+ StretchedIntxWeightOnlyConfig (
321
+ b = b ,
322
+ quant_min = quantizer .quant_min ,
323
+ quant_max = quantizer .quant_max ,
324
+ granularity = PerGroup (group_size ),
325
+ ),
326
+ )
327
+
328
+ compare_quantized_models (model , m_ref , quantizer , b , group_size )
329
+
330
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_6 , "Test only enabled for 2.6+" )
331
+ @unittest .skipIf (_DEVICE == "cpu" , "Need GPU available" )
332
+ @common_utils .parametrize ("b" , [2 , 3 ])
333
+ def test_intx_weight_only_e2e (self , b : int = 2 , group_size : int = 32 ):
334
+ model = M (m = 512 , n = 512 ).to (_DEVICE )
335
+ model .reset_parameters ()
336
+
337
+ quantizer = StretchedUnifTorchaoQuantizer (b )
338
+
339
+ m_ref = copy .deepcopy (model ).eval ().to (_DEVICE )
340
+ config = StretchedIntxWeightOnlyConfig (
341
+ b = b ,
342
+ quant_min = quantizer .quant_min ,
343
+ quant_max = quantizer .quant_max ,
344
+ granularity = PerGroup (group_size ),
345
+ )
346
+ quantize_ (m_ref , config )
347
+
348
+ base_optimizer = torch .optim .AdamW (build_param_groups (model , b , group_size ))
349
+ optimizer = QuantOptimizer (
350
+ base_optimizer ,
351
+ quantizer ,
352
+ ProxHardQuant (),
353
+ quant_per_channel = True ,
354
+ )
355
+ compare_parq_convert (model , m_ref , optimizer , config )
277
356
278
357
279
358
class TestInt8DynamicActivationTorchaoQuantizer (common_utils .TestCase ):
0 commit comments