-
Notifications
You must be signed in to change notification settings - Fork 322
Convert SmoothQuant test to unittest #2659
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
e44599e
e4a6b02
a6dd8bd
3ffdd10
7217f8f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,9 +4,9 @@ | |
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import tempfile | ||
import unittest | ||
from copy import deepcopy | ||
|
||
import pytest | ||
import torch | ||
|
||
from torchao.prototype.smoothquant import ( | ||
|
@@ -21,13 +21,11 @@ | |
dequantize_per_channel, | ||
dynamically_quantize_per_channel, | ||
) | ||
from torchao.testing import common_utils | ||
from torchao.utils import ( | ||
TORCH_VERSION_AT_LEAST_2_5, | ||
) | ||
|
||
if torch.version.hip is not None: | ||
pytest.skip("Skipping the test in ROCm", allow_module_level=True) | ||
|
||
|
||
class ToyLinearModel(torch.nn.Module): | ||
def __init__(self, m=512, n=256, k=128): | ||
|
@@ -53,143 +51,232 @@ def forward(self, x): | |
return x | ||
|
||
|
||
bias_list = [True, False] | ||
alpha_list = [None, 0.5, 0.75] | ||
quant_mode_list = ["static", "dynamic"] | ||
devices = ["cpu"] | ||
if torch.cuda.is_available(): | ||
devices.append("cuda") | ||
idtypes = (torch.float, torch.bfloat16, torch.half) | ||
|
||
if TORCH_VERSION_AT_LEAST_2_5: | ||
# This test case will trigger recompilation many times, so set a large cache_size_limit here | ||
torch._dynamo.config.cache_size_limit = 128 | ||
|
||
|
||
@pytest.mark.parametrize("bias", bias_list) | ||
@pytest.mark.parametrize("alpha", alpha_list) | ||
@pytest.mark.parametrize("quant_mode", quant_mode_list) | ||
@pytest.mark.parametrize("device", devices) | ||
@pytest.mark.parametrize("idtype", idtypes) | ||
@pytest.mark.skip("this test is broken on recent PyTorch, TODO(#1639): fix it") | ||
def test_compute(bias, alpha, quant_mode, device, idtype): | ||
class Linear(torch.nn.Module): | ||
def __init__(self, bias: bool): | ||
super().__init__() | ||
self.fc = torch.nn.Linear(32, 32, bias) | ||
self.fc.weight.data = torch.randn_like(self.fc.weight.data) | ||
|
||
def forward(self, x): | ||
return self.fc(x) | ||
|
||
m = Linear(bias).eval().to(idtype).to(device) | ||
m_ref = deepcopy(m) | ||
data = torch.randn(2, 32, dtype=idtype, device=device) | ||
|
||
# calibrate | ||
insert_smooth_quant_observer_(m, alpha, quant_mode) | ||
m(data) | ||
# quantize | ||
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) | ||
quantize_(m, SmoothQuantConfig(), is_observed_linear) | ||
with torch.inference_mode(): | ||
class TestSmoothQuant(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
"""Set up class-level configuration for tests.""" | ||
# Skip tests on ROCm (AMD GPU) due to compatibility issues | ||
if torch.version.hip is not None: | ||
raise unittest.SkipTest("Skipping the tests in ROCm") | ||
|
||
if TORCH_VERSION_AT_LEAST_2_5: | ||
# This test case will trigger recompilation many times, so set a large cache_size_limit here | ||
torch._dynamo.config.cache_size_limit = 128 | ||
|
||
def _setup_quantized_model(self, model, alpha, quant_mode, calibration_data): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel it's probably fine to inline these as well, there are not many lines of code here |
||
"""Setup for quantized models with observer insertion and calibration.""" | ||
insert_smooth_quant_observer_(model, alpha, quant_mode) | ||
|
||
for data in calibration_data: | ||
model(data) | ||
|
||
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) | ||
quantize_(model, SmoothQuantConfig(), is_observed_linear) | ||
|
||
if TORCH_VERSION_AT_LEAST_2_5: | ||
m = torch.compile(m, fullgraph=True) | ||
out = m(data) | ||
model = torch.compile(model, fullgraph=True) | ||
|
||
return model | ||
|
||
@unittest.skip("This test is broken on recent PyTorch, TODO(#1639): fix it") | ||
@common_utils.parametrize("bias", [True, False]) | ||
@common_utils.parametrize("alpha", [None, 0.5, 0.75]) | ||
@common_utils.parametrize("quant_mode", ["static", "dynamic"]) | ||
@common_utils.parametrize( | ||
"device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) | ||
) | ||
@common_utils.parametrize("input_dtype", [torch.float, torch.bfloat16, torch.half]) | ||
def test_smoothquant_accuracy(self, bias, alpha, quant_mode, device, input_dtype): | ||
"""Test the margin error of SmoothQuant across bias, alpha, dtype, etc.""" | ||
self._run_compute_accuracy_test(bias, alpha, quant_mode, device, input_dtype) | ||
|
||
def _run_compute_accuracy_test(self, bias, alpha, quant_mode, device, input_dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can also remove this function, seems no need to define a function here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, those functions are undefined for more brevity in 3ffdd10 . |
||
"""Single compute accuracy test""" | ||
|
||
class SimpleLinear(torch.nn.Module): | ||
def __init__(self, bias: bool): | ||
super().__init__() | ||
self.fc = torch.nn.Linear(32, 32, bias) | ||
self.fc.weight.data = torch.randn_like(self.fc.weight.data) | ||
|
||
def forward(self, x): | ||
return self.fc(x) | ||
|
||
# reference | ||
# Create model, reference, and test data | ||
m = SimpleLinear(bias).eval().to(input_dtype).to(device) | ||
m_ref = deepcopy(m) | ||
test_data = torch.randn(2, 32, dtype=input_dtype, device=device) | ||
|
||
# Step 1: Get calibration from observed SmoothQuant | ||
m = self._setup_quantized_model(m, alpha, quant_mode, [test_data]) | ||
|
||
# Step 2: Inference quantized model | ||
with torch.inference_mode(): | ||
q_out = m(test_data) | ||
|
||
# Step 3: Compute reference | ||
reference_out = self._compute_reference_out( | ||
m_ref, test_data, alpha, quant_mode, bias, input_dtype | ||
) | ||
|
||
# Step 4: Validate numerical accuracy | ||
tolerance = ( | ||
0.1 | ||
if input_dtype == torch.float | ||
else (0.2 if input_dtype == torch.half else 0.3) | ||
) | ||
torch.testing.assert_close( | ||
q_out, | ||
reference_out.to(input_dtype), | ||
atol=tolerance, | ||
msg=f"Quantized output differs from reference for " | ||
f"bias={bias}, alpha={alpha}, quant_mode={quant_mode}, " | ||
f"device={device}, dtype={input_dtype}", | ||
) | ||
|
||
def _compute_reference_out(self, m_ref, data, alpha, quant_mode, bias, input_dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if it's used only once, I'd suggest not to define new functions |
||
"""Compute the expected SmoothQuant output.""" | ||
weight = m_ref.fc.weight.data.float() | ||
b = m_ref.fc.bias if bias else None | ||
x_abs_max_per_ic = torch.abs(data).max(dim=0).values | ||
w_abs_max_per_ic = torch.abs(weight).max(dim=0).values | ||
smoothing_factor = ( | ||
1 | ||
if alpha is None | ||
else ( | ||
torch.pow(x_abs_max_per_ic, alpha) | ||
/ torch.pow(w_abs_max_per_ic, 1 - alpha) | ||
if alpha is not None: | ||
# Apply SmoothQuant | ||
smoothing_factor = torch.pow(x_abs_max_per_ic, alpha) / torch.pow( | ||
w_abs_max_per_ic, 1 - alpha | ||
) | ||
) | ||
act = data / smoothing_factor | ||
wei = weight * smoothing_factor | ||
else: | ||
smoothing_factor = torch.ones_like(x_abs_max_per_ic) | ||
|
||
# Apply smoothing to activations and weights | ||
smoothed_activation = data / smoothing_factor | ||
smoothed_weight = weight * smoothing_factor | ||
|
||
# Quantize weights using per-channel quantization | ||
qw, w_scales, w_zps = dynamically_quantize_per_channel( | ||
wei, -127, 127, torch.int8 | ||
smoothed_weight, -127, 127, torch.int8 | ||
) | ||
fq_wei = dequantize_per_channel(qw, w_scales, w_zps, idtype) | ||
fq_wei = dequantize_per_channel(qw, w_scales, w_zps, input_dtype) | ||
|
||
# Handle activation quantization based on mode | ||
if quant_mode == "static": | ||
# activation is quantized per-tensor | ||
act_min, act_max = torch.aminmax(act.float()) | ||
act_min, act_max = torch.aminmax(smoothed_activation.float()) | ||
max_val_pos = torch.max(-act_min, act_max) | ||
act_scale = max_val_pos / 127.0 | ||
activation_scale = max_val_pos / 127.0 | ||
|
||
fq_act = ( | ||
torch.quantize_per_tensor( | ||
act.float(), scale=act_scale.item(), zero_point=0, dtype=torch.qint8 | ||
smoothed_activation.float(), | ||
scale=activation_scale.item(), | ||
zero_point=0, | ||
dtype=torch.qint8, | ||
) | ||
.dequantize() | ||
.to(idtype) | ||
.to(input_dtype) | ||
) | ||
out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) | ||
else: | ||
# activation is quantized per-row (batch * sequence_length) | ||
qx, x_scales, x_zps = dynamically_quantize_per_channel( | ||
act.float(), -127, 127, torch.int8 | ||
smoothed_activation.float(), -127, 127, torch.int8 | ||
) | ||
fq_act = dequantize_per_channel( | ||
qx, | ||
x_scales, | ||
x_zps, | ||
input_dtype, | ||
) | ||
fq_act = dequantize_per_channel(qx, x_scales, x_zps, idtype) | ||
out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) | ||
|
||
# BFloat16 and Float16 have larger errors | ||
atol = 0.1 if idtype == torch.float else (0.2 if idtype == torch.half else 0.3) | ||
assert torch.allclose(out, out_ref.to(idtype), atol=atol) | ||
|
||
|
||
@pytest.mark.parametrize("alpha", alpha_list) | ||
@pytest.mark.parametrize("quant_mode", quant_mode_list) | ||
@pytest.mark.parametrize("device", devices) | ||
@pytest.mark.parametrize("idtype", idtypes) | ||
@pytest.mark.skip("this test is broken on recent PyTorch, TODO(#1639): fix it") | ||
def test_save_load_recipe(alpha, quant_mode, device, idtype): | ||
dataset_size = 20 | ||
l1, l2, l3 = 512, 256, 128 | ||
original_dtype = idtype | ||
n_calib_examples = 10 | ||
sequence_length = 5 | ||
|
||
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device) | ||
m_save_load = deepcopy(m) | ||
|
||
dataset = m.example_inputs( | ||
dataset_size, | ||
sequence_length=sequence_length, | ||
dtype=original_dtype, | ||
device=device, | ||
|
||
# Compute final linear operation | ||
return torch.nn.functional.linear(fq_act, fq_wei, b) | ||
|
||
@unittest.skip("This test is broken on recent PyTorch, TODO(#1639): fix it") | ||
@common_utils.parametrize("alpha", [None, 0.5, 0.75]) | ||
@common_utils.parametrize("quant_mode", ["static", "dynamic"]) | ||
@common_utils.parametrize( | ||
"device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) | ||
) | ||
calibration_data = dataset[:n_calib_examples] | ||
|
||
# calibrate | ||
insert_smooth_quant_observer_(m, alpha, quant_mode) | ||
insert_smooth_quant_observer_(m_save_load, alpha, quant_mode) | ||
|
||
for example in calibration_data: | ||
m(example.to(device)) | ||
m_save_load(example.to(device)) | ||
|
||
with tempfile.NamedTemporaryFile() as fp: | ||
save_path = fp.name | ||
save_smooth_quant_recipe(m_save_load, save_path) | ||
load_smooth_quant_recipe(m_save_load, save_path) | ||
|
||
# quantize | ||
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) | ||
quantize_(m, SmoothQuantConfig(), is_observed_linear) | ||
if TORCH_VERSION_AT_LEAST_2_5: | ||
# earlier versions are not compatible | ||
m = torch.compile(m, fullgraph=True) | ||
m_save_load = torch.compile(m_save_load, fullgraph=True) | ||
out_list = [m(data.squeeze(0)) for data in dataset] | ||
out = torch.cat(out_list) | ||
save_load_out_list = [m_save_load(data.squeeze(0)) for data in dataset] | ||
save_load_out = torch.cat(save_load_out_list) | ||
|
||
assert out is not None | ||
assert save_load_out is not None | ||
assert torch.allclose(out, save_load_out) | ||
@common_utils.parametrize("input_dtype", [torch.float, torch.bfloat16, torch.half]) | ||
def test_save_load_recipe(self, alpha, quant_mode, device, input_dtype): | ||
"""Test save/load recipe functionality.""" | ||
self._run_save_load_recipe_test(alpha, quant_mode, device, input_dtype) | ||
|
||
def _run_save_load_recipe_test(self, alpha, quant_mode, device, input_dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also this one |
||
"""Single save/load recipe test.""" | ||
dataset_size = 20 | ||
layer_dims = (512, 256, 128) # Input, hidden, output dimensions | ||
n_calib_examples = 10 | ||
sequence_length = 5 | ||
|
||
# Create two identical models for comparison | ||
m = ToyLinearModel(*layer_dims).eval().to(input_dtype).to(device) | ||
m_save_load = deepcopy(m) | ||
|
||
# Generate calibration dataset | ||
dataset = m.example_inputs( | ||
dataset_size, | ||
sequence_length=sequence_length, | ||
dtype=input_dtype, | ||
device=device, | ||
) | ||
calibration_data = dataset[:n_calib_examples] | ||
|
||
# Step 1: Setup quantized models | ||
m = self._setup_quantized_model(m, alpha, quant_mode, calibration_data) | ||
|
||
# Step 2: Setup save/load model with recipe functionality | ||
insert_smooth_quant_observer_(m_save_load, alpha, quant_mode) | ||
for example in calibration_data: | ||
m_save_load(example.to(device)) | ||
|
||
# Step 3: Test save/load recipe functionality | ||
with tempfile.NamedTemporaryFile() as temp_file: | ||
save_path = temp_file.name | ||
save_smooth_quant_recipe(m_save_load, save_path) | ||
load_smooth_quant_recipe(m_save_load, save_path) | ||
|
||
# Step 4: Complete quantization for save/load model | ||
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) | ||
quantize_(m_save_load, SmoothQuantConfig(), is_observed_linear) | ||
|
||
if TORCH_VERSION_AT_LEAST_2_5: | ||
m_save_load = torch.compile(m_save_load, fullgraph=True) | ||
|
||
# Step 5: Validate outputs on full dataset | ||
with torch.inference_mode(): | ||
original_outputs = [] | ||
save_load_outputs = [] | ||
|
||
for data in dataset: | ||
# Remove batch dimension for model input | ||
input_tensor = data.squeeze(0) | ||
|
||
original_output = m(input_tensor) | ||
save_load_output = m_save_load(input_tensor) | ||
|
||
original_outputs.append(original_output) | ||
save_load_outputs.append(save_load_output) | ||
|
||
# Concatenate all outputs for comparison | ||
original_result = torch.cat(original_outputs) | ||
save_load_out = torch.cat(save_load_outputs) | ||
|
||
self.assertIsNotNone( | ||
original_result, "Original model output should not be None" | ||
) | ||
self.assertIsNotNone( | ||
save_load_out, "Save/load model output should not be None" | ||
) | ||
|
||
torch.testing.assert_close( | ||
original_result, | ||
save_load_out, | ||
msg=f"Save/load recipe should produce identical results for " | ||
f"alpha={alpha}, quant_mode={quant_mode}, device={device}, dtype={input_dtype}", | ||
) | ||
|
||
|
||
common_utils.instantiate_parametrized_tests(TestSmoothQuant) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: does
unittest.skip(...)
work?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jerryzh168 https://docs.pytorch.org/docs/stable/notes/hip.html shows how
torch.version.hip
works. It isNone
in CUDA,str (ROCm version)
in ROCm, andNone
in CPU-only. That is, this test is skipped only when ROCm available.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I mean using
unittest.skip
instead of raise the SkipTest errorThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, but there is no conditioner in the
unittest.skip
. How about usingunittest.skipif
? See https://docs.python.org/3/library/unittest.html#unittest.skipIf for the reference, and I like this change because the code can be more brevityThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah sounds good, thanks