3
3
#
4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
- import os
7
- from copy import deepcopy
6
+ import copy
7
+ import tempfile
8
+ import unittest
8
9
9
- import pytest
10
10
import torch
11
-
12
- from torchao .quantization import quantize_
13
- from torchao .testing .utils import skip_if_rocm
14
- from torchao .utils import (
15
- TORCH_VERSION_AT_LEAST_2_3 ,
16
- TORCH_VERSION_AT_LEAST_2_5 ,
11
+ from torch .testing ._internal .common_utils import (
12
+ TestCase ,
13
+ run_tests ,
17
14
)
18
15
19
- if TORCH_VERSION_AT_LEAST_2_3 :
20
- from torchao .prototype . awq import AWQObservedLinear , awq_uintx , insert_awq_observer_
16
+ from torchao . prototype . awq import AWQConfig , AWQStep
17
+ from torchao .quantization import FbgemmConfig , Int4WeightOnlyConfig , quantize_
21
18
22
19
23
20
class ToyLinearModel (torch .nn .Module ):
24
21
def __init__ (self , m = 512 , n = 256 , k = 128 ):
25
22
super ().__init__ ()
26
23
self .linear1 = torch .nn .Linear (m , n , bias = False )
27
24
self .linear2 = torch .nn .Linear (n , k , bias = False )
28
- self .linear3 = torch .nn .Linear (k , 1 , bias = False )
25
+ self .linear3 = torch .nn .Linear (k , 64 , bias = False )
29
26
30
27
def example_inputs (
31
28
self , batch_size , sequence_length = 10 , dtype = torch .bfloat16 , device = "cuda"
@@ -44,137 +41,189 @@ def forward(self, x):
44
41
return x
45
42
46
43
47
- devices = ["cpu" , "cuda" ]
48
- # torch.uintx dtypes are introduced in 2.3
49
- if TORCH_VERSION_AT_LEAST_2_3 :
50
- qdtypes = (torch .uint4 , torch .uint7 )
51
- else :
52
- qdtypes = ()
53
-
54
-
55
- @pytest .fixture (autouse = True )
56
- def run_before_and_after_tests ():
57
- yield
58
- torch ._dynamo .reset () # reset cache between tests
59
-
60
-
61
- @pytest .mark .parametrize ("device" , devices )
62
- @pytest .mark .parametrize ("qdtype" , qdtypes )
63
- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
64
- @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "requires nightly pytorch" )
65
- @pytest .mark .skip ("Temporarily skipping to unpin nightiles" )
66
- def test_awq_loading (device , qdtype ):
67
- if qdtype == torch .uint4 and device == "cpu" :
68
- pytest .skip ("uint4 not supported on cpu" )
69
-
70
- dataset_size = 100
71
- l1 , l2 , l3 = 512 , 256 , 128
72
- original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
73
- quant_dtype = qdtype
74
- group_size = 128
75
- n_calibration_examples = 10
76
- n_validation_examples = 10
77
- sequence_length = 5
78
-
79
- m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
80
- dataset = m .example_inputs (
81
- dataset_size ,
82
- sequence_length = sequence_length ,
83
- dtype = original_dtype ,
84
- device = device ,
85
- )
86
- calibration_data = dataset [:n_calibration_examples ]
87
-
88
- # calibrate
89
- insert_awq_observer_ (
90
- m ,
91
- n_validation_examples ,
92
- sequence_length ,
93
- quant_dtype = quant_dtype ,
94
- group_size = group_size ,
95
- )
96
-
97
- for example in calibration_data :
98
- m (example .to (device ))
99
-
100
- # quantize
101
- is_observed_linear = lambda m , fqn : isinstance (m , AWQObservedLinear )
102
- quantize_ (
103
- m , awq_uintx (quant_dtype = quant_dtype , group_size = group_size ), is_observed_linear
104
- )
105
-
106
- model_save_path = "awq_model.pth"
107
- torch .save (m , model_save_path )
108
- loaded_model = torch .load (model_save_path )
109
- os .remove (model_save_path )
110
-
111
- if torch .cuda .is_available ():
44
+ @unittest .skipIf (not torch .cuda .is_available (), reason = "CUDA not available" )
45
+ class TestAWQ (TestCase ):
46
+ def test_awq_config (self ):
47
+ base_config = Int4WeightOnlyConfig ()
48
+ AWQConfig (base_config , step = AWQStep .PREPARE )
49
+ AWQConfig (base_config , step = AWQStep .PREPARE_FOR_LOADING )
50
+ AWQConfig (base_config , step = AWQStep .CONVERT )
51
+
52
+ AWQConfig (base_config , step = "prepare" )
53
+ AWQConfig (base_config , step = "prepare_for_loading" )
54
+ AWQConfig (base_config , step = "convert" )
55
+
56
+ with self .assertRaisesRegex (ValueError , "is not one of" ):
57
+ AWQConfig (base_config , step = "not_supported" )
58
+
59
+ def test_awq_functionality (self ):
60
+ device = "cuda"
61
+ dataset_size = 100
62
+ l1 , l2 , l3 = 512 , 256 , 128
63
+ original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
64
+ group_size = 128
65
+ n_calibration_examples = 10
66
+ sequence_length = 5
67
+
68
+ m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
69
+
70
+ # baseline quantization
71
+ base_config = FbgemmConfig (
72
+ input_dtype = torch .bfloat16 ,
73
+ weight_dtype = torch .int4 ,
74
+ output_dtype = torch .bfloat16 ,
75
+ block_size = [1 , group_size ],
76
+ preshuffle = False ,
77
+ )
78
+ m_baseline = copy .deepcopy (m )
79
+ quantize_ (m_baseline , base_config )
80
+
81
+ # awq quantization
82
+ dataset = m .example_inputs (
83
+ dataset_size ,
84
+ sequence_length = sequence_length ,
85
+ dtype = original_dtype ,
86
+ device = device ,
87
+ )
88
+ ref_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
89
+
90
+ calibration_data = dataset [:n_calibration_examples ]
91
+
92
+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
93
+ quantize_ (m , quant_config )
94
+
95
+ for example in calibration_data :
96
+ m (example )
97
+
98
+ quant_config = AWQConfig (base_config , step = AWQStep .CONVERT )
99
+ quantize_ (m , quant_config )
100
+
101
+ awq_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
102
+ baseline_out = torch .cat ([m_baseline (d .squeeze (0 )) for d in dataset ])
103
+
104
+ loss_awq = (ref_out - awq_out ).pow (2 ).mean ().item ()
105
+ loss_base = (ref_out - baseline_out ).pow (2 ).mean ().item ()
106
+ assert loss_awq < loss_base
107
+
108
+ def test_awq_loading (self ):
109
+ device = "cuda"
110
+ dataset_size = 100
111
+ l1 , l2 , l3 = 512 , 256 , 128
112
+ original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
113
+ group_size = 128
114
+ n_calibration_examples = 10
115
+ sequence_length = 5
116
+
117
+ m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
118
+ dataset = m .example_inputs (
119
+ dataset_size ,
120
+ sequence_length = sequence_length ,
121
+ dtype = original_dtype ,
122
+ device = device ,
123
+ )
124
+ calibration_data = dataset [:n_calibration_examples ]
125
+
126
+ # calibrate
127
+ base_config = FbgemmConfig (
128
+ input_dtype = torch .bfloat16 ,
129
+ weight_dtype = torch .int4 ,
130
+ output_dtype = torch .bfloat16 ,
131
+ block_size = [1 , group_size ],
132
+ preshuffle = False ,
133
+ )
134
+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
135
+ quantize_ (m , quant_config )
136
+
137
+ for example in calibration_data :
138
+ m (example )
139
+
140
+ # quantize
141
+ quant_config = AWQConfig (base_config , step = AWQStep .CONVERT )
142
+ quantize_ (m , quant_config )
143
+
144
+ with tempfile .NamedTemporaryFile () as f :
145
+ torch .save (m .state_dict (), f )
146
+ f .seek (0 )
147
+ state_dict = torch .load (f )
148
+
149
+ loaded_model = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
150
+ loaded_model .load_state_dict (state_dict , assign = True )
151
+
112
152
m = torch .compile (m , fullgraph = True )
113
153
loaded_model = torch .compile (loaded_model , fullgraph = True )
114
154
115
- awq_out = torch .cat ([m (i .squeeze (0 )) for i in dataset ])
116
- awq_save_load_out = torch .cat ([loaded_model (i .squeeze (0 )) for i in dataset ])
117
-
118
- assert awq_out is not None
119
- assert awq_save_load_out is not None
120
- assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
121
-
122
-
123
- @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "requires nightly pytorch" )
124
- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
125
- @skip_if_rocm ("ROCm enablement in progress" )
126
- def test_save_weights_only ():
127
- dataset_size = 100
128
- l1 , l2 , l3 = 512 , 256 , 128
129
- original_dtype = torch .bfloat16
130
- quant_dtype = torch .uint4
131
- device = "cuda"
132
- group_size = 128
133
- n_calibration_examples = 10
134
- n_validation_examples = 10
135
- sequence_length = 5
136
-
137
- m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
138
- m2 = deepcopy (m )
139
- dataset = m .example_inputs (
140
- dataset_size ,
141
- sequence_length = sequence_length ,
142
- dtype = original_dtype ,
143
- device = device ,
144
- )
145
- calibration_data = dataset [:n_calibration_examples ]
146
-
147
- # calibrate
148
- insert_awq_observer_ (
149
- m ,
150
- n_validation_examples ,
151
- sequence_length ,
152
- quant_dtype = quant_dtype ,
153
- group_size = group_size ,
154
- )
155
-
156
- for example in calibration_data :
157
- m (example .to (device ))
158
-
159
- # quantize
160
- is_observed_linear = lambda m , fqn : isinstance (m , AWQObservedLinear )
161
- quantize_ (
162
- m , awq_uintx (quant_dtype = quant_dtype , group_size = group_size ), is_observed_linear
163
- )
164
-
165
- model_save_path = "awq_model.pth"
166
- torch .save (m .state_dict (), model_save_path )
167
- m2 .load_state_dict (
168
- torch .load (model_save_path ), assign = True
169
- ) # load weights only.torch.load(model_save_path)
170
- os .remove (model_save_path )
171
-
172
- m = torch .compile (m , fullgraph = True )
173
- m2 = torch .compile (m2 , fullgraph = True )
174
-
175
- awq_out = torch .cat ([m (i .squeeze (0 )) for i in dataset ])
176
- awq_save_load_out = torch .cat ([m2 (i .squeeze (0 )) for i in dataset ])
177
-
178
- assert awq_out is not None
179
- assert awq_save_load_out is not None
180
- assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
155
+ awq_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
156
+ awq_save_load_out = torch .cat ([loaded_model (d .squeeze (0 )) for d in dataset ])
157
+
158
+ assert awq_out is not None
159
+ assert awq_save_load_out is not None
160
+ assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
161
+
162
+ def test_awq_loading_vllm (self ):
163
+ """Simulate weight loading in vllm:
164
+ * prepare model weight to the same format (awq weight)
165
+ * use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint
166
+
167
+ There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo
168
+ """
169
+ device = "cuda"
170
+ dataset_size = 100
171
+ l1 , l2 , l3 = 512 , 256 , 128
172
+ original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
173
+ group_size = 128
174
+ n_calibration_examples = 10
175
+ sequence_length = 5
176
+
177
+ m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
178
+ dataset = m .example_inputs (
179
+ dataset_size ,
180
+ sequence_length = sequence_length ,
181
+ dtype = original_dtype ,
182
+ device = device ,
183
+ )
184
+ calibration_data = dataset [:n_calibration_examples ]
185
+
186
+ # calibrate
187
+ base_config = FbgemmConfig (
188
+ input_dtype = torch .bfloat16 ,
189
+ weight_dtype = torch .int4 ,
190
+ output_dtype = torch .bfloat16 ,
191
+ block_size = [1 , group_size ],
192
+ preshuffle = False ,
193
+ )
194
+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
195
+ quantize_ (m , quant_config )
196
+
197
+ for example in calibration_data :
198
+ m (example )
199
+
200
+ # quantize
201
+ quant_config = AWQConfig (base_config , step = AWQStep .CONVERT )
202
+ quantize_ (m , quant_config )
203
+
204
+ with tempfile .NamedTemporaryFile () as f :
205
+ torch .save (m .state_dict (), f )
206
+ f .seek (0 )
207
+ state_dict = torch .load (f )
208
+
209
+ loaded_model = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
210
+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE_FOR_LOADING )
211
+ quantize_ (loaded_model , quant_config )
212
+
213
+ loaded_model .linear1 .weight .copy_ (state_dict ["linear1.weight" ])
214
+ loaded_model .linear2 .weight .copy_ (state_dict ["linear2.weight" ])
215
+ loaded_model .linear3 .weight .copy_ (state_dict ["linear3.weight" ])
216
+
217
+ m = torch .compile (m , fullgraph = True )
218
+ loaded_model = torch .compile (loaded_model , fullgraph = True )
219
+
220
+ awq_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
221
+ awq_save_load_out = torch .cat ([loaded_model (d .squeeze (0 )) for d in dataset ])
222
+
223
+ assert awq_out is not None
224
+ assert awq_save_load_out is not None
225
+ assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
226
+
227
+
228
+ if __name__ == "__main__" :
229
+ run_tests ()
0 commit comments