19
19
import torch
20
20
from compressed_tensors .quantization .lifecycle .forward import (
21
21
_process_quantization ,
22
- dequantize ,
22
+ fake_quantize ,
23
23
forward_quantize ,
24
- quantize ,
25
24
wrap_module_forward_quantized ,
26
25
)
27
26
from compressed_tensors .quantization .lifecycle .initialize import (
@@ -96,7 +95,7 @@ def test_forward_quantize(
96
95
97
96
98
97
@pytest .mark .parametrize (
99
- "num_bits,type,strategy,group_size,scale,zero_point,g_idx" ,
98
+ "num_bits,type,strategy,group_size,scale,zero_point,g_idx,global_scale " ,
100
99
[
101
100
(
102
101
4 ,
@@ -106,6 +105,7 @@ def test_forward_quantize(
106
105
torch .rand ((1 ,)) * 0.01 ,
107
106
torch .zeros ((1 ,)),
108
107
None ,
108
+ None ,
109
109
),
110
110
(
111
111
4 ,
@@ -115,6 +115,7 @@ def test_forward_quantize(
115
115
torch .rand ((512 , 8 )) * 0.01 ,
116
116
torch .zeros ((512 , 8 )),
117
117
None ,
118
+ None ,
118
119
),
119
120
(
120
121
4 ,
@@ -124,6 +125,7 @@ def test_forward_quantize(
124
125
torch .rand ((512 , 8 )) * 0.01 ,
125
126
torch .zeros ((512 , 8 )),
126
127
make_dummy_g_idx (1024 , 128 ),
128
+ None ,
127
129
),
128
130
(
129
131
8 ,
@@ -133,6 +135,7 @@ def test_forward_quantize(
133
135
torch .rand ((1 ,)) * 0.01 ,
134
136
torch .zeros ((1 ,)),
135
137
None ,
138
+ None ,
136
139
),
137
140
(
138
141
8 ,
@@ -142,6 +145,7 @@ def test_forward_quantize(
142
145
torch .rand ((512 , 8 )) * 0.01 ,
143
146
torch .zeros ((512 , 8 )),
144
147
None ,
148
+ None ,
145
149
),
146
150
(
147
151
8 ,
@@ -151,28 +155,8 @@ def test_forward_quantize(
151
155
torch .rand ((512 , 8 )) * 0.01 ,
152
156
torch .zeros ((512 , 8 )),
153
157
make_dummy_g_idx (1024 , 128 ),
158
+ None ,
154
159
),
155
- ],
156
- )
157
- def test_quantize (num_bits , type , strategy , group_size , scale , zero_point , g_idx ):
158
- args = QuantizationArgs (
159
- num_bits = num_bits , type = type , strategy = strategy , group_size = group_size
160
- )
161
-
162
- x = torch .rand ((512 , 1024 ))
163
- quantize (
164
- x = x ,
165
- scale = scale ,
166
- zero_point = zero_point ,
167
- args = args ,
168
- dtype = args .pytorch_dtype (),
169
- g_idx = g_idx ,
170
- )
171
-
172
-
173
- @pytest .mark .parametrize (
174
- "num_bits,type,strategy,group_size,scale,zero_point,g_idx" ,
175
- [
176
160
(
177
161
8 ,
178
162
"int" ,
@@ -181,6 +165,7 @@ def test_quantize(num_bits, type, strategy, group_size, scale, zero_point, g_idx
181
165
torch .rand ((512 , 8 )) * 0.01 ,
182
166
torch .zeros ((512 , 8 )),
183
167
None ,
168
+ None ,
184
169
),
185
170
(
186
171
8 ,
@@ -190,23 +175,26 @@ def test_quantize(num_bits, type, strategy, group_size, scale, zero_point, g_idx
190
175
torch .rand ((512 , 8 )) * 0.01 ,
191
176
torch .zeros ((512 , 8 )),
192
177
make_dummy_g_idx (1024 , 128 ),
178
+ None ,
193
179
),
194
180
],
195
181
)
196
- def test_dequantize (num_bits , type , strategy , group_size , scale , zero_point , g_idx ):
182
+ def test_fake_quantize_2d (
183
+ num_bits , type , strategy , group_size , scale , zero_point , g_idx , global_scale
184
+ ):
197
185
args = QuantizationArgs (
198
186
num_bits = num_bits , type = type , strategy = strategy , group_size = group_size
199
187
)
200
188
201
- x_q = torch .rand ((512 , 1024 )). to ( dtype = args . pytorch_dtype ( ))
202
- dequantize (
203
- x_q = x_q ,
189
+ x = torch .rand ((512 , 1024 ))
190
+ fake_quantize (
191
+ x = x ,
204
192
scale = scale ,
205
193
zero_point = zero_point ,
206
194
args = args ,
207
- dtype = None ,
208
195
g_idx = g_idx ,
209
- )
196
+ global_scale = global_scale ,
197
+ ) # note that reconstruction loss is bad for uncalibrated scales
210
198
211
199
212
200
def test_process_quantization_block_static ():
0 commit comments