3
3
#
4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
- import os
7
- from copy import deepcopy
6
+ import copy
7
+ import tempfile
8
8
9
9
import pytest
10
10
import torch
11
11
12
- from torchao .quantization import quantize_
13
- from torchao .testing .utils import skip_if_rocm
12
+ from torchao .quantization import FbgemmConfig , quantize_
14
13
from torchao .utils import (
15
14
TORCH_VERSION_AT_LEAST_2_3 ,
16
15
TORCH_VERSION_AT_LEAST_2_5 ,
17
16
)
18
17
19
18
if TORCH_VERSION_AT_LEAST_2_3 :
20
- from torchao .prototype .awq import AWQObservedLinear , awq_uintx , insert_awq_observer_
19
+ from torchao .prototype .awq import AWQConfig , AWQStep
21
20
22
21
23
22
class ToyLinearModel (torch .nn .Module ):
24
23
def __init__ (self , m = 512 , n = 256 , k = 128 ):
25
24
super ().__init__ ()
26
25
self .linear1 = torch .nn .Linear (m , n , bias = False )
27
26
self .linear2 = torch .nn .Linear (n , k , bias = False )
28
- self .linear3 = torch .nn .Linear (k , 1 , bias = False )
27
+ self .linear3 = torch .nn .Linear (k , 64 , bias = False )
29
28
30
29
def example_inputs (
31
30
self , batch_size , sequence_length = 10 , dtype = torch .bfloat16 , device = "cuda"
@@ -44,36 +43,74 @@ def forward(self, x):
44
43
return x
45
44
46
45
47
- devices = ["cpu" , "cuda" ]
48
- # torch.uintx dtypes are introduced in 2.3
49
- if TORCH_VERSION_AT_LEAST_2_3 :
50
- qdtypes = (torch .uint4 , torch .uint7 )
51
- else :
52
- qdtypes = ()
53
-
54
-
55
46
@pytest .fixture (autouse = True )
56
47
def run_before_and_after_tests ():
57
48
yield
58
49
torch ._dynamo .reset () # reset cache between tests
59
50
60
51
61
- @pytest .mark .parametrize ("device" , devices )
62
- @pytest .mark .parametrize ("qdtype" , qdtypes )
63
52
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
64
53
@pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "requires nightly pytorch" )
65
- @pytest .mark .skip ("Temporarily skipping to unpin nightiles" )
66
- def test_awq_loading (device , qdtype ):
67
- if qdtype == torch .uint4 and device == "cpu" :
68
- pytest .skip ("uint4 not supported on cpu" )
54
+ def test_awq_functionality ():
55
+ device = "cuda"
56
+ dataset_size = 100
57
+ l1 , l2 , l3 = 512 , 256 , 128
58
+ original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
59
+ group_size = 128
60
+ n_calibration_examples = 10
61
+ sequence_length = 5
62
+
63
+ m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
64
+
65
+ # baseline quantization
66
+ base_config = FbgemmConfig (
67
+ input_dtype = torch .bfloat16 ,
68
+ weight_dtype = torch .int4 ,
69
+ output_dtype = torch .bfloat16 ,
70
+ block_size = [1 , group_size ],
71
+ preshuffle = False ,
72
+ )
73
+ m_baseline = copy .deepcopy (m )
74
+ quantize_ (m_baseline , base_config )
75
+
76
+ # awq quantization
77
+ dataset = m .example_inputs (
78
+ dataset_size ,
79
+ sequence_length = sequence_length ,
80
+ dtype = original_dtype ,
81
+ device = device ,
82
+ )
83
+ ref_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
84
+
85
+ calibration_data = dataset [:n_calibration_examples ]
69
86
87
+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
88
+ quantize_ (m , quant_config )
89
+
90
+ for example in calibration_data :
91
+ print ("device:" , example .device )
92
+ m (example )
93
+
94
+ quant_config = AWQConfig (base_config , step = AWQStep .CONVERT )
95
+ quantize_ (m , quant_config )
96
+
97
+ awq_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
98
+ baseline_out = torch .cat ([m_baseline (d .squeeze (0 )) for d in dataset ])
99
+
100
+ loss_awq = (ref_out - awq_out ).pow (2 ).mean ().item ()
101
+ loss_base = (ref_out - baseline_out ).pow (2 ).mean ().item ()
102
+ assert loss_awq < loss_base
103
+
104
+
105
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
106
+ @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "requires nightly pytorch" )
107
+ def test_awq_loading ():
108
+ device = "cuda"
70
109
dataset_size = 100
71
110
l1 , l2 , l3 = 512 , 256 , 128
72
111
original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
73
- quant_dtype = qdtype
74
112
group_size = 128
75
113
n_calibration_examples = 10
76
- n_validation_examples = 10
77
114
sequence_length = 5
78
115
79
116
m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
@@ -86,56 +123,60 @@ def test_awq_loading(device, qdtype):
86
123
calibration_data = dataset [:n_calibration_examples ]
87
124
88
125
# calibrate
89
- insert_awq_observer_ (
90
- m ,
91
- n_validation_examples ,
92
- sequence_length ,
93
- quant_dtype = quant_dtype ,
94
- group_size = group_size ,
126
+ base_config = FbgemmConfig (
127
+ input_dtype = torch . bfloat16 ,
128
+ weight_dtype = torch . int4 ,
129
+ output_dtype = torch . bfloat16 ,
130
+ block_size = [ 1 , group_size ] ,
131
+ preshuffle = False ,
95
132
)
133
+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
134
+ quantize_ (m , quant_config )
96
135
97
136
for example in calibration_data :
98
- m (example . to ( device ) )
137
+ m (example )
99
138
100
139
# quantize
101
- is_observed_linear = lambda m , fqn : isinstance (m , AWQObservedLinear )
102
- quantize_ (
103
- m , awq_uintx (quant_dtype = quant_dtype , group_size = group_size ), is_observed_linear
104
- )
140
+ quant_config = AWQConfig (base_config , step = AWQStep .CONVERT )
141
+ quantize_ (m , quant_config )
105
142
106
- model_save_path = "awq_model.pth"
107
- torch .save (m , model_save_path )
108
- loaded_model = torch . load ( model_save_path )
109
- os . remove ( model_save_path )
143
+ with tempfile . NamedTemporaryFile () as f :
144
+ torch .save (m . state_dict (), f )
145
+ f . seek ( 0 )
146
+ state_dict = torch . load ( f )
110
147
111
- if torch .cuda .is_available ():
112
- m = torch .compile (m , fullgraph = True )
113
- loaded_model = torch .compile (loaded_model , fullgraph = True )
148
+ loaded_model = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
149
+ loaded_model .load_state_dict (state_dict , assign = True )
114
150
115
- awq_out = torch .cat ([m (i .squeeze (0 )) for i in dataset ])
116
- awq_save_load_out = torch .cat ([loaded_model (i .squeeze (0 )) for i in dataset ])
151
+ m = torch .compile (m , fullgraph = True )
152
+ loaded_model = torch .compile (loaded_model , fullgraph = True )
153
+
154
+ awq_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
155
+ awq_save_load_out = torch .cat ([loaded_model (d .squeeze (0 )) for d in dataset ])
117
156
118
157
assert awq_out is not None
119
158
assert awq_save_load_out is not None
120
159
assert torch .allclose (awq_out , awq_save_load_out , atol = 1e-2 )
121
160
122
161
123
- @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "requires nightly pytorch" )
124
162
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
125
- @skip_if_rocm ("ROCm enablement in progress" )
126
- def test_save_weights_only ():
163
+ @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "requires nightly pytorch" )
164
+ def test_awq_loading_vllm ():
165
+ """Simulate weight loading in vllm:
166
+ * prepare model weight to the same format (awq weight)
167
+ * use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint
168
+
169
+ There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo
170
+ """
171
+ device = "cuda"
127
172
dataset_size = 100
128
173
l1 , l2 , l3 = 512 , 256 , 128
129
- original_dtype = torch .bfloat16
130
- quant_dtype = torch .uint4
131
- device = "cuda"
174
+ original_dtype = torch .bfloat16 # tinygemm kernel only uses bfloat16 inputs
132
175
group_size = 128
133
176
n_calibration_examples = 10
134
- n_validation_examples = 10
135
177
sequence_length = 5
136
178
137
179
m = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
138
- m2 = deepcopy (m )
139
180
dataset = m .example_inputs (
140
181
dataset_size ,
141
182
sequence_length = sequence_length ,
@@ -145,35 +186,41 @@ def test_save_weights_only():
145
186
calibration_data = dataset [:n_calibration_examples ]
146
187
147
188
# calibrate
148
- insert_awq_observer_ (
149
- m ,
150
- n_validation_examples ,
151
- sequence_length ,
152
- quant_dtype = quant_dtype ,
153
- group_size = group_size ,
189
+ base_config = FbgemmConfig (
190
+ input_dtype = torch . bfloat16 ,
191
+ weight_dtype = torch . int4 ,
192
+ output_dtype = torch . bfloat16 ,
193
+ block_size = [ 1 , group_size ] ,
194
+ preshuffle = False ,
154
195
)
196
+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE )
197
+ quantize_ (m , quant_config )
155
198
156
199
for example in calibration_data :
157
- m (example . to ( device ) )
200
+ m (example )
158
201
159
202
# quantize
160
- is_observed_linear = lambda m , fqn : isinstance (m , AWQObservedLinear )
161
- quantize_ (
162
- m , awq_uintx (quant_dtype = quant_dtype , group_size = group_size ), is_observed_linear
163
- )
203
+ quant_config = AWQConfig (base_config , step = AWQStep .CONVERT )
204
+ quantize_ (m , quant_config )
205
+
206
+ with tempfile .NamedTemporaryFile () as f :
207
+ torch .save (m .state_dict (), f )
208
+ f .seek (0 )
209
+ state_dict = torch .load (f )
210
+
211
+ loaded_model = ToyLinearModel (l1 , l2 , l3 ).eval ().to (original_dtype ).to (device )
212
+ quant_config = AWQConfig (base_config , step = AWQStep .PREPARE_FOR_LOADING )
213
+ quantize_ (loaded_model , quant_config )
164
214
165
- model_save_path = "awq_model.pth"
166
- torch .save (m .state_dict (), model_save_path )
167
- m2 .load_state_dict (
168
- torch .load (model_save_path ), assign = True
169
- ) # load weights only.torch.load(model_save_path)
170
- os .remove (model_save_path )
215
+ loaded_model .linear1 .weight .copy_ (state_dict ["linear1.weight" ])
216
+ loaded_model .linear2 .weight .copy_ (state_dict ["linear2.weight" ])
217
+ loaded_model .linear3 .weight .copy_ (state_dict ["linear3.weight" ])
171
218
172
219
m = torch .compile (m , fullgraph = True )
173
- m2 = torch .compile (m2 , fullgraph = True )
220
+ loaded_model = torch .compile (loaded_model , fullgraph = True )
174
221
175
- awq_out = torch .cat ([m (i .squeeze (0 )) for i in dataset ])
176
- awq_save_load_out = torch .cat ([m2 ( i .squeeze (0 )) for i in dataset ])
222
+ awq_out = torch .cat ([m (d .squeeze (0 )) for d in dataset ])
223
+ awq_save_load_out = torch .cat ([loaded_model ( d .squeeze (0 )) for d in dataset ])
177
224
178
225
assert awq_out is not None
179
226
assert awq_save_load_out is not None
0 commit comments