Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
331 changes: 209 additions & 122 deletions test/prototype/test_smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import tempfile
import unittest
from copy import deepcopy

import pytest
import torch

from torchao.prototype.smoothquant import (
Expand All @@ -21,13 +21,11 @@
dequantize_per_channel,
dynamically_quantize_per_channel,
)
from torchao.testing import common_utils
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
)

if torch.version.hip is not None:
pytest.skip("Skipping the test in ROCm", allow_module_level=True)


class ToyLinearModel(torch.nn.Module):
def __init__(self, m=512, n=256, k=128):
Expand All @@ -53,143 +51,232 @@ def forward(self, x):
return x


bias_list = [True, False]
alpha_list = [None, 0.5, 0.75]
quant_mode_list = ["static", "dynamic"]
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
idtypes = (torch.float, torch.bfloat16, torch.half)

if TORCH_VERSION_AT_LEAST_2_5:
# This test case will trigger recompilation many times, so set a large cache_size_limit here
torch._dynamo.config.cache_size_limit = 128


@pytest.mark.parametrize("bias", bias_list)
@pytest.mark.parametrize("alpha", alpha_list)
@pytest.mark.parametrize("quant_mode", quant_mode_list)
@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("idtype", idtypes)
@pytest.mark.skip("this test is broken on recent PyTorch, TODO(#1639): fix it")
def test_compute(bias, alpha, quant_mode, device, idtype):
class Linear(torch.nn.Module):
def __init__(self, bias: bool):
super().__init__()
self.fc = torch.nn.Linear(32, 32, bias)
self.fc.weight.data = torch.randn_like(self.fc.weight.data)

def forward(self, x):
return self.fc(x)

m = Linear(bias).eval().to(idtype).to(device)
m_ref = deepcopy(m)
data = torch.randn(2, 32, dtype=idtype, device=device)

