Skip to content

Commit 785f3dd

Browse files
authored
Convert SmoothQuant test to unittest (#2659)
* Convert SmoothQuant test to unittest * refactor using `common_utils.parametrize` decorator * incline quantizaztion setup function * undefine only-one time used functions - Uncorrect API usage ( `common_utils`) is fixed * replace unittest.SkipTest with unittest.skipIf
1 parent 77b2127 commit 785f3dd

File tree

1 file changed

+216
-137
lines changed

1 file changed

+216
-137
lines changed

test/prototype/test_smoothquant.py

Lines changed: 216 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
import tempfile
7+
import unittest
78
from copy import deepcopy
89

9-
import pytest
1010
import torch
11+
from torch.testing._internal import common_utils
1112

1213
from torchao.prototype.smoothquant import (
1314
SmoothQuantConfig,
@@ -25,9 +26,6 @@
2526
TORCH_VERSION_AT_LEAST_2_5,
2627
)
2728

28-
if torch.version.hip is not None:
29-
pytest.skip("Skipping the test in ROCm", allow_module_level=True)
30-
3129

3230
class ToyLinearModel(torch.nn.Module):
3331
def __init__(self, m=512, n=256, k=128):
@@ -53,143 +51,224 @@ def forward(self, x):
5351
return x
5452

5553

56-
bias_list = [True, False]
57-
alpha_list = [None, 0.5, 0.75]
58-
quant_mode_list = ["static", "dynamic"]
59-
devices = ["cpu"]
60-
if torch.cuda.is_available():
61-
devices.append("cuda")
62-
idtypes = (torch.float, torch.bfloat16, torch.half)
63-
64-
if TORCH_VERSION_AT_LEAST_2_5:
65-
# This test case will trigger recompilation many times, so set a large cache_size_limit here
66-
torch._dynamo.config.cache_size_limit = 128
67-
68-
69-
@pytest.mark.parametrize("bias", bias_list)
70-
@pytest.mark.parametrize("alpha", alpha_list)
71-
@pytest.mark.parametrize("quant_mode", quant_mode_list)
72-
@pytest.mark.parametrize("device", devices)
73-
@pytest.mark.parametrize("idtype", idtypes)
74-
@pytest.mark.skip("this test is broken on recent PyTorch, TODO(#1639): fix it")
75-
def test_compute(bias, alpha, quant_mode, device, idtype):
76-
class Linear(torch.nn.Module):
77-
def __init__(self, bias: bool):
78-
super().__init__()
79-
self.fc = torch.nn.Linear(32, 32, bias)
80-
self.fc.weight.data = torch.randn_like(self.fc.weight.data)
81-
82-
def forward(self, x):
83-
return self.fc(x)
84-
85-
m = Linear(bias).eval().to(idtype).to(device)
86-
m_ref = deepcopy(m)
87-
data = torch.randn(2, 32, dtype=idtype, device=device)
88-
89-
# calibrate
90-
insert_smooth_quant_observer_(m, alpha, quant_mode)
91-
m(data)
92-
# quantize
93-
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
94-
quantize_(m, SmoothQuantConfig(), is_observed_linear)
95-
with torch.inference_mode():
54+
@unittest.skipIf(torch.version.hip is not None, "Skipping tests in ROCm")
55+
class TestSmoothQuant(unittest.TestCase):
56+
@classmethod
57+
def setUpClass(cls):
58+
"""Set up class-level configuration for tests."""
59+
if TORCH_VERSION_AT_LEAST_2_5:
60+
# This test case will trigger recompilation many times, so set a large cache_size_limit here
61+
torch._dynamo.config.cache_size_limit = 128
62+
63+
@unittest.skip("This test is broken on recent PyTorch, TODO(#1639): fix it")
64+
@common_utils.parametrize("bias", [True, False])
65+
@common_utils.parametrize("alpha", [None, 0.5, 0.75])
66+
@common_utils.parametrize("quant_mode", ["static", "dynamic"])
67+
@common_utils.parametrize(
68+
"device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
69+
)
70+
@common_utils.parametrize("input_dtype", [torch.float, torch.bfloat16, torch.half])
71+
def test_smoothquant_accuracy(self, bias, alpha, quant_mode, device, input_dtype):
72+
"""Test the margin error of SmoothQuant across bias, alpha, dtype, etc."""
73+
74+
class SimpleLinear(torch.nn.Module):
75+
def __init__(self, bias: bool):
76+
super().__init__()
77+
self.fc = torch.nn.Linear(32, 32, bias)
78+
self.fc.weight.data = torch.randn_like(self.fc.weight.data)
79+
80+
def forward(self, x):
81+
return self.fc(x)
82+
83+
# Create model, reference, and test data
84+
m = SimpleLinear(bias).eval().to(input_dtype).to(device)
85+
m_ref = deepcopy(m)
86+
test_data = torch.randn(2, 32, dtype=input_dtype, device=device)
87+
88+
# Step 1: Setup quantized model with observer insertion and calibration
89+
insert_smooth_quant_observer_(m, alpha, quant_mode)
90+
91+
# Perform calibration with test data
92+
m(test_data)
93+
94+
# Apply quantization configuration
95+
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
96+
quantize_(m, SmoothQuantConfig(), is_observed_linear)
97+
98+
# Apply compilation if supported
9699
if TORCH_VERSION_AT_LEAST_2_5:
97100
m = torch.compile(m, fullgraph=True)
98-
out = m(data)
99-
100-
# reference
101-
weight = m_ref.fc.weight.data.float()
102-
b = m_ref.fc.bias if bias else None
103-
x_abs_max_per_ic = torch.abs(data).max(dim=0).values
104-
w_abs_max_per_ic = torch.abs(weight).max(dim=0).values
105-
smoothing_factor = (
106-
1
107-
if alpha is None
108-
else (
109-
torch.pow(x_abs_max_per_ic, alpha)
110-
/ torch.pow(w_abs_max_per_ic, 1 - alpha)
101+
102+
# Step 2: Inference quantized model
103+
with torch.inference_mode():
104+
q_out = m(test_data)
105+
106+
# Step 3: Compute reference
107+
weight = m_ref.fc.weight.data.float()
108+
b = m_ref.fc.bias if bias else None
109+
x_abs_max_per_ic = torch.abs(test_data).max(dim=0).values
110+
w_abs_max_per_ic = torch.abs(weight).max(dim=0).values
111+
112+
if alpha is not None:
113+
# Apply SmoothQuant
114+
smoothing_factor = torch.pow(x_abs_max_per_ic, alpha) / torch.pow(
115+
w_abs_max_per_ic, 1 - alpha
116+
)
117+
else:
118+
smoothing_factor = torch.ones_like(x_abs_max_per_ic)
119+
120+
# Apply smoothing to activations and weights
121+
smoothed_activation = test_data / smoothing_factor
122+
smoothed_weight = weight * smoothing_factor
123+
124+
# Quantize weights using per-channel quantization
125+
qw, w_scales, w_zps = dynamically_quantize_per_channel(
126+
smoothed_weight, -127, 127, torch.int8
111127
)
112-
)
113-
act = data / smoothing_factor
114-
wei = weight * smoothing_factor
115-
qw, w_scales, w_zps = dynamically_quantize_per_channel(
116-
wei, -127, 127, torch.int8
117-
)
118-
fq_wei = dequantize_per_channel(qw, w_scales, w_zps, idtype)
119-
if quant_mode == "static":
120-
# activation is quantized per-tensor
121-
act_min, act_max = torch.aminmax(act.float())
122-
max_val_pos = torch.max(-act_min, act_max)
123-
act_scale = max_val_pos / 127.0
124-
fq_act = (
125-
torch.quantize_per_tensor(
126-
act.float(), scale=act_scale.item(), zero_point=0, dtype=torch.qint8
128+
fq_wei = dequantize_per_channel(qw, w_scales, w_zps, input_dtype)
129+
130+
# Handle activation quantization based on mode
131+
if quant_mode == "static":
132+
# activation is quantized per-tensor
133+
act_min, act_max = torch.aminmax(smoothed_activation.float())
134+
max_val_pos = torch.max(-act_min, act_max)
135+
activation_scale = max_val_pos / 127.0
136+
137+
fq_act = (
138+
torch.quantize_per_tensor(
139+
smoothed_activation.float(),
140+
scale=activation_scale.item(),
141+
zero_point=0,
142+
dtype=torch.qint8,
143+
)
144+
.dequantize()
145+
.to(input_dtype)
146+
)
147+
else:
148+
# activation is quantized per-row (batch * sequence_length)
149+
qx, x_scales, x_zps = dynamically_quantize_per_channel(
150+
smoothed_activation.float(), -127, 127, torch.int8
151+
)
152+
fq_act = dequantize_per_channel(
153+
qx,
154+
x_scales,
155+
x_zps,
156+
input_dtype,
127157
)
128-
.dequantize()
129-
.to(idtype)
158+
159+
# Compute final linear operation
160+
reference_out = torch.nn.functional.linear(fq_act, fq_wei, b)
161+
162+
# Step 4: Validate numerical accuracy
163+
tolerance = (
164+
0.1
165+
if input_dtype == torch.float
166+
else (0.2 if input_dtype == torch.half else 0.3)
130167
)
131-
out_ref = torch.nn.functional.linear(fq_act, fq_wei, b)
132-
else:
133-
# activation is quantized per-row (batch * sequence_length)
134-
qx, x_scales, x_zps = dynamically_quantize_per_channel(
135-
act.float(), -127, 127, torch.int8
168+
torch.testing.assert_close(
169+
q_out,
170+
reference_out.to(input_dtype),
171+
atol=tolerance,
172+
msg=f"Quantized output differs from reference for "
173+
f"bias={bias}, alpha={alpha}, quant_mode={quant_mode}, "
174+
f"device={device}, dtype={input_dtype}",
136175
)
137-
fq_act = dequantize_per_channel(qx, x_scales, x_zps, idtype)
138-
out_ref = torch.nn.functional.linear(fq_act, fq_wei, b)
139-
140-
# BFloat16 and Float16 have larger errors
141-
atol = 0.1 if idtype == torch.float else (0.2 if idtype == torch.half else 0.3)
142-
assert torch.allclose(out, out_ref.to(idtype), atol=atol)
143-
144-
145-
@pytest.mark.parametrize("alpha", alpha_list)
146-
@pytest.mark.parametrize("quant_mode", quant_mode_list)
147-
@pytest.mark.parametrize("device", devices)
148-
@pytest.mark.parametrize("idtype", idtypes)
149-
@pytest.mark.skip("this test is broken on recent PyTorch, TODO(#1639): fix it")
150-
def test_save_load_recipe(alpha, quant_mode, device, idtype):
151-
dataset_size = 20
152-
l1, l2, l3 = 512, 256, 128
153-
original_dtype = idtype
154-
n_calib_examples = 10
155-
sequence_length = 5
156-
157-
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
158-
m_save_load = deepcopy(m)
159-
160-
dataset = m.example_inputs(
161-
dataset_size,
162-
sequence_length=sequence_length,
163-
dtype=original_dtype,
164-
device=device,
176+
177+
@unittest.skip("This test is broken on recent PyTorch, TODO(#1639): fix it")
178+
@common_utils.parametrize("alpha", [None, 0.5, 0.75])
179+
@common_utils.parametrize("quant_mode", ["static", "dynamic"])
180+
@common_utils.parametrize(
181+
"device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
165182
)
166-
calibration_data = dataset[:n_calib_examples]
167-
168-
# calibrate
169-
insert_smooth_quant_observer_(m, alpha, quant_mode)
170-
insert_smooth_quant_observer_(m_save_load, alpha, quant_mode)
171-
172-
for example in calibration_data:
173-
m(example.to(device))
174-
m_save_load(example.to(device))
175-
176-
with tempfile.NamedTemporaryFile() as fp:
177-
save_path = fp.name
178-
save_smooth_quant_recipe(m_save_load, save_path)
179-
load_smooth_quant_recipe(m_save_load, save_path)
180-
181-
# quantize
182-
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
183-
quantize_(m, SmoothQuantConfig(), is_observed_linear)
184-
if TORCH_VERSION_AT_LEAST_2_5:
185-
# earlier versions are not compatible
186-
m = torch.compile(m, fullgraph=True)
187-
m_save_load = torch.compile(m_save_load, fullgraph=True)
188-
out_list = [m(data.squeeze(0)) for data in dataset]
189-
out = torch.cat(out_list)
190-
save_load_out_list = [m_save_load(data.squeeze(0)) for data in dataset]
191-
save_load_out = torch.cat(save_load_out_list)
192-
193-
assert out is not None
194-
assert save_load_out is not None
195-
assert torch.allclose(out, save_load_out)
183+
@common_utils.parametrize("input_dtype", [torch.float, torch.bfloat16, torch.half])
184+
def test_save_load_recipe(self, alpha, quant_mode, device, input_dtype):
185+
"""Test save/load recipe functionality."""
186+
dataset_size = 20
187+
layer_dims = (512, 256, 128) # Input, hidden, output dimensions
188+
n_calib_examples = 10
189+
sequence_length = 5
190+
191+
# Create two identical models for comparison
192+
m = ToyLinearModel(*layer_dims).eval().to(input_dtype).to(device)
193+
m_save_load = deepcopy(m)
194+
195+
# Generate calibration dataset
196+
dataset = m.example_inputs(
197+
dataset_size,
198+
sequence_length=sequence_length,
199+
dtype=input_dtype,
200+
device=device,
201+
)
202+
calibration_data = dataset[:n_calib_examples]
203+
204+
# Step 1: Setup first quantized model with observer insertion and calibration
205+
insert_smooth_quant_observer_(m, alpha, quant_mode)
206+
207+
# Perform calibration with calibration data
208+
for data in calibration_data:
209+
m(data)
210+
211+
# Apply quantization configuration
212+
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
213+
quantize_(m, SmoothQuantConfig(), is_observed_linear)
214+
215+
# Apply compilation if supported
216+
if TORCH_VERSION_AT_LEAST_2_5:
217+
m = torch.compile(m, fullgraph=True)
218+
219+
# Step 2: Setup save/load model with recipe functionality
220+
insert_smooth_quant_observer_(m_save_load, alpha, quant_mode)
221+
for example in calibration_data:
222+
m_save_load(example.to(device))
223+
224+
# Step 3: Test save/load recipe functionality
225+
with tempfile.NamedTemporaryFile() as temp_file:
226+
save_path = temp_file.name
227+
save_smooth_quant_recipe(m_save_load, save_path)
228+
load_smooth_quant_recipe(m_save_load, save_path)
229+
230+
# Step 4: Complete quantization for save/load model
231+
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
232+
quantize_(m_save_load, SmoothQuantConfig(), is_observed_linear)
233+
234+
if TORCH_VERSION_AT_LEAST_2_5:
235+
m_save_load = torch.compile(m_save_load, fullgraph=True)
236+
237+
# Step 5: Validate outputs on full dataset
238+
with torch.inference_mode():
239+
original_outputs = []
240+
save_load_outputs = []
241+
242+
for data in dataset:
243+
# Remove batch dimension for model input
244+
input_tensor = data.squeeze(0)
245+
246+
original_output = m(input_tensor)
247+
save_load_output = m_save_load(input_tensor)
248+
249+
original_outputs.append(original_output)
250+
save_load_outputs.append(save_load_output)
251+
252+
# Concatenate all outputs for comparison
253+
original_result = torch.cat(original_outputs)
254+
save_load_out = torch.cat(save_load_outputs)
255+
256+
self.assertIsNotNone(
257+
original_result, "Original model output should not be None"
258+
)
259+
self.assertIsNotNone(
260+
save_load_out, "Save/load model output should not be None"
261+
)
262+
263+
torch.testing.assert_close(
264+
original_result,
265+
save_load_out,
266+
msg=f"Save/load recipe should produce identical results for "
267+
f"alpha={alpha}, quant_mode={quant_mode}, device={device}, dtype={input_dtype}",
268+
)
269+
270+
271+
common_utils.instantiate_parametrized_tests(TestSmoothQuant)
272+
273+
if __name__ == "__main__":
274+
unittest.main()

0 commit comments

Comments
 (0)