-
Notifications
You must be signed in to change notification settings - Fork 323
Convert SmoothQuant test to unittest #2659
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
e44599e
e4a6b02
a6dd8bd
3ffdd10
7217f8f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
dequantize_per_channel, | ||
dynamically_quantize_per_channel, | ||
) | ||
from torchao.testing import common_utils | ||
from torchao.utils import ( | ||
TORCH_VERSION_AT_LEAST_2_5, | ||
) | ||
|
@@ -62,33 +63,32 @@ def setUpClass(cls): | |
# This test case will trigger recompilation many times, so set a large cache_size_limit here | ||
torch._dynamo.config.cache_size_limit = 128 | ||
|
||
# Define test parameter ranges | ||
cls.bias_options = [True, False] | ||
cls.alpha_options = [None, 0.5, 0.75] # None means conventional quantization | ||
cls.quant_mode_options = ["static", "dynamic"] | ||
cls.devices = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) | ||
cls.input_dtypes = (torch.float, torch.bfloat16, torch.half) | ||
def _setup_quantized_model(self, model, alpha, quant_mode, calibration_data): | ||
"""Setup for quantized models with observer insertion and calibration.""" | ||
insert_smooth_quant_observer_(model, alpha, quant_mode) | ||
|
||
for data in calibration_data: | ||
model(data) | ||
|
||
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) | ||
quantize_(model, SmoothQuantConfig(), is_observed_linear) | ||
|
||
if TORCH_VERSION_AT_LEAST_2_5: | ||
model = torch.compile(model, fullgraph=True) | ||
|
||
return model | ||
|
||
@unittest.skip("This test is broken on recent PyTorch, TODO(#1639): fix it") | ||
def test_smoothquant_accuracy(self): | ||
@common_utils.parametrize("bias", [True, False]) | ||
@common_utils.parametrize("alpha", [None, 0.5, 0.75]) | ||
@common_utils.parametrize("quant_mode", ["static", "dynamic"]) | ||
@common_utils.parametrize( | ||
"device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) | ||
) | ||
@common_utils.parametrize("input_dtype", [torch.float, torch.bfloat16, torch.half]) | ||
def test_smoothquant_accuracy(self, bias, alpha, quant_mode, device, input_dtype): | ||
"""Test the margin error of SmoothQuant across bias, alpha, dtype, etc.""" | ||
|
||
# Test all parameter combinations using subTest for better isolation | ||
for bias in self.bias_options: | ||
for alpha in self.alpha_options: | ||
for quant_mode in self.quant_mode_options: | ||
for device in self.devices: | ||
for input_dtype in self.input_dtypes: | ||
with self.subTest( | ||
bias=bias, | ||
alpha=alpha, | ||
quant_mode=quant_mode, | ||
device=device, | ||
input_dtype=input_dtype, | ||
): | ||
self._run_compute_accuracy_test( | ||
bias, alpha, quant_mode, device, input_dtype | ||
) | ||
self._run_compute_accuracy_test(bias, alpha, quant_mode, device, input_dtype) | ||
|
||
def _run_compute_accuracy_test(self, bias, alpha, quant_mode, device, input_dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can also remove this function, seems no need to define a function here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, those functions are undefined for more brevity in 3ffdd10 . |
||
"""Single compute accuracy test""" | ||
|
@@ -108,25 +108,18 @@ def forward(self, x): | |
test_data = torch.randn(2, 32, dtype=input_dtype, device=device) | ||
|
||
# Step 1: Get calibration from observed SmoothQuant | ||
insert_smooth_quant_observer_(m, alpha, quant_mode) | ||
m(test_data) | ||
m = self._setup_quantized_model(m, alpha, quant_mode, [test_data]) | ||
|
||
# Step 2: Quantize | ||
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) | ||
quantize_(m, SmoothQuantConfig(), is_observed_linear) | ||
|
||
# Step 3: Inference quantized model | ||
# Step 2: Inference quantized model | ||
with torch.inference_mode(): | ||
if TORCH_VERSION_AT_LEAST_2_5: | ||
m = torch.compile(m, fullgraph=True) | ||
q_out = m(test_data) | ||
|
||
# Step 4: Compute reference | ||
# Step 3: Compute reference | ||
reference_out = self._compute_reference_out( | ||
m_ref, test_data, alpha, quant_mode, bias, input_dtype | ||
) | ||
|
||
# Step 5: Validate numerical accuracy | ||
# Step 4: Validate numerical accuracy | ||
tolerance = ( | ||
0.1 | ||
if input_dtype == torch.float | ||
|
@@ -198,21 +191,15 @@ def _compute_reference_out(self, m_ref, data, alpha, quant_mode, bias, input_dty | |
return torch.nn.functional.linear(fq_act, fq_wei, b) | ||
|
||
@unittest.skip("This test is broken on recent PyTorch, TODO(#1639): fix it") | ||
def test_save_load_recipe(self): | ||
"""Setup test for save/load recipe functionality.""" | ||
for alpha in self.alpha_options: | ||
for quant_mode in self.quant_mode_options: | ||
for device in self.devices: | ||
for input_dtype in self.input_dtypes: | ||
with self.subTest( | ||
alpha=alpha, | ||
quant_mode=quant_mode, | ||
device=device, | ||
input_dtype=input_dtype, | ||
): | ||
self._run_save_load_recipe_test( | ||
alpha, quant_mode, device, input_dtype | ||
) | ||
@common_utils.parametrize("alpha", [None, 0.5, 0.75]) | ||
@common_utils.parametrize("quant_mode", ["static", "dynamic"]) | ||
@common_utils.parametrize( | ||
"device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) | ||
) | ||
@common_utils.parametrize("input_dtype", [torch.float, torch.bfloat16, torch.half]) | ||
def test_save_load_recipe(self, alpha, quant_mode, device, input_dtype): | ||
"""Test save/load recipe functionality.""" | ||
self._run_save_load_recipe_test(alpha, quant_mode, device, input_dtype) | ||
|
||
def _run_save_load_recipe_test(self, alpha, quant_mode, device, input_dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also this one |
||
"""Single save/load recipe test.""" | ||
|
@@ -234,13 +221,12 @@ def _run_save_load_recipe_test(self, alpha, quant_mode, device, input_dtype): | |
) | ||
calibration_data = dataset[:n_calib_examples] | ||
|
||
# Step 1: Insert observers in both models | ||
insert_smooth_quant_observer_(m, alpha, quant_mode) | ||
insert_smooth_quant_observer_(m_save_load, alpha, quant_mode) | ||
# Step 1: Setup quantized models | ||
m = self._setup_quantized_model(m, alpha, quant_mode, calibration_data) | ||
|
||
# Step 2: Calibrate both models with identical data | ||
# Step 2: Setup save/load model with recipe functionality | ||
insert_smooth_quant_observer_(m_save_load, alpha, quant_mode) | ||
for example in calibration_data: | ||
m(example.to(device)) | ||
m_save_load(example.to(device)) | ||
|
||
# Step 3: Test save/load recipe functionality | ||
|
@@ -249,14 +235,11 @@ def _run_save_load_recipe_test(self, alpha, quant_mode, device, input_dtype): | |
save_smooth_quant_recipe(m_save_load, save_path) | ||
load_smooth_quant_recipe(m_save_load, save_path) | ||
|
||
# Step 4: Quantize both models | ||
# Step 4: Complete quantization for save/load model | ||
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) | ||
quantize_(m, SmoothQuantConfig(), is_observed_linear) | ||
quantize_(m_save_load, SmoothQuantConfig(), is_observed_linear) | ||
|
||
if TORCH_VERSION_AT_LEAST_2_5: | ||
# earlier versions are not compatible | ||
m = torch.compile(m, fullgraph=True) | ||
m_save_load = torch.compile(m_save_load, fullgraph=True) | ||
|
||
# Step 5: Validate outputs on full dataset | ||
|
@@ -293,5 +276,7 @@ def _run_save_load_recipe_test(self, alpha, quant_mode, device, input_dtype): | |
) | ||
|
||
|
||
common_utils.instantiate_parametrized_tests(TestSmoothQuant) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel it's probably fine to inline these as well, there are not many lines of code here