Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 44 additions & 59 deletions test/prototype/test_smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
dequantize_per_channel,
dynamically_quantize_per_channel,
)
from torchao.testing import common_utils
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
)
Expand Down Expand Up @@ -62,33 +63,32 @@ def setUpClass(cls):
# This test case will trigger recompilation many times, so set a large cache_size_limit here
torch._dynamo.config.cache_size_limit = 128

# Define test parameter ranges
cls.bias_options = [True, False]
cls.alpha_options = [None, 0.5, 0.75] # None means conventional quantization
cls.quant_mode_options = ["static", "dynamic"]
cls.devices = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
cls.input_dtypes = (torch.float, torch.bfloat16, torch.half)
def _setup_quantized_model(self, model, alpha, quant_mode, calibration_data):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's probably fine to inline these as well, there are not many lines of code here

"""Setup for quantized models with observer insertion and calibration."""
insert_smooth_quant_observer_(model, alpha, quant_mode)

for data in calibration_data:
model(data)

is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
quantize_(model, SmoothQuantConfig(), is_observed_linear)

if TORCH_VERSION_AT_LEAST_2_5:
model = torch.compile(model, fullgraph=True)

return model

@unittest.skip("This test is broken on recent PyTorch, TODO(#1639): fix it")
def test_smoothquant_accuracy(self):
@common_utils.parametrize("bias", [True, False])
@common_utils.parametrize("alpha", [None, 0.5, 0.75])
@common_utils.parametrize("quant_mode", ["static", "dynamic"])
@common_utils.parametrize(
"device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
)
@common_utils.parametrize("input_dtype", [torch.float, torch.bfloat16, torch.half])
def test_smoothquant_accuracy(self, bias, alpha, quant_mode, device, input_dtype):
"""Test the margin error of SmoothQuant across bias, alpha, dtype, etc."""

# Test all parameter combinations using subTest for better isolation
for bias in self.bias_options:
for alpha in self.alpha_options:
for quant_mode in self.quant_mode_options:
for device in self.devices:
for input_dtype in self.input_dtypes:
with self.subTest(
bias=bias,
alpha=alpha,
quant_mode=quant_mode,
device=device,
input_dtype=input_dtype,
):
self._run_compute_accuracy_test(
bias, alpha, quant_mode, device, input_dtype
)
self._run_compute_accuracy_test(bias, alpha, quant_mode, device, input_dtype)

def _run_compute_accuracy_test(self, bias, alpha, quant_mode, device, input_dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can also remove this function, seems no need to define a function here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, those functions are undefined for more brevity in 3ffdd10 .

"""Single compute accuracy test"""
Expand All @@ -108,25 +108,18 @@ def forward(self, x):
test_data = torch.randn(2, 32, dtype=input_dtype, device=device)

# Step 1: Get calibration from observed SmoothQuant
insert_smooth_quant_observer_(m, alpha, quant_mode)
m(test_data)
m = self._setup_quantized_model(m, alpha, quant_mode, [test_data])

# Step 2: Quantize
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
quantize_(m, SmoothQuantConfig(), is_observed_linear)

# Step 3: Inference quantized model
# Step 2: Inference quantized model
with torch.inference_mode():
if TORCH_VERSION_AT_LEAST_2_5:
m = torch.compile(m, fullgraph=True)
q_out = m(test_data)

# Step 4: Compute reference
# Step 3: Compute reference
reference_out = self._compute_reference_out(
m_ref, test_data, alpha, quant_mode, bias, input_dtype
)

# Step 5: Validate numerical accuracy
# Step 4: Validate numerical accuracy
tolerance = (
0.1
if input_dtype == torch.float
Expand Down Expand Up @@ -198,21 +191,15 @@ def _compute_reference_out(self, m_ref, data, alpha, quant_mode, bias, input_dty
return torch.nn.functional.linear(fq_act, fq_wei, b)

@unittest.skip("This test is broken on recent PyTorch, TODO(#1639): fix it")
def test_save_load_recipe(self):
"""Setup test for save/load recipe functionality."""
for alpha in self.alpha_options:
for quant_mode in self.quant_mode_options:
for device in self.devices:
for input_dtype in self.input_dtypes:
with self.subTest(
alpha=alpha,
quant_mode=quant_mode,
device=device,
input_dtype=input_dtype,
):
self._run_save_load_recipe_test(
alpha, quant_mode, device, input_dtype
)
@common_utils.parametrize("alpha", [None, 0.5, 0.75])
@common_utils.parametrize("quant_mode", ["static", "dynamic"])
@common_utils.parametrize(
"device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
)
@common_utils.parametrize("input_dtype", [torch.float, torch.bfloat16, torch.half])
def test_save_load_recipe(self, alpha, quant_mode, device, input_dtype):
"""Test save/load recipe functionality."""
self._run_save_load_recipe_test(alpha, quant_mode, device, input_dtype)

def _run_save_load_recipe_test(self, alpha, quant_mode, device, input_dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also this one

"""Single save/load recipe test."""
Expand All @@ -234,13 +221,12 @@ def _run_save_load_recipe_test(self, alpha, quant_mode, device, input_dtype):
)
calibration_data = dataset[:n_calib_examples]

# Step 1: Insert observers in both models
insert_smooth_quant_observer_(m, alpha, quant_mode)
insert_smooth_quant_observer_(m_save_load, alpha, quant_mode)
# Step 1: Setup quantized models
m = self._setup_quantized_model(m, alpha, quant_mode, calibration_data)

# Step 2: Calibrate both models with identical data
# Step 2: Setup save/load model with recipe functionality
insert_smooth_quant_observer_(m_save_load, alpha, quant_mode)
for example in calibration_data:
m(example.to(device))
m_save_load(example.to(device))

# Step 3: Test save/load recipe functionality
Expand All @@ -249,14 +235,11 @@ def _run_save_load_recipe_test(self, alpha, quant_mode, device, input_dtype):
save_smooth_quant_recipe(m_save_load, save_path)
load_smooth_quant_recipe(m_save_load, save_path)

# Step 4: Quantize both models
# Step 4: Complete quantization for save/load model
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
quantize_(m, SmoothQuantConfig(), is_observed_linear)
quantize_(m_save_load, SmoothQuantConfig(), is_observed_linear)

if TORCH_VERSION_AT_LEAST_2_5:
# earlier versions are not compatible
m = torch.compile(m, fullgraph=True)
m_save_load = torch.compile(m_save_load, fullgraph=True)

# Step 5: Validate outputs on full dataset
Expand Down Expand Up @@ -293,5 +276,7 @@ def _run_save_load_recipe_test(self, alpha, quant_mode, device, input_dtype):
)


common_utils.instantiate_parametrized_tests(TestSmoothQuant)

if __name__ == "__main__":
unittest.main()