3
3
#
4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
- import os
7
- from copy import deepcopy
6
+ import copy
7
+ import tempfile
8
+ import unittest
8
9
9
- import pytest
10
10
import torch
11
+ from torch .testing ._internal .common_utils import (
12
+ TestCase ,
13
+ run_tests ,
14
+ )
11
15
12
- from torchao .quantization import quantize_
13
- from torchao .testing . utils import skip_if_rocm
16
+ from torchao .prototype . awq import AWQConfig , AWQStep
17
+ from torchao .quantization import FbgemmConfig , Int4WeightOnlyConfig , quantize_
14
18
from torchao .utils import (
15
- TORCH_VERSION_AT_LEAST_2_3 ,
16
- TORCH_VERSION_AT_LEAST_2_5 ,
19
+ TORCH_VERSION_AT_LEAST_2_6 ,
20
+ _is_fbgemm_genai_gpu_available ,
17
21
)
18
22
19
- if TORCH_VERSION_AT_LEAST_2_3 :
20
- from torchao .prototype .awq import AWQObservedLinear , awq_uintx , insert_awq_observer_
21
-
22
23
23
24
class ToyLinearModel (torch .nn .Module ):
24
25
def __init__ (self , m = 512 , n = 256 , k = 128 ):
25
26
super ().__init__ ()
26
27
self .linear1 = torch .nn .Linear (m , n , bias = False )
27
28
self .linear2 = torch .nn .Linear (n , k , bias = False )
28
- self .linear3 = torch .nn .Linear (k , 1 , bias = False )
29
+ self .linear3 = torch .nn .Linear (k , 64 , bias = False )
29
30
30
31
def example_inputs (
31
32
self , batch_size , sequence_length = 10 , dtype = torch .bfloat16 , device = "cuda"
@@ -44,137 +45,197 @@ def forward(self, x):
44
45
return x
45
46
46
47
47
- devices = ["cpu" , "cuda" ]
48
- # torch.uintx dtypes are introduced in 2.3
49
- if TORCH_VERSION_AT_LEAST_2_3 :
50
- qdtypes = (torch .uint4 , torch .uint7 )
51
- else :
52
- qdtypes = ()
53
-
54
-
55
- @pytest .fixture (autouse = True )
56
- def run_before_and_after_tests ():
57
- yield
58
- torch ._dynamo .reset () # reset cache between tests
59
-
60
-
61
- @pytest .mark .parametrize ("device" , devices )
62
- @pytest .mark .parametrize ("qdtype" , qdtypes )
63
- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
64
- @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "requires nightly pytorch" )
65
- @pytest .mark .skip ("Temporarily skipping to unpin nightiles" )
66
- def test_awq_loading (device , qdtype ):
67
- if qdtype == torch .uint4 and device == "cpu" :
68
- pytest .skip ("uint4 not supported on cpu" )
69
-
70
- dataset_size = 100
71
- l1 , l2 , l3 = 512 , 256 , 128
72
- original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
73
- quant_dtype = qdtype
74
- group_size = 128
75
- n_calibration_examples = 10
76
- n_validation_examples = 10
77
- sequence_length = 5
78
-
79
- m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
80
- dataset = m .example_inputs (
81
- dataset_size ,
82
- sequence_length = sequence_length ,
83
- dtype = original_dtype ,
84
- device = device ,
85
- )
86
- calibration_data = dataset [:n_calibration_examples ]
87
-
88
- # calibrate
89
- insert_awq_observer_ (
90
- m ,
91
- n_validation_examples ,
92
- sequence_length ,
93
- quant_dtype = quant_dtype ,
94
- group_size = group_size ,
95
- )
96
-
97
- for example in calibration_data :
98
- m (example .to (device ))
99
-
100
- # quantize
101
- is_observed_linear = lambda m , fqn : isinstance (m , AWQObservedLinear )
102
- quantize_ (
103
- m , awq_uintx (quant_dtype = quant_dtype , group_size = group_size ), is_observed_linear
104
- )
105
-
106
- model_save_path = "awq_model.pth"
107
- torch .save (m , model_save_path )
108
- loaded_model = torch .load (model_save_path )
109
- os .remove (model_save_path )
110
-
111
- if torch .cuda .is_available ():
48
+ @unittest .skipIf (not torch .cuda .is_available (), reason = "CUDA not available" )
49
+ @unittest .skipIf (
50
+ not _is_fbgemm_genai_gpu_available (),
51
+ reason = "need to install fbgemm_gpu_genai package" ,
52
+ )
53
+ @unittest .skipIf (
54
+ not TORCH_VERSION_AT_LEAST_2_6 ,
55
+ reason = "torch.int4 needs torch 2.6+, can remove after we are not using FbgemmConfig" ,
56
+ )
57
+ class TestAWQ (TestCase ):
58
+ def test_awq_config (self ):
59
+ base_config = Int4WeightOnlyConfig ()
60
+ AWQConfig (base_config , step = AWQStep .PREPARE )
61
+ AWQConfig (base_config , step = AWQStep .PREPARE_FOR_LOADING )
62
+ AWQConfig (base_config , step = AWQStep .CONVERT )
63
+
64
+ AWQConfig (base_config , step = "prepare" )
65
+ AWQConfig (base_config , step = "prepare_for_loading" )
66
+ AWQConfig (base_config , step = "convert" )
67
+
68
+ with self .assertRaisesRegex (ValueError , "is not one of" ):
69
+ AWQConfig (base_config , step = "not_supported" )
70
+
71
+ def test_awq_functionality (self ):
72
+ device = "cuda"
73
+ dataset_size = 100
74
+ l1 , l2 , l3 = 512 , 256 , 128
75
+ original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
76
+ group_size = 128
77
+ n_calibration_examples = 10
78
+ sequence_length = 5
79
+
80
+ m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
81
+
82
+ # baseline quantization
83
+ base_config = FbgemmConfig (
84
+ input_dtype = torch .bfloat16 ,
85
+ weight_dtype = torch .int4 ,
86
+ output_dtype = torch .bfloat16 ,
87
+ block_size = [1 , group_size ],
88
+ preshuffle = False ,
89
+ )
90
+ m_baseline = copy .deepcopy (m )
91
+ quantize_ (m_baseline , base_config )
92
+
93
+ # awq quantization
94
+ dataset = m .example_inputs (
95
+ dataset_size ,
96
+ sequence_length = sequence_length ,
97
+ dtype = original_dtype ,
98
+ device = device ,
99
+ )
100
+ ref_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
101
+
102
+ calibration_data = dataset [:n_calibration_examples ]
103
+
104
+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
105
+ quantize_ (m , quant_config )
106
+
107
+ for example in calibration_data :
108
+ m (example )
109
+
110
+ quant_config = AWQConfig (base_config , step = AWQStep .CONVERT )
111
+ quantize_ (m , quant_config )
112
+
113
+ awq_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
114
+ baseline_out = torch .cat ([m_baseline (d .squeeze (0 )) for d in dataset ])
115
+
116
+ loss_awq = (ref_out - awq_out ).pow (2 ).mean ().item ()
117
+ loss_base = (ref_out - baseline_out ).pow (2 ).mean ().item ()
118
+ assert loss_awq < loss_base
119
+
120
+ def test_awq_loading (self ):
121
+ device = "cuda"
122
+ dataset_size = 100
123
+ l1 , l2 , l3 = 512 , 256 , 128
124
+ original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
125
+ group_size = 128
126
+ n_calibration_examples = 10
127
+ sequence_length = 5
128
+
129
+ m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
130
+ dataset = m .example_inputs (
131
+ dataset_size ,
132
+ sequence_length = sequence_length ,
133
+ dtype = original_dtype ,
134
+ device = device ,
135
+ )
136
+ calibration_data = dataset [:n_calibration_examples ]
137
+
138
+ # calibrate
139
+ base_config = FbgemmConfig (
140
+ input_dtype = torch .bfloat16 ,
141
+ weight_dtype = torch .int4 ,
142
+ output_dtype = torch .bfloat16 ,
143
+ block_size = [1 , group_size ],
144
+ preshuffle = False ,
145
+ )
146
+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
147
+ quantize_ (m , quant_config )
148
+
149
+ for example in calibration_data :
150
+ m (example )
151
+
152
+ # quantize
153
+ quant_config = AWQConfig (base_config , step = AWQStep .CONVERT )
154
+ quantize_ (m , quant_config )
155
+
156
+ with tempfile .NamedTemporaryFile () as f :
157
+ torch .save (m .state_dict (), f )
158
+ f .seek (0 )
159
+ state_dict = torch .load (f )
160
+
161
+ loaded_model = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
162
+ loaded_model .load_state_dict (state_dict , assign = True )
163
+
164
+ m = torch .compile (m , fullgraph = True )
165
+ loaded_model = torch .compile (loaded_model , fullgraph = True )
166
+
167
+ awq_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
168
+ awq_save_load_out = torch .cat ([loaded_model (d .squeeze (0 )) for d in dataset ])
169
+
170
+ assert awq_out is not None
171
+ assert awq_save_load_out is not None
172
+ assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
173
+
174
+ def test_awq_loading_vllm (self ):
175
+ """Simulate weight loading in vllm:
176
+ * prepare model weight to the same format (awq weight)
177
+ * use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint
178
+
179
+ There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo
180
+ """
181
+ device = "cuda"
182
+ dataset_size = 100
183
+ l1 , l2 , l3 = 512 , 256 , 128
184
+ original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
185
+ group_size = 128
186
+ n_calibration_examples = 10
187
+ sequence_length = 5
188
+
189
+ m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
190
+ dataset = m .example_inputs (
191
+ dataset_size ,
192
+ sequence_length = sequence_length ,
193
+ dtype = original_dtype ,
194
+ device = device ,
195
+ )
196
+ calibration_data = dataset [:n_calibration_examples ]
197
+
198
+ # calibrate
199
+ base_config = FbgemmConfig (
200
+ input_dtype = torch .bfloat16 ,
201
+ weight_dtype = torch .int4 ,
202
+ output_dtype = torch .bfloat16 ,
203
+ block_size = [1 , group_size ],
204
+ preshuffle = False ,
205
+ )
206
+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
207
+ quantize_ (m , quant_config )
208
+
209
+ for example in calibration_data :
210
+ m (example )
211
+
212
+ # quantize
213
+ quant_config = AWQConfig (base_config , step = AWQStep .CONVERT )
214
+ quantize_ (m , quant_config )
215
+
216
+ with tempfile .NamedTemporaryFile () as f :
217
+ torch .save (m .state_dict (), f )
218
+ f .seek (0 )
219
+ state_dict = torch .load (f )
220
+
221
+ loaded_model = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
222
+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE_FOR_LOADING )
223
+ quantize_ (loaded_model , quant_config )
224
+
225
+ loaded_model .linear1 .weight .copy_ (state_dict ["linear1.weight" ])
226
+ loaded_model .linear2 .weight .copy_ (state_dict ["linear2.weight" ])
227
+ loaded_model .linear3 .weight .copy_ (state_dict ["linear3.weight" ])
228
+
112
229
m = torch .compile (m , fullgraph = True )
113
230
loaded_model = torch .compile (loaded_model , fullgraph = True )
114
231
115
- awq_out = torch .cat ([m (i .squeeze (0 )) for i in dataset ])
116
- awq_save_load_out = torch .cat ([loaded_model (i .squeeze (0 )) for i in dataset ])
117
-
118
- assert awq_out is not None
119
- assert awq_save_load_out is not None
120
- assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
121
-
122
-
123
- @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "requires nightly pytorch" )
124
- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
125
- @skip_if_rocm ("ROCm enablement in progress" )
126
- def test_save_weights_only ():
127
- dataset_size = 100
128
- l1 , l2 , l3 = 512 , 256 , 128
129
- original_dtype = torch .bfloat16
130
- quant_dtype = torch .uint4
131
- device = "cuda"
132
- group_size = 128
133
- n_calibration_examples = 10
134
- n_validation_examples = 10
135
- sequence_length = 5
136
-
137
- m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
138
- m2 = deepcopy (m )
139
- dataset = m .example_inputs (
140
- dataset_size ,
141
- sequence_length = sequence_length ,
142
- dtype = original_dtype ,
143
- device = device ,
144
- )
145
- calibration_data = dataset [:n_calibration_examples ]
146
-
147
- # calibrate
148
- insert_awq_observer_ (
149
- m ,
150
- n_validation_examples ,
151
- sequence_length ,
152
- quant_dtype = quant_dtype ,
153
- group_size = group_size ,
154
- )
155
-
156
- for example in calibration_data :
157
- m (example .to (device ))
158
-
159
- # quantize
160
- is_observed_linear = lambda m , fqn : isinstance (m , AWQObservedLinear )
161
- quantize_ (
162
- m , awq_uintx (quant_dtype = quant_dtype , group_size = group_size ), is_observed_linear
163
- )
164
-
165
- model_save_path = "awq_model.pth"
166
- torch .save (m .state_dict (), model_save_path )
167
- m2 .load_state_dict (
168
- torch .load (model_save_path ), assign = True
169
- ) # load weights only.torch.load(model_save_path)
170
- os .remove (model_save_path )
171
-
172
- m = torch .compile (m , fullgraph = True )
173
- m2 = torch .compile (m2 , fullgraph = True )
174
-
175
- awq_out = torch .cat ([m (i .squeeze (0 )) for i in dataset ])
176
- awq_save_load_out = torch .cat ([m2 (i .squeeze (0 )) for i in dataset ])
177
-
178
- assert awq_out is not None
179
- assert awq_save_load_out is not None
180
- assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
232
+ awq_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
233
+ awq_save_load_out = torch .cat ([loaded_model (d .squeeze (0 )) for d in dataset ])
234
+
235
+ assert awq_out is not None
236
+ assert awq_save_load_out is not None
237
+ assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
238
+
239
+
240
+ if __name__ == "__main__" :
241
+ run_tests ()
0 commit comments