Skip to content

Commit ca1c7f7

Browse files
authored
[Tests] Combine quantization and dequantization tests (#443)
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 4efaf26 commit ca1c7f7

File tree

1 file changed

+18
-30
lines changed

1 file changed

+18
-30
lines changed

tests/test_quantization/lifecycle/test_forward.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@
1919
import torch
2020
from compressed_tensors.quantization.lifecycle.forward import (
2121
_process_quantization,
22-
dequantize,
22+
fake_quantize,
2323
forward_quantize,
24-
quantize,
2524
wrap_module_forward_quantized,
2625
)
2726
from compressed_tensors.quantization.lifecycle.initialize import (
@@ -96,7 +95,7 @@ def test_forward_quantize(
9695

9796

9897
@pytest.mark.parametrize(
99-
"num_bits,type,strategy,group_size,scale,zero_point,g_idx",
98+
"num_bits,type,strategy,group_size,scale,zero_point,g_idx,global_scale",
10099
[
101100
(
102101
4,
@@ -106,6 +105,7 @@ def test_forward_quantize(
106105
torch.rand((1,)) * 0.01,
107106
torch.zeros((1,)),
108107
None,
108+
None,
109109
),
110110
(
111111
4,
@@ -115,6 +115,7 @@ def test_forward_quantize(
115115
torch.rand((512, 8)) * 0.01,
116116
torch.zeros((512, 8)),
117117
None,
118+
None,
118119
),
119120
(
120121
4,
@@ -124,6 +125,7 @@ def test_forward_quantize(
124125
torch.rand((512, 8)) * 0.01,
125126
torch.zeros((512, 8)),
126127
make_dummy_g_idx(1024, 128),
128+
None,
127129
),
128130
(
129131
8,
@@ -133,6 +135,7 @@ def test_forward_quantize(
133135
torch.rand((1,)) * 0.01,
134136
torch.zeros((1,)),
135137
None,
138+
None,
136139
),
137140
(
138141
8,
@@ -142,6 +145,7 @@ def test_forward_quantize(
142145
torch.rand((512, 8)) * 0.01,
143146
torch.zeros((512, 8)),
144147
None,
148+
None,
145149
),
146150
(
147151
8,
@@ -151,28 +155,8 @@ def test_forward_quantize(
151155
torch.rand((512, 8)) * 0.01,
152156
torch.zeros((512, 8)),
153157
make_dummy_g_idx(1024, 128),
158+
None,
154159
),
155-
],
156-
)
157-
def test_quantize(num_bits, type, strategy, group_size, scale, zero_point, g_idx):
158-
args = QuantizationArgs(
159-
num_bits=num_bits, type=type, strategy=strategy, group_size=group_size
160-
)
161-
162-
x = torch.rand((512, 1024))
163-
quantize(
164-
x=x,
165-
scale=scale,
166-
zero_point=zero_point,
167-
args=args,
168-
dtype=args.pytorch_dtype(),
169-
g_idx=g_idx,
170-
)
171-
172-
173-
@pytest.mark.parametrize(
174-
"num_bits,type,strategy,group_size,scale,zero_point,g_idx",
175-
[
176160
(
177161
8,
178162
"int",
@@ -181,6 +165,7 @@ def test_quantize(num_bits, type, strategy, group_size, scale, zero_point, g_idx
181165
torch.rand((512, 8)) * 0.01,
182166
torch.zeros((512, 8)),
183167
None,
168+
None,
184169
),
185170
(
186171
8,
@@ -190,23 +175,26 @@ def test_quantize(num_bits, type, strategy, group_size, scale, zero_point, g_idx
190175
torch.rand((512, 8)) * 0.01,
191176
torch.zeros((512, 8)),
192177
make_dummy_g_idx(1024, 128),
178+
None,
193179
),
194180
],
195181
)
196-
def test_dequantize(num_bits, type, strategy, group_size, scale, zero_point, g_idx):
182+
def test_fake_quantize_2d(
183+
num_bits, type, strategy, group_size, scale, zero_point, g_idx, global_scale
184+
):
197185
args = QuantizationArgs(
198186
num_bits=num_bits, type=type, strategy=strategy, group_size=group_size
199187
)
200188

201-
x_q = torch.rand((512, 1024)).to(dtype=args.pytorch_dtype())
202-
dequantize(
203-
x_q=x_q,
189+
x = torch.rand((512, 1024))
190+
fake_quantize(
191+
x=x,
204192
scale=scale,
205193
zero_point=zero_point,
206194
args=args,
207-
dtype=None,
208195
g_idx=g_idx,
209-
)
196+
global_scale=global_scale,
197+
) # note that reconstruction loss is bad for uncalibrated scales
210198

211199

212200
def test_process_quantization_block_static():

0 commit comments

Comments
 (0)