4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
import tempfile
7
+ import unittest
7
8
from copy import deepcopy
8
9
9
- import pytest
10
10
import torch
11
+ from torch .testing ._internal import common_utils
11
12
12
13
from torchao .prototype .smoothquant import (
13
14
SmoothQuantConfig ,
25
26
TORCH_VERSION_AT_LEAST_2_5 ,
26
27
)
27
28
28
- if torch .version .hip is not None :
29
- pytest .skip ("Skipping the test in ROCm" , allow_module_level = True )
30
-
31
29
32
30
class ToyLinearModel (torch .nn .Module ):
33
31
def __init__ (self , m = 512 , n = 256 , k = 128 ):
@@ -53,143 +51,224 @@ def forward(self, x):
53
51
return x
54
52
55
53
56
- bias_list = [True , False ]
57
- alpha_list = [None , 0.5 , 0.75 ]
58
- quant_mode_list = ["static" , "dynamic" ]
59
- devices = ["cpu" ]
60
- if torch .cuda .is_available ():
61
- devices .append ("cuda" )
62
- idtypes = (torch .float , torch .bfloat16 , torch .half )
63
-
64
- if TORCH_VERSION_AT_LEAST_2_5 :
65
- # This test case will trigger recompilation many times, so set a large cache_size_limit here
66
- torch ._dynamo .config .cache_size_limit = 128
67
-
68
-
69
- @pytest .mark .parametrize ("bias" , bias_list )
70
- @pytest .mark .parametrize ("alpha" , alpha_list )
71
- @pytest .mark .parametrize ("quant_mode" , quant_mode_list )
72
- @pytest .mark .parametrize ("device" , devices )
73
- @pytest .mark .parametrize ("idtype" , idtypes )
74
- @pytest .mark .skip ("this test is broken on recent PyTorch, TODO(#1639): fix it" )
75
- def test_compute (bias , alpha , quant_mode , device , idtype ):
76
- class Linear (torch .nn .Module ):
77
- def __init__ (self , bias : bool ):
78
- super ().__init__ ()
79
- self .fc = torch .nn .Linear (32 , 32 , bias )
80
- self .fc .weight .data = torch .randn_like (self .fc .weight .data )
81
-
82
- def forward (self , x ):
83
- return self .fc (x )
84
-
85
- m = Linear (bias ).eval ().to (idtype ).to (device )
86
- m_ref = deepcopy (m )
87
- data = torch .randn (2 , 32 , dtype = idtype , device = device )
88
-
89
- # calibrate
90
- insert_smooth_quant_observer_ (m , alpha , quant_mode )
91
- m (data )
92
- # quantize
93
- is_observed_linear = lambda m , fqn : isinstance (m , SmoothQuantObservedLinear )
94
- quantize_ (m , SmoothQuantConfig (), is_observed_linear )
95
- with torch .inference_mode ():
54
+ @unittest .skipIf (torch .version .hip is not None , "Skipping tests in ROCm" )
55
+ class TestSmoothQuant (unittest .TestCase ):
56
+ @classmethod
57
+ def setUpClass (cls ):
58
+ """Set up class-level configuration for tests."""
59
+ if TORCH_VERSION_AT_LEAST_2_5 :
60
+ # This test case will trigger recompilation many times, so set a large cache_size_limit here
61
+ torch ._dynamo .config .cache_size_limit = 128
62
+
63
+ @unittest .skip ("This test is broken on recent PyTorch, TODO(#1639): fix it" )
64
+ @common_utils .parametrize ("bias" , [True , False ])
65
+ @common_utils .parametrize ("alpha" , [None , 0.5 , 0.75 ])
66
+ @common_utils .parametrize ("quant_mode" , ["static" , "dynamic" ])
67
+ @common_utils .parametrize (
68
+ "device" , ["cpu" ] + (["cuda" ] if torch .cuda .is_available () else [])
69
+ )
70
+ @common_utils .parametrize ("input_dtype" , [torch .float , torch .bfloat16 , torch .half ])
71
+ def test_smoothquant_accuracy (self , bias , alpha , quant_mode , device , input_dtype ):
72
+ """Test the margin error of SmoothQuant across bias, alpha, dtype, etc."""
73
+
74
+ class SimpleLinear (torch .nn .Module ):
75
+ def __init__ (self , bias : bool ):
76
+ super ().__init__ ()
77
+ self .fc = torch .nn .Linear (32 , 32 , bias )
78
+ self .fc .weight .data = torch .randn_like (self .fc .weight .data )
79
+
80
+ def forward (self , x ):
81
+ return self .fc (x )
82
+
83
+ # Create model, reference, and test data
84
+ m = SimpleLinear (bias ).eval ().to (input_dtype ).to (device )
85
+ m_ref = deepcopy (m )
86
+ test_data = torch .randn (2 , 32 , dtype = input_dtype , device = device )
87
+
88
+ # Step 1: Setup quantized model with observer insertion and calibration
89
+ insert_smooth_quant_observer_ (m , alpha , quant_mode )
90
+
91
+ # Perform calibration with test data
92
+ m (test_data )
93
+
94
+ # Apply quantization configuration
95
+ is_observed_linear = lambda m , fqn : isinstance (m , SmoothQuantObservedLinear )
96
+ quantize_ (m , SmoothQuantConfig (), is_observed_linear )
97
+
98
+ # Apply compilation if supported
96
99
if TORCH_VERSION_AT_LEAST_2_5 :
97
100
m = torch .compile (m , fullgraph = True )
98
- out = m (data )
99
-
100
- # reference
101
- weight = m_ref .fc .weight .data .float ()
102
- b = m_ref .fc .bias if bias else None
103
- x_abs_max_per_ic = torch .abs (data ).max (dim = 0 ).values
104
- w_abs_max_per_ic = torch .abs (weight ).max (dim = 0 ).values
105
- smoothing_factor = (
106
- 1
107
- if alpha is None
108
- else (
109
- torch .pow (x_abs_max_per_ic , alpha )
110
- / torch .pow (w_abs_max_per_ic , 1 - alpha )
101
+
102
+ # Step 2: Inference quantized model
103
+ with torch .inference_mode ():
104
+ q_out = m (test_data )
105
+
106
+ # Step 3: Compute reference
107
+ weight = m_ref .fc .weight .data .float ()
108
+ b = m_ref .fc .bias if bias else None
109
+ x_abs_max_per_ic = torch .abs (test_data ).max (dim = 0 ).values
110
+ w_abs_max_per_ic = torch .abs (weight ).max (dim = 0 ).values
111
+
112
+ if alpha is not None :
113
+ # Apply SmoothQuant
114
+ smoothing_factor = torch .pow (x_abs_max_per_ic , alpha ) / torch .pow (
115
+ w_abs_max_per_ic , 1 - alpha
116
+ )
117
+ else :
118
+ smoothing_factor = torch .ones_like (x_abs_max_per_ic )
119
+
120
+ # Apply smoothing to activations and weights
121
+ smoothed_activation = test_data / smoothing_factor
122
+ smoothed_weight = weight * smoothing_factor
123
+
124
+ # Quantize weights using per-channel quantization
125
+ qw , w_scales , w_zps = dynamically_quantize_per_channel (
126
+ smoothed_weight , - 127 , 127 , torch .int8
111
127
)
112
- )
113
- act = data / smoothing_factor
114
- wei = weight * smoothing_factor
115
- qw , w_scales , w_zps = dynamically_quantize_per_channel (
116
- wei , - 127 , 127 , torch .int8
117
- )
118
- fq_wei = dequantize_per_channel (qw , w_scales , w_zps , idtype )
119
- if quant_mode == "static" :
120
- # activation is quantized per-tensor
121
- act_min , act_max = torch .aminmax (act .float ())
122
- max_val_pos = torch .max (- act_min , act_max )
123
- act_scale = max_val_pos / 127.0
124
- fq_act = (
125
- torch .quantize_per_tensor (
126
- act .float (), scale = act_scale .item (), zero_point = 0 , dtype = torch .qint8
128
+ fq_wei = dequantize_per_channel (qw , w_scales , w_zps , input_dtype )
129
+
130
+ # Handle activation quantization based on mode
131
+ if quant_mode == "static" :
132
+ # activation is quantized per-tensor
133
+ act_min , act_max = torch .aminmax (smoothed_activation .float ())
134
+ max_val_pos = torch .max (- act_min , act_max )
135
+ activation_scale = max_val_pos / 127.0
136
+
137
+ fq_act = (
138
+ torch .quantize_per_tensor (
139
+ smoothed_activation .float (),
140
+ scale = activation_scale .item (),
141
+ zero_point = 0 ,
142
+ dtype = torch .qint8 ,
143
+ )
144
+ .dequantize ()
145
+ .to (input_dtype )
146
+ )
147
+ else :
148
+ # activation is quantized per-row (batch * sequence_length)
149
+ qx , x_scales , x_zps = dynamically_quantize_per_channel (
150
+ smoothed_activation .float (), - 127 , 127 , torch .int8
151
+ )
152
+ fq_act = dequantize_per_channel (
153
+ qx ,
154
+ x_scales ,
155
+ x_zps ,
156
+ input_dtype ,
127
157
)
128
- .dequantize ()
129
- .to (idtype )
158
+
159
+ # Compute final linear operation
160
+ reference_out = torch .nn .functional .linear (fq_act , fq_wei , b )
161
+
162
+ # Step 4: Validate numerical accuracy
163
+ tolerance = (
164
+ 0.1
165
+ if input_dtype == torch .float
166
+ else (0.2 if input_dtype == torch .half else 0.3 )
130
167
)
131
- out_ref = torch .nn .functional .linear (fq_act , fq_wei , b )
132
- else :
133
- # activation is quantized per-row (batch * sequence_length)
134
- qx , x_scales , x_zps = dynamically_quantize_per_channel (
135
- act .float (), - 127 , 127 , torch .int8
168
+ torch .testing .assert_close (
169
+ q_out ,
170
+ reference_out .to (input_dtype ),
171
+ atol = tolerance ,
172
+ msg = f"Quantized output differs from reference for "
173
+ f"bias={ bias } , alpha={ alpha } , quant_mode={ quant_mode } , "
174
+ f"device={ device } , dtype={ input_dtype } " ,
136
175
)
137
- fq_act = dequantize_per_channel (qx , x_scales , x_zps , idtype )
138
- out_ref = torch .nn .functional .linear (fq_act , fq_wei , b )
139
-
140
- # BFloat16 and Float16 have larger errors
141
- atol = 0.1 if idtype == torch .float else (0.2 if idtype == torch .half else 0.3 )
142
- assert torch .allclose (out , out_ref .to (idtype ), atol = atol )
143
-
144
-
145
- @pytest .mark .parametrize ("alpha" , alpha_list )
146
- @pytest .mark .parametrize ("quant_mode" , quant_mode_list )
147
- @pytest .mark .parametrize ("device" , devices )
148
- @pytest .mark .parametrize ("idtype" , idtypes )
149
- @pytest .mark .skip ("this test is broken on recent PyTorch, TODO(#1639): fix it" )
150
- def test_save_load_recipe (alpha , quant_mode , device , idtype ):
151
- dataset_size = 20
152
- l1 , l2 , l3 = 512 , 256 , 128
153
- original_dtype = idtype
154
- n_calib_examples = 10
155
- sequence_length = 5
156
-
157
- m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
158
- m_save_load = deepcopy (m )
159
-
160
- dataset = m .example_inputs (
161
- dataset_size ,
162
- sequence_length = sequence_length ,
163
- dtype = original_dtype ,
164
- device = device ,
176
+
177
+ @unittest .skip ("This test is broken on recent PyTorch, TODO(#1639): fix it" )
178
+ @common_utils .parametrize ("alpha" , [None , 0.5 , 0.75 ])
179
+ @common_utils .parametrize ("quant_mode" , ["static" , "dynamic" ])
180
+ @common_utils .parametrize (
181
+ "device" , ["cpu" ] + (["cuda" ] if torch .cuda .is_available () else [])
165
182
)
166
- calibration_data = dataset [:n_calib_examples ]
167
-
168
- # calibrate
169
- insert_smooth_quant_observer_ (m , alpha , quant_mode )
170
- insert_smooth_quant_observer_ (m_save_load , alpha , quant_mode )
171
-
172
- for example in calibration_data :
173
- m (example .to (device ))
174
- m_save_load (example .to (device ))
175
-
176
- with tempfile .NamedTemporaryFile () as fp :
177
- save_path = fp .name
178
- save_smooth_quant_recipe (m_save_load , save_path )
179
- load_smooth_quant_recipe (m_save_load , save_path )
180
-
181
- # quantize
182
- is_observed_linear = lambda m , fqn : isinstance (m , SmoothQuantObservedLinear )
183
- quantize_ (m , SmoothQuantConfig (), is_observed_linear )
184
- if TORCH_VERSION_AT_LEAST_2_5 :
185
- # earlier versions are not compatible
186
- m = torch .compile (m , fullgraph = True )
187
- m_save_load = torch .compile (m_save_load , fullgraph = True )
188
- out_list = [m (data .squeeze (0 )) for data in dataset ]
189
- out = torch .cat (out_list )
190
- save_load_out_list = [m_save_load (data .squeeze (0 )) for data in dataset ]
191
- save_load_out = torch .cat (save_load_out_list )
192
-
193
- assert out is not None
194
- assert save_load_out is not None
195
- assert torch .allclose (out , save_load_out )
183
+ @common_utils .parametrize ("input_dtype" , [torch .float , torch .bfloat16 , torch .half ])
184
+ def test_save_load_recipe (self , alpha , quant_mode , device , input_dtype ):
185
+ """Test save/load recipe functionality."""
186
+ dataset_size = 20
187
+ layer_dims = (512 , 256 , 128 ) # Input, hidden, output dimensions
188
+ n_calib_examples = 10
189
+ sequence_length = 5
190
+
191
+ # Create two identical models for comparison
192
+ m = ToyLinearModel (* layer_dims ).eval ().to (input_dtype ).to (device )
193
+ m_save_load = deepcopy (m )
194
+
195
+ # Generate calibration dataset
196
+ dataset = m .example_inputs (
197
+ dataset_size ,
198
+ sequence_length = sequence_length ,
199
+ dtype = input_dtype ,
200
+ device = device ,
201
+ )
202
+ calibration_data = dataset [:n_calib_examples ]
203
+
204
+ # Step 1: Setup first quantized model with observer insertion and calibration
205
+ insert_smooth_quant_observer_ (m , alpha , quant_mode )
206
+
207
+ # Perform calibration with calibration data
208
+ for data in calibration_data :
209
+ m (data )
210
+
211
+ # Apply quantization configuration
212
+ is_observed_linear = lambda m , fqn : isinstance (m , SmoothQuantObservedLinear )
213
+ quantize_ (m , SmoothQuantConfig (), is_observed_linear )
214
+
215
+ # Apply compilation if supported
216
+ if TORCH_VERSION_AT_LEAST_2_5 :
217
+ m = torch .compile (m , fullgraph = True )
218
+
219
+ # Step 2: Setup save/load model with recipe functionality
220
+ insert_smooth_quant_observer_ (m_save_load , alpha , quant_mode )
221
+ for example in calibration_data :
222
+ m_save_load (example .to (device ))
223
+
224
+ # Step 3: Test save/load recipe functionality
225
+ with tempfile .NamedTemporaryFile () as temp_file :
226
+ save_path = temp_file .name
227
+ save_smooth_quant_recipe (m_save_load , save_path )
228
+ load_smooth_quant_recipe (m_save_load , save_path )
229
+
230
+ # Step 4: Complete quantization for save/load model
231
+ is_observed_linear = lambda m , fqn : isinstance (m , SmoothQuantObservedLinear )
232
+ quantize_ (m_save_load , SmoothQuantConfig (), is_observed_linear )
233
+
234
+ if TORCH_VERSION_AT_LEAST_2_5 :
235
+ m_save_load = torch .compile (m_save_load , fullgraph = True )
236
+
237
+ # Step 5: Validate outputs on full dataset
238
+ with torch .inference_mode ():
239
+ original_outputs = []
240
+ save_load_outputs = []
241
+
242
+ for data in dataset :
243
+ # Remove batch dimension for model input
244
+ input_tensor = data .squeeze (0 )
245
+
246
+ original_output = m (input_tensor )
247
+ save_load_output = m_save_load (input_tensor )
248
+
249
+ original_outputs .append (original_output )
250
+ save_load_outputs .append (save_load_output )
251
+
252
+ # Concatenate all outputs for comparison
253
+ original_result = torch .cat (original_outputs )
254
+ save_load_out = torch .cat (save_load_outputs )
255
+
256
+ self .assertIsNotNone (
257
+ original_result , "Original model output should not be None"
258
+ )
259
+ self .assertIsNotNone (
260
+ save_load_out , "Save/load model output should not be None"
261
+ )
262
+
263
+ torch .testing .assert_close (
264
+ original_result ,
265
+ save_load_out ,
266
+ msg = f"Save/load recipe should produce identical results for "
267
+ f"alpha={ alpha } , quant_mode={ quant_mode } , device={ device } , dtype={ input_dtype } " ,
268
+ )
269
+
270
+
271
+ common_utils .instantiate_parametrized_tests (TestSmoothQuant )
272
+
273
+ if __name__ == "__main__" :
274
+ unittest .main ()
0 commit comments