Skip to content

Make SmoothQuant more General #2728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 71 additions & 88 deletions test/prototype/test_smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import tempfile
import unittest
from copy import deepcopy

Expand All @@ -13,44 +12,18 @@
from torchao.prototype.smoothquant import (
SmoothQuantConfig,
SmoothQuantObservedLinear,
insert_smooth_quant_observer_,
load_smooth_quant_recipe,
save_smooth_quant_recipe,
)
from torchao.quantization import quantize_
from torchao.quantization.utils import (
dequantize_per_channel,
dynamically_quantize_per_channel,
)
from torchao.testing.model_architectures import ToyLinearModel
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
)


class ToyLinearModel(torch.nn.Module):
def __init__(self, m=512, n=256, k=128):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)
self.linear3 = torch.nn.Linear(k, 1, bias=False)

def example_inputs(
self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"
):
return [
torch.randn(
1, sequence_length, self.linear1.in_features, dtype=dtype, device=device
)
for j in range(batch_size)
]

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
return x


@unittest.skipIf(torch.version.hip is not None, "Skipping tests in ROCm")
class TestSmoothQuant(unittest.TestCase):
@classmethod
Expand Down Expand Up @@ -86,14 +59,15 @@ def forward(self, x):
test_data = torch.randn(2, 32, dtype=input_dtype, device=device)

# Step 1: Setup quantized model with observer insertion and calibration
insert_smooth_quant_observer_(m, alpha, quant_mode)
config = SmoothQuantConfig(step="prepare", alpha=alpha, quant_mode=quant_mode)
quantize_(m, config)

# Perform calibration with test data
m(test_data)

# Apply quantization configuration
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
quantize_(m, SmoothQuantConfig(), is_observed_linear)
config.step = "convert"
quantize_(m, config)

# Apply compilation if supported
if TORCH_VERSION_AT_LEAST_2_5:
Expand Down Expand Up @@ -174,98 +148,107 @@ def forward(self, x):
f"device={device}, dtype={input_dtype}",
)

def test_observer_insertion(self):
"""Test that PREPARE step correctly inserts SmoothQuantObservedLinear."""

class SimpleLinear(torch.nn.Module):
def __init__(self, bias: bool):
super().__init__()
self.fc = torch.nn.Linear(32, 32, bias)

def forward(self, x):
return self.fc(x)

m = SimpleLinear(True).eval()

# Before quantization - should be regular Linear
self.assertIsInstance(m.fc, torch.nn.Linear)
self.assertNotIsInstance(m.fc, SmoothQuantObservedLinear)

# PREPARE step - should insert observers
config = SmoothQuantConfig(step="prepare", alpha=0.5, quant_mode="dynamic")
quantize_(m, config)

# After PREPARE - should be SmoothQuantObservedLinear
self.assertIsInstance(m.fc, SmoothQuantObservedLinear)
self.assertTrue(hasattr(m.fc, "obs"))

# Test calibration
test_data = torch.randn(2, 32)
m(test_data)

# CONVERT step - should produce regular Linear with quantized weights
config.step = "convert"
quantize_(m, config)

# After CONVERT - should be regular Linear again (but quantized)
self.assertIsInstance(m.fc, torch.nn.Linear)
self.assertNotIsInstance(m.fc, SmoothQuantObservedLinear)

@unittest.skip("This test is broken on recent PyTorch, TODO(#1639): fix it")
@common_utils.parametrize("alpha", [None, 0.5, 0.75])
@common_utils.parametrize("quant_mode", ["static", "dynamic"])
@common_utils.parametrize(
"device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
)
@common_utils.parametrize("input_dtype", [torch.float, torch.bfloat16, torch.half])
def test_save_load_recipe(self, alpha, quant_mode, device, input_dtype):
"""Test save/load recipe functionality."""
def test_two_step_quantization(self, alpha, quant_mode, device, input_dtype):
"""Test two-step quantization process (PREPARE -> CONVERT)."""
dataset_size = 20
layer_dims = (512, 256, 128) # Input, hidden, output dimensions
n_calib_examples = 10
sequence_length = 5

# Create two identical models for comparison
m = ToyLinearModel(*layer_dims).eval().to(input_dtype).to(device)
m_save_load = deepcopy(m)
m1 = ToyLinearModel(*layer_dims).eval().to(input_dtype).to(device)
m2 = deepcopy(m1)