# calibrate
insert_smooth_quant_observer_(m, alpha, quant_mode)
m(data)
# quantize
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
quantize_(m, SmoothQuantConfig(), is_observed_linear)
with torch.inference_mode():
class TestSmoothQuant(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""Set up class-level configuration for tests."""
# Skip tests on ROCm (AMD GPU) due to compatibility issues
if torch.version.hip is not None:
raise unittest.SkipTest("Skipping the tests in ROCm")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: does unittest.skip(...) work?

Copy link
Contributor Author

@namgyu-youn namgyu-youn Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 https://docs.pytorch.org/docs/stable/notes/hip.html shows how torch.version.hip works. It is None in CUDA, str (ROCm version) in ROCm, and None in CPU-only. That is, this test is skipped only when ROCm available.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I mean using unittest.skip instead of raise the SkipTest error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, but there is no conditioner in the unittest.skip . How about using unittest.skipif? See https://docs.python.org/3/library/unittest.html#unittest.skipIf for the reference, and I like this change because the code can be more brevity

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sounds good, thanks


if TORCH_VERSION_AT_LEAST_2_5:
# This test case will trigger recompilation many times, so set a large cache_size_limit here
torch._dynamo.config.cache_size_limit = 128

def _setup_quantized_model(self, model, alpha, quant_mode, calibration_data):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's probably fine to inline these as well, there are not many lines of code here

"""Setup for quantized models with observer insertion and calibration."""
insert_smooth_quant_observer_(model, alpha, quant_mode)

for data in calibration_data:
model(data)

is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
quantize_(model, SmoothQuantConfig(), is_observed_linear)

if TORCH_VERSION_AT_LEAST_2_5:
m = torch.compile(m, fullgraph=True)
out = m(data)
model = torch.compile(model, fullgraph=True)

return model

@unittest.skip("This test is broken on recent PyTorch, TODO(#1639): fix it")
@common_utils.parametrize("bias", [True, False])
@common_utils.parametrize("alpha", [None, 0.5, 0.75])
@common_utils.parametrize("quant_mode", ["static", "dynamic"])
@common_utils.parametrize(
"device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
)
@common_utils.parametrize("input_dtype", [torch.float, torch.bfloat16, torch.half])
def test_smoothquant_accuracy(self, bias, alpha, quant_mode, device, input_dtype):
"""Test the margin error of SmoothQuant across bias, alpha, dtype, etc."""
self._run_compute_accuracy_test(bias, alpha, quant_mode, device, input_dtype)

def _run_compute_accuracy_test(self, bias, alpha, quant_mode, device, input_dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can also remove this function, seems no need to define a function here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, those functions are undefined for more brevity in 3ffdd10 .

"""Single compute accuracy test"""

class SimpleLinear(torch.nn.Module):
def __init__(self, bias: bool):
super().__init__()
self.fc = torch.nn.Linear(32, 32, bias)
self.fc.weight.data = torch.randn_like(self.fc.weight.data)

def forward(self, x):
return self.fc(x)

# reference
# Create model, reference, and test data
m = SimpleLinear(bias).eval().to(input_dtype).to(device)
m_ref = deepcopy(m)
test_data = torch.randn(2, 32, dtype=input_dtype, device=device)

# Step 1: Get calibration from observed SmoothQuant
m = self._setup_quantized_model(m, alpha, quant_mode, [test_data])

# Step 2: Inference quantized model
with torch.inference_mode():
q_out = m(test_data)

# Step 3: Compute reference
reference_out = self._compute_reference_out(
m_ref, test_data, alpha, quant_mode, bias, input_dtype
)

# Step 4: Validate numerical accuracy
tolerance = (
0.1
if input_dtype == torch.float
else (0.2 if input_dtype == torch.half else 0.3)
)
torch.testing.assert_close(
q_out,
reference_out.to(input_dtype),
atol=tolerance,
msg=f"Quantized output differs from reference for "
f"bias={bias}, alpha={alpha}, quant_mode={quant_mode}, "
f"device={device}, dtype={input_dtype}",
)

def _compute_reference_out(self, m_ref, data, alpha, quant_mode, bias, input_dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's used only once, I'd suggest not to define new functions

"""Compute the expected SmoothQuant output."""
weight = m_ref.fc.weight.data.float()
b = m_ref.fc.bias if bias else None
x_abs_max_per_ic = torch.abs(data).max(dim=0).values
w_abs_max_per_ic = torch.abs(weight).max(dim=0).values
smoothing_factor = (
1
if alpha is None
else (
torch.pow(x_abs_max_per_ic, alpha)
/ torch.pow(w_abs_max_per_ic, 1 - alpha)
if alpha is not None:
# Apply SmoothQuant
smoothing_factor = torch.pow(x_abs_max_per_ic, alpha) / torch.pow(
w_abs_max_per_ic, 1 - alpha
)
)
act = data / smoothing_factor
wei = weight * smoothing_factor
else:
smoothing_factor = torch.ones_like(x_abs_max_per_ic)

# Apply smoothing to activations and weights
smoothed_activation = data / smoothing_factor
smoothed_weight = weight * smoothing_factor

# Quantize weights using per-channel quantization
qw, w_scales, w_zps = dynamically_quantize_per_channel(
wei, -127, 127, torch.int8
smoothed_weight, -127, 127, torch.int8
)
fq_wei = dequantize_per_channel(qw, w_scales, w_zps, idtype)
fq_wei = dequantize_per_channel(qw, w_scales, w_zps, input_dtype)

# Handle activation quantization based on mode
if quant_mode == "static":
# activation is quantized per-tensor
act_min, act_max = torch.aminmax(act.float())
act_min, act_max = torch.aminmax(smoothed_activation.float())
max_val_pos = torch.max(-act_min, act_max)
act_scale = max_val_pos / 127.0
activation_scale = max_val_pos / 127.0

fq_act = (
torch.quantize_per_tensor(
act.float(), scale=act_scale.item(), zero_point=0, dtype=torch.qint8
smoothed_activation.float(),
scale=activation_scale.item(),
zero_point=0,
dtype=torch.qint8,
)
.dequantize()
.to(idtype)
.to(input_dtype)
)
out_ref = torch.nn.functional.linear(fq_act, fq_wei, b)
else:
# activation is quantized per-row (batch * sequence_length)
qx, x_scales, x_zps = dynamically_quantize_per_channel(
act.float(), -127, 127, torch.int8
smoothed_activation.float(), -127, 127, torch.int8
)
fq_act = dequantize_per_channel(
qx,
x_scales,
x_zps,
input_dtype,
)
fq_act = dequantize_per_channel(qx, x_scales, x_zps, idtype)
out_ref = torch.nn.functional.linear(fq_act, fq_wei, b)

# BFloat16 and Float16 have larger errors
atol = 0.1 if idtype == torch.float else (0.2 if idtype == torch.half else 0.3)
assert torch.allclose(out, out_ref.to(idtype), atol=atol)


@pytest.mark.parametrize("alpha", alpha_list)
@pytest.mark.parametrize("quant_mode", quant_mode_list)
@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("idtype", idtypes)
@pytest.mark.skip("this test is broken on recent PyTorch, TODO(#1639): fix it")
def test_save_load_recipe(alpha, quant_mode, device, idtype):
dataset_size = 20
l1, l2, l3 = 512, 256, 128
original_dtype = idtype
n_calib_examples = 10
sequence_length = 5

m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
m_save_load = deepcopy(m)

dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
dtype=original_dtype,
device=device,

# Compute final linear operation
return torch.nn.functional.linear(fq_act, fq_wei, b)

@unittest.skip("This test is broken on recent PyTorch, TODO(#1639): fix it")
@common_utils.parametrize("alpha", [None, 0.5, 0.75])
@common_utils.parametrize("quant_mode", ["static", "dynamic"])
@common_utils.parametrize(
"device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
)
calibration_data = dataset[:n_calib_examples]

# calibrate
insert_smooth_quant_observer_(m, alpha, quant_mode)
insert_smooth_quant_observer_(m_save_load, alpha, quant_mode)

for example in calibration_data:
m(example.to(device))
m_save_load(example.to(device))

with tempfile.NamedTemporaryFile() as fp:
save_path = fp.name
save_smooth_quant_recipe(m_save_load, save_path)
load_smooth_quant_recipe(m_save_load, save_path)

# quantize
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
quantize_(m, SmoothQuantConfig(), is_observed_linear)
if TORCH_VERSION_AT_LEAST_2_5:
# earlier versions are not compatible
m = torch.compile(m, fullgraph=True)
m_save_load = torch.compile(m_save_load, fullgraph=True)
out_list = [m(data.squeeze(0)) for data in dataset]
out = torch.cat(out_list)
save_load_out_list = [m_save_load(data.squeeze(0)) for data in dataset]
save_load_out = torch.cat(save_load_out_list)

assert out is not None
assert save_load_out is not None
assert torch.allclose(out, save_load_out)
@common_utils.parametrize("input_dtype", [torch.float, torch.bfloat16, torch.half])
def test_save_load_recipe(self, alpha, quant_mode, device, input_dtype):
"""Test save/load recipe functionality."""
self._run_save_load_recipe_test(alpha, quant_mode, device, input_dtype)

def _run_save_load_recipe_test(self, alpha, quant_mode, device, input_dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also this one

"""Single save/load recipe test."""
dataset_size = 20
layer_dims = (512, 256, 128) # Input, hidden, output dimensions
n_calib_examples = 10
sequence_length = 5

# Create two identical models for comparison
m = ToyLinearModel(*layer_dims).eval().to(input_dtype).to(device)
m_save_load = deepcopy(m)

# Generate calibration dataset
dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
dtype=input_dtype,
device=device,
)
calibration_data = dataset[:n_calib_examples]

# Step 1: Setup quantized models
m = self._setup_quantized_model(m, alpha, quant_mode, calibration_data)

# Step 2: Setup save/load model with recipe functionality
insert_smooth_quant_observer_(m_save_load, alpha, quant_mode)
for example in calibration_data:
m_save_load(example.to(device))

# Step 3: Test save/load recipe functionality
with tempfile.NamedTemporaryFile() as temp_file:
save_path = temp_file.name
save_smooth_quant_recipe(m_save_load, save_path)
load_smooth_quant_recipe(m_save_load, save_path)

# Step 4: Complete quantization for save/load model
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
quantize_(m_save_load, SmoothQuantConfig(), is_observed_linear)

if TORCH_VERSION_AT_LEAST_2_5:
m_save_load = torch.compile(m_save_load, fullgraph=True)

# Step 5: Validate outputs on full dataset
with torch.inference_mode():
original_outputs = []
save_load_outputs = []

for data in dataset:
# Remove batch dimension for model input
input_tensor = data.squeeze(0)

original_output = m(input_tensor)
save_load_output = m_save_load(input_tensor)

original_outputs.append(original_output)
save_load_outputs.append(save_load_output)

# Concatenate all outputs for comparison
original_result = torch.cat(original_outputs)
save_load_out = torch.cat(save_load_outputs)

self.assertIsNotNone(
original_result, "Original model output should not be None"
)
self.assertIsNotNone(
save_load_out, "Save/load model output should not be None"
)

torch.testing.assert_close(
original_result,
save_load_out,
msg=f"Save/load recipe should produce identical results for "
f"alpha={alpha}, quant_mode={quant_mode}, device={device}, dtype={input_dtype}",
)


common_utils.instantiate_parametrized_tests(TestSmoothQuant)

if __name__ == "__main__":
unittest.main()