# Generate calibration dataset
dataset = m.example_inputs(
dataset = m1.example_inputs(
dataset_size,
sequence_length=sequence_length,
dtype=input_dtype,
device=device,
)
calibration_data = dataset[:n_calib_examples]

# Step 1: Setup first quantized model with observer insertion and calibration
insert_smooth_quant_observer_(m, alpha, quant_mode)
# Step 1: PREPARE - Insert observers
config = SmoothQuantConfig(step="prepare", alpha=alpha, quant_mode=quant_mode)
quantize_(m2, config)

# Perform calibration with calibration data
# Step 2: Calibration
for data in calibration_data:
m(data)
m2(data)

# Apply quantization configuration
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
quantize_(m, SmoothQuantConfig(), is_observed_linear)
quantize_(m2, SmoothQuantConfig(), is_observed_linear)

# Apply compilation if supported
if TORCH_VERSION_AT_LEAST_2_5:
m = torch.compile(m, fullgraph=True)

# Step 2: Setup save/load model with recipe functionality
insert_smooth_quant_observer_(m_save_load, alpha, quant_mode)
for example in calibration_data:
m_save_load(example.to(device))

# Step 3: Test save/load recipe functionality
with tempfile.NamedTemporaryFile() as temp_file:
save_path = temp_file.name
save_smooth_quant_recipe(m_save_load, save_path)
load_smooth_quant_recipe(m_save_load, save_path)

# Step 4: Complete quantization for save/load model
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
quantize_(m_save_load, SmoothQuantConfig(), is_observed_linear)
m2 = torch.compile(m2, fullgraph=True)

if TORCH_VERSION_AT_LEAST_2_5:
m_save_load = torch.compile(m_save_load, fullgraph=True)

# Step 5: Validate outputs on full dataset
with torch.inference_mode():
original_outputs = []
save_load_outputs = []

for data in dataset:
# Remove batch dimension for model input
input_tensor = data.squeeze(0)

original_output = m(input_tensor)
save_load_output = m_save_load(input_tensor)
# Step 4: Validate outputs on full dataset
with torch.inference_mode():
m2_outputs = []

original_outputs.append(original_output)
save_load_outputs.append(save_load_output)
for data in dataset:
# Remove batch dimension for model input
input_tensor = data.squeeze(0)
m2_output = m2(input_tensor)
m2_outputs.append(m2_output)

# Concatenate all outputs for comparison
original_result = torch.cat(original_outputs)
save_load_out = torch.cat(save_load_outputs)
# Concatenate all outputs
m2_result = torch.cat(m2_outputs)

self.assertIsNotNone(
original_result, "Original model output should not be None"
)
self.assertIsNotNone(
save_load_out, "Save/load model output should not be None"
)
self.assertIsNotNone(m2_result, "Quantized model output should not be None")

torch.testing.assert_close(
original_result,
save_load_out,
msg=f"Save/load recipe should produce identical results for "
f"alpha={alpha}, quant_mode={quant_mode}, device={device}, dtype={input_dtype}",
)
# Check that model produces reasonable outputs
self.assertFalse(
torch.isnan(m2_result).any(),
f"Quantized model should not produce NaN values for "
f"alpha={alpha}, quant_mode={quant_mode}, device={device}, dtype={input_dtype}",
)


common_utils.instantiate_parametrized_tests(TestSmoothQuant)
Expand Down
10 changes: 5 additions & 5 deletions torchao/prototype/smoothquant/README.md
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
# SmothQuant quantization
# SmoothQuant quantization
This is a native PyTorch implementation of the algorithm described in [this paper](https://arxiv.org/abs/2211.10438).

In this implementation, weights are smoothed (equalized) and quantized to int8 during quantization. Activations are smoothed and quantized to int8 at runtime. Quantization is done either dynamically or statically. If activations are dynamically quantized, qparams (i.e., scales) are found at runtime while qparams are found during quantization for static quantization. For dynamic quantization, activations are quantized per token. And for static quantization, activations are quantized per tensor. Generally, dynamic quantization produces better accuracy while static quantization has better latency. In both cases, weights and activations are symmetrically quantized.

## Quick start
Run the example code with
```bash
python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static>
python example.py -m MODEL_ID --device=<cuda or cpu> --quant-mode=<dynamic or static>
# An example
python example.py -m meta-llama/Llama-2-7b-hf --device=cuda --quant-mode=dynamic
```
To use the `torch.compile` for speedup, add `--compile`. You may want to export `TORCHINDUCTOR_FREEZING=1` for even better performance.
```bash
TORCHINDUCTOR_FREEZING=1 python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --compile
TORCHINDUCTOR_FREEZING=1 python example.py -m MODEL_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --compile
```
To save a quantized model for reuse, specify `--model-save-path`
```bash
python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --model-save-path ./quantized_model.pt
python example.py -m MODEL_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --model-save-path ./quantized_model.pt
```
And load it by `--model-load-path`
```bash
python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --model-load-path ./quantized_model.pt
python example.py -m MODEL_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --model-load-path ./quantized_model.pt
```


Expand Down
16 changes: 7 additions & 9 deletions torchao/prototype/smoothquant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from .api import (
SmoothQuantConfig,
insert_smooth_quant_observer_,
load_smooth_quant_recipe,
save_smooth_quant_recipe,
from .api import SmoothQuantConfig
from .core import (
SmoothQuantObservedLinear,
SmoothQuantObserver,
SmoothQuantStep,
)
from .core import SmoothQuantObservedLinear

__all__ = [
"insert_smooth_quant_observer_",
"load_smooth_quant_recipe",
"save_smooth_quant_recipe",
"SmoothQuantConfig",
"SmoothQuantStep",
"SmoothQuantObserver",
"SmoothQuantObservedLinear",
]
Loading