diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index 568b2d964f..78cfa220e7 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -3,7 +3,6 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import tempfile import unittest from copy import deepcopy @@ -13,44 +12,18 @@ from torchao.prototype.smoothquant import ( SmoothQuantConfig, SmoothQuantObservedLinear, - insert_smooth_quant_observer_, - load_smooth_quant_recipe, - save_smooth_quant_recipe, ) from torchao.quantization import quantize_ from torchao.quantization.utils import ( dequantize_per_channel, dynamically_quantize_per_channel, ) +from torchao.testing.model_architectures import ToyLinearModel from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, ) -class ToyLinearModel(torch.nn.Module): - def __init__(self, m=512, n=256, k=128): - super().__init__() - self.linear1 = torch.nn.Linear(m, n, bias=False) - self.linear2 = torch.nn.Linear(n, k, bias=False) - self.linear3 = torch.nn.Linear(k, 1, bias=False) - - def example_inputs( - self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda" - ): - return [ - torch.randn( - 1, sequence_length, self.linear1.in_features, dtype=dtype, device=device - ) - for j in range(batch_size) - ] - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - x = self.linear3(x) - return x - - @unittest.skipIf(torch.version.hip is not None, "Skipping tests in ROCm") class TestSmoothQuant(unittest.TestCase): @classmethod @@ -86,14 +59,15 @@ def forward(self, x): test_data = torch.randn(2, 32, dtype=input_dtype, device=device) # Step 1: Setup quantized model with observer insertion and calibration - insert_smooth_quant_observer_(m, alpha, quant_mode) + config = SmoothQuantConfig(step="prepare", alpha=alpha, quant_mode=quant_mode) + quantize_(m, config) # Perform calibration with test data m(test_data) # Apply quantization configuration - is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) - quantize_(m, SmoothQuantConfig(), is_observed_linear) + config.step = "convert" + quantize_(m, config) # Apply compilation if supported if TORCH_VERSION_AT_LEAST_2_5: @@ -174,6 +148,43 @@ def forward(self, x): f"device={device}, dtype={input_dtype}", ) + def test_observer_insertion(self): + """Test that PREPARE step correctly inserts SmoothQuantObservedLinear.""" + + class SimpleLinear(torch.nn.Module): + def __init__(self, bias: bool): + super().__init__() + self.fc = torch.nn.Linear(32, 32, bias) + + def forward(self, x): + return self.fc(x) + + m = SimpleLinear(True).eval() + + # Before quantization - should be regular Linear + self.assertIsInstance(m.fc, torch.nn.Linear) + self.assertNotIsInstance(m.fc, SmoothQuantObservedLinear) + + # PREPARE step - should insert observers + config = SmoothQuantConfig(step="prepare", alpha=0.5, quant_mode="dynamic") + quantize_(m, config) + + # After PREPARE - should be SmoothQuantObservedLinear + self.assertIsInstance(m.fc, SmoothQuantObservedLinear) + self.assertTrue(hasattr(m.fc, "obs")) + + # Test calibration + test_data = torch.randn(2, 32) + m(test_data) + + # CONVERT step - should produce regular Linear with quantized weights + config.step = "convert" + quantize_(m, config) + + # After CONVERT - should be regular Linear again (but quantized) + self.assertIsInstance(m.fc, torch.nn.Linear) + self.assertNotIsInstance(m.fc, SmoothQuantObservedLinear) + @unittest.skip("This test is broken on recent PyTorch, TODO(#1639): fix it") @common_utils.parametrize("alpha", [None, 0.5, 0.75]) @common_utils.parametrize("quant_mode", ["static", "dynamic"]) @@ -181,19 +192,19 @@ def forward(self, x): "device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) ) @common_utils.parametrize("input_dtype", [torch.float, torch.bfloat16, torch.half]) - def test_save_load_recipe(self, alpha, quant_mode, device, input_dtype): - """Test save/load recipe functionality.""" + def test_two_step_quantization(self, alpha, quant_mode, device, input_dtype): + """Test two-step quantization process (PREPARE -> CONVERT).""" dataset_size = 20 layer_dims = (512, 256, 128) # Input, hidden, output dimensions n_calib_examples = 10 sequence_length = 5 # Create two identical models for comparison - m = ToyLinearModel(*layer_dims).eval().to(input_dtype).to(device) - m_save_load = deepcopy(m) + m1 = ToyLinearModel(*layer_dims).eval().to(input_dtype).to(device) + m2 = deepcopy(m1) # Generate calibration dataset - dataset = m.example_inputs( + dataset = m1.example_inputs( dataset_size, sequence_length=sequence_length, dtype=input_dtype, @@ -201,71 +212,43 @@ def test_save_load_recipe(self, alpha, quant_mode, device, input_dtype): ) calibration_data = dataset[:n_calib_examples] - # Step 1: Setup first quantized model with observer insertion and calibration - insert_smooth_quant_observer_(m, alpha, quant_mode) + # Step 1: PREPARE - Insert observers + config = SmoothQuantConfig(step="prepare", alpha=alpha, quant_mode=quant_mode) + quantize_(m2, config) - # Perform calibration with calibration data + # Step 2: Calibration for data in calibration_data: - m(data) + m2(data) # Apply quantization configuration is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) - quantize_(m, SmoothQuantConfig(), is_observed_linear) + quantize_(m2, SmoothQuantConfig(), is_observed_linear) # Apply compilation if supported if TORCH_VERSION_AT_LEAST_2_5: - m = torch.compile(m, fullgraph=True) - - # Step 2: Setup save/load model with recipe functionality - insert_smooth_quant_observer_(m_save_load, alpha, quant_mode) - for example in calibration_data: - m_save_load(example.to(device)) - - # Step 3: Test save/load recipe functionality - with tempfile.NamedTemporaryFile() as temp_file: - save_path = temp_file.name - save_smooth_quant_recipe(m_save_load, save_path) - load_smooth_quant_recipe(m_save_load, save_path) - - # Step 4: Complete quantization for save/load model - is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) - quantize_(m_save_load, SmoothQuantConfig(), is_observed_linear) + m2 = torch.compile(m2, fullgraph=True) - if TORCH_VERSION_AT_LEAST_2_5: - m_save_load = torch.compile(m_save_load, fullgraph=True) - - # Step 5: Validate outputs on full dataset - with torch.inference_mode(): - original_outputs = [] - save_load_outputs = [] - - for data in dataset: - # Remove batch dimension for model input - input_tensor = data.squeeze(0) - - original_output = m(input_tensor) - save_load_output = m_save_load(input_tensor) + # Step 4: Validate outputs on full dataset + with torch.inference_mode(): + m2_outputs = [] - original_outputs.append(original_output) - save_load_outputs.append(save_load_output) + for data in dataset: + # Remove batch dimension for model input + input_tensor = data.squeeze(0) + m2_output = m2(input_tensor) + m2_outputs.append(m2_output) - # Concatenate all outputs for comparison - original_result = torch.cat(original_outputs) - save_load_out = torch.cat(save_load_outputs) + # Concatenate all outputs + m2_result = torch.cat(m2_outputs) - self.assertIsNotNone( - original_result, "Original model output should not be None" - ) - self.assertIsNotNone( - save_load_out, "Save/load model output should not be None" - ) + self.assertIsNotNone(m2_result, "Quantized model output should not be None") - torch.testing.assert_close( - original_result, - save_load_out, - msg=f"Save/load recipe should produce identical results for " - f"alpha={alpha}, quant_mode={quant_mode}, device={device}, dtype={input_dtype}", - ) + # Check that model produces reasonable outputs + self.assertFalse( + torch.isnan(m2_result).any(), + f"Quantized model should not produce NaN values for " + f"alpha={alpha}, quant_mode={quant_mode}, device={device}, dtype={input_dtype}", + ) common_utils.instantiate_parametrized_tests(TestSmoothQuant) diff --git a/torchao/prototype/smoothquant/README.md b/torchao/prototype/smoothquant/README.md index c268a83504..21d2738c82 100644 --- a/torchao/prototype/smoothquant/README.md +++ b/torchao/prototype/smoothquant/README.md @@ -1,4 +1,4 @@ -# SmothQuant quantization +# SmoothQuant quantization This is a native PyTorch implementation of the algorithm described in [this paper](https://arxiv.org/abs/2211.10438). In this implementation, weights are smoothed (equalized) and quantized to int8 during quantization. Activations are smoothed and quantized to int8 at runtime. Quantization is done either dynamically or statically. If activations are dynamically quantized, qparams (i.e., scales) are found at runtime while qparams are found during quantization for static quantization. For dynamic quantization, activations are quantized per token. And for static quantization, activations are quantized per tensor. Generally, dynamic quantization produces better accuracy while static quantization has better latency. In both cases, weights and activations are symmetrically quantized. @@ -6,21 +6,21 @@ In this implementation, weights are smoothed (equalized) and quantized to int8 d ## Quick start Run the example code with ```bash -python example.py -m MODLE_ID --device= --quant-mode= +python example.py -m MODEL_ID --device= --quant-mode= # An example python example.py -m meta-llama/Llama-2-7b-hf --device=cuda --quant-mode=dynamic ``` To use the `torch.compile` for speedup, add `--compile`. You may want to export `TORCHINDUCTOR_FREEZING=1` for even better performance. ```bash -TORCHINDUCTOR_FREEZING=1 python example.py -m MODLE_ID --device= --quant-mode= --compile +TORCHINDUCTOR_FREEZING=1 python example.py -m MODEL_ID --device= --quant-mode= --compile ``` To save a quantized model for reuse, specify `--model-save-path` ```bash -python example.py -m MODLE_ID --device= --quant-mode= --model-save-path ./quantized_model.pt +python example.py -m MODEL_ID --device= --quant-mode= --model-save-path ./quantized_model.pt ``` And load it by `--model-load-path` ```bash -python example.py -m MODLE_ID --device= --quant-mode= --model-load-path ./quantized_model.pt +python example.py -m MODEL_ID --device= --quant-mode= --model-load-path ./quantized_model.pt ``` diff --git a/torchao/prototype/smoothquant/__init__.py b/torchao/prototype/smoothquant/__init__.py index 948a99c080..2ea8b5713a 100644 --- a/torchao/prototype/smoothquant/__init__.py +++ b/torchao/prototype/smoothquant/__init__.py @@ -1,15 +1,13 @@ -from .api import ( - SmoothQuantConfig, - insert_smooth_quant_observer_, - load_smooth_quant_recipe, - save_smooth_quant_recipe, +from .api import SmoothQuantConfig +from .core import ( + SmoothQuantObservedLinear, + SmoothQuantObserver, + SmoothQuantStep, ) -from .core import SmoothQuantObservedLinear __all__ = [ - "insert_smooth_quant_observer_", - "load_smooth_quant_recipe", - "save_smooth_quant_recipe", "SmoothQuantConfig", + "SmoothQuantStep", + "SmoothQuantObserver", "SmoothQuantObservedLinear", ] diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index 9397b340b3..627ca8eed3 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -5,18 +5,13 @@ # LICENSE file in the root directory of this source tree. import types from dataclasses import dataclass -from typing import Dict, Optional +from typing import Optional import torch import torchao from torchao.core.config import AOBaseConfig from torchao.dtypes import to_affine_quantized_intx, to_affine_quantized_intx_static -from torchao.prototype.smoothquant.core import ( - SmoothQuantObservedLinear, - SmoothQuantObserver, -) -from torchao.quantization import quantize_ from torchao.quantization.linear_activation_quantized_tensor import ( to_linear_activation_quantized, ) @@ -25,109 +20,52 @@ ) from torchao.quantization.quant_api import ( _linear_extra_repr, - _replace_with_custom_fn_if_matches_filter, ) from torchao.quantization.quant_primitives import MappingType from torchao.quantization.transform_module import ( register_quantize_module_handler, ) from torchao.quantization.utils import _get_per_token_block_size -from torchao.quantization.weight_tensor_linear_activation_quantization import ( - to_weight_tensor_with_linear_activation_quantization_metadata, + +from .core import ( + SmoothQuantObservedLinear, + SmoothQuantObserver, + SmoothQuantStep, ) -def insert_smooth_quant_observer_( - model: torch.nn.Module, alpha: Optional[float] = 0.5, quant_mode: str = "dynamic" -): +@dataclass +class SmoothQuantConfig(AOBaseConfig): """ - Inserts SmoothQuantObserver into Linear layers of a given model. + Configuration for SmoothQuant quantization when passed into quantize_() Args: - model: The model to be modified (in place). Ensure model is on the desired device for calibration + step (SmoothQuantStep): The step for SmoothQuant process + PREPARE: insert SmoothQuant Observers to linear layers + CONVERT: convert the observed linear modules to quantized modules alpha: The alpha value to determine smoothing factor. Factor = 1 if alpha is None, which means - falling back to conventional quantization. + Fall back to conventional quantization if None quant_mode: dynamic or static quantization of activation + smoothing_factor: The smoothing factor for the layer. Acquired from the layer's observer if None. + act_scales: The activation scales for the layer. Acquired from the layer's observer if None. + wei_scales: The weight scales for the layer. Acquired from the layer's observer if None. + set_inductor_config: if True, adjusts `torchinductor` settings to recommended values. """ - _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) - - quant_min, quant_max = -127, 127 - eps = torch.finfo(torch.float32).eps - - def replace_with_observer(layer): - # creates observer and replaces linear layers with observed linear layers - observer = SmoothQuantObserver( - layer.weight, - alpha, - quant_mode, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - ) - return SmoothQuantObservedLinear.from_float(layer, observer) - - _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) - - -def save_smooth_quant_recipe( - model: torch.nn.Module, save_path: str -) -> Dict[str, torch.Tensor]: - """ - Save smoothing_factors, act_scales, and wei_scales for each SmoothQuantObservedLinear layer in the model. - """ - result = {} - - def recurse(module: torch.nn.Module, name: str = ""): - for child_name, child in module.named_children(): - full_name = f"{name}.{child_name}" if name else child_name - - # Apply the analysis function to this layer - if isinstance(child, SmoothQuantObservedLinear): - smoothing_factor, act_scales, wei_scales = child.obs.calculate_qparams() - result[full_name + ".smoothing_factor"] = smoothing_factor - result[full_name + ".act_scales"] = act_scales - result[full_name + ".wei_scales"] = wei_scales - - # Recurse into child modules - recurse(child, full_name) - - recurse(model) - - torch.save(result, save_path) - -def load_smooth_quant_recipe( - model: torch.nn.Module, recipe_path: str, device=None -) -> torch.nn.Module: - recipe = torch.load(recipe_path, weights_only=True) - - def recurse(module: torch.nn.Module, name: str = ""): - if isinstance(module, SmoothQuantObservedLinear): - smoothing_factor = recipe.get(name + ".smoothing_factor", None) - act_scales = recipe.get(name + ".act_scales", None) - wei_scales = recipe.get(name + ".wei_scales", None) - if device is not None: - module.to(device=device) - # act_scales is None for dynamic quantization - if any(x is None for x in (smoothing_factor, wei_scales)): - return module - is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) - wrapper = torch.nn.Sequential(module) - quantize_( - wrapper, - SmoothQuantConfig(smoothing_factor, act_scales, wei_scales), - is_observed_linear, - ) - return wrapper[0] - - mod_new = module - - for child_name, child in module.named_children(): - full_name = f"{name}.{child_name}" if name else child_name - setattr(mod_new, child_name, recurse(child, full_name)) - return mod_new + step: SmoothQuantStep + alpha: Optional[float] = 0.5 + quant_mode: str = "dynamic" + smoothing_factor: Optional[torch.Tensor] = None + act_scales: Optional[torch.Tensor] = None + wei_scales: Optional[torch.Tensor] = None + set_inductor_config: bool = True - recurse(model) + def __post_init__(self): + self.step = self.step.lower() if isinstance(self.step, str) else self.step.value + all_step_values = [s.value for s in SmoothQuantStep] + if self.step not in all_step_values: + raise ValueError(f"{self.step} is not one of {all_step_values}") + assert self.quant_mode in ["static", "dynamic"] class _ActQuantizer: @@ -145,46 +83,66 @@ def dynamic_quantize(self, input): ) def static_quantize(self, input, scale, zero_point): + # Use tensor-wise quantization for static mode + # This matches the expected behavior for SmoothQuant static quantization return to_affine_quantized_intx_static( input, scale, zero_point, - list(input.shape), + (1,) + (1,) * (input.ndim - 1), self.target_dtype, self.quant_min, ) -@dataclass -class SmoothQuantConfig(AOBaseConfig): - """ - Configuration for quantizing linear layers when passed into quantize_() - - Args: - smoothing_factor: The smoothing factor for the layer. Acquired from the layer's observer if None. - act_scales: The activation scales for the layer. Acquired from the layer's observer if None. - wei_scales: The weight scales for the layer. Acquired from the layer's observer if None. - set_inductor_config: if True, adjusts `torchinductor` settings to recommended values. - """ - - smoothing_factor: Optional[torch.Tensor] = None - act_scales: Optional[torch.Tensor] = None - wei_scales: Optional[torch.Tensor] = None - set_inductor_config: bool = True - - @register_quantize_module_handler(SmoothQuantConfig) def _smooth_quant_transform( module: torch.nn.Module, config: SmoothQuantConfig, -): - smoothing_factor = config.smoothing_factor - act_scales = config.act_scales - wei_scales = config.wei_scales +) -> torch.nn.Module: + step = config.step + observed_linear = None + + if step == SmoothQuantStep.PREPARE: + observer = SmoothQuantObserver( + weight=module.weight, + alpha=config.alpha, + quant_mode=config.quant_mode, + quant_min=-127, + quant_max=127, + eps=torch.finfo(torch.float32).eps, + ) + return SmoothQuantObservedLinear.from_float(module, observer) + + elif step == SmoothQuantStep.CONVERT: + if not isinstance(module, SmoothQuantObservedLinear): + print( + f"convert: module is not SmoothQuantObservedLinear, skipping: {type(module)}" + ) + return module + observed_linear = module + else: + raise ValueError(f"Unexpected step: {step}") + + # Convert observed linear to quantized linear if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - observed_linear = module + # Get quantization parameters + if all(x is not None for x in (config.smoothing_factor, config.wei_scales)): + smoothing_factor, act_scales, wei_scales = ( + config.smoothing_factor, + config.act_scales, + config.wei_scales, + ) + weight = observed_linear.weight * smoothing_factor + else: + smoothing_factor, act_scales, wei_scales = ( + observed_linear.obs.calculate_qparams() + ) + weight = observed_linear.obs.weight * smoothing_factor + + # Create new linear layer linear = torch.nn.Linear( observed_linear.in_features, observed_linear.out_features, @@ -194,38 +152,26 @@ def _smooth_quant_transform( ) linear.bias = observed_linear.bias + # Quantize weights target_dtype = torch.int8 - # act_scales is None for dynamic quantization thus not checked - if any(x is None for x in (smoothing_factor, wei_scales)): - factor, x_scale, w_scales = observed_linear.obs.calculate_qparams() - weight = observed_linear.obs.weight * factor - else: - factor, x_scale, w_scales = smoothing_factor, act_scales, wei_scales - weight = observed_linear.weight * factor weight = weight.to(observed_linear.weight.dtype) block_size = (1, weight.size(1)) - wei_zero_points = torch.zeros_like(w_scales, dtype=torch.int64) + wei_zero_points = torch.zeros_like(wei_scales, dtype=torch.int64) + qw = to_affine_quantized_intx_static( - weight, - w_scales, - wei_zero_points, - block_size, - target_dtype, + weight, wei_scales, wei_zero_points, block_size, target_dtype ) - if x_scale is None: - # dynamic quant - qw = to_linear_activation_quantized( - qw, _ActQuantizer(target_dtype).dynamic_quantize - ) - else: - # static quant - x_zero_point = torch.zeros_like(x_scale, dtype=torch.int64) - qw = to_weight_tensor_with_linear_activation_quantization_metadata( - qw, _ActQuantizer(target_dtype).static_quantize, x_scale, x_zero_point - ) + # Apply activation quantization + qw = to_linear_activation_quantized( + qw, _ActQuantizer(target_dtype).dynamic_quantize + ) - qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, factor.to(qw.dtype)) + # Add smoothing factor metadata + qw = to_weight_tensor_with_linear_activation_scale_metadata( + qw, smoothing_factor.to(qw.dtype) + ) linear.weight = torch.nn.Parameter(qw, requires_grad=False) - linear.extra_repr = types.MethodType(_linear_extra_repr, module) + linear.extra_repr = types.MethodType(_linear_extra_repr, linear) + return linear diff --git a/torchao/prototype/smoothquant/core.py b/torchao/prototype/smoothquant/core.py index 3e6c6ea5d5..a631550db7 100644 --- a/torchao/prototype/smoothquant/core.py +++ b/torchao/prototype/smoothquant/core.py @@ -3,15 +3,19 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from enum import Enum from typing import Optional import torch import torch.nn.functional as F from torchao.quantization.observer import AffineQuantizedMinMaxObserver, PerAxis -from torchao.quantization.quant_primitives import ( - MappingType, -) +from torchao.quantization.quant_primitives import MappingType + + +class SmoothQuantStep(str, Enum): + PREPARE = "prepare" + CONVERT = "convert" class SmoothQuantObserver(torch.nn.Module): @@ -39,36 +43,28 @@ def __init__( super().__init__() assert weight.ndim == 2 self.weight = weight - self.inputs = [] self.device = self.weight.device self.alpha = alpha - assert quant_mode in ["static", "dynamic"] self.quant_mode = quant_mode self.quant_min = quant_min self.quant_max = quant_max - self.eps = eps + self.eps = eps or torch.finfo(torch.float32).eps # act.shape = [mb, ic] (reshape if needed), wei.shape = [oc, ic] # *_ic_obs are used to determine smoothing_factor # wei_oc_obs is used to find qparams for quantization self.act_ic_obs = AffineQuantizedMinMaxObserver( - MappingType.SYMMETRIC, - torch.int8, - PerAxis(-1), - eps=eps, + MappingType.SYMMETRIC, torch.int8, PerAxis(-1), eps=self.eps ) self.wei_ic_obs = AffineQuantizedMinMaxObserver( - MappingType.SYMMETRIC, - torch.int8, - PerAxis(-1), - eps=eps, + MappingType.SYMMETRIC, torch.int8, PerAxis(-1), eps=self.eps ) self.wei_oc_obs = AffineQuantizedMinMaxObserver( MappingType.SYMMETRIC, torch.int8, PerAxis(0), - quant_min=quant_min, - quant_max=quant_max, - eps=eps, + quant_min=self.quant_min, + quant_max=self.quant_max, + eps=self.eps, ) self.wei_ic_obs(self.weight) @@ -78,7 +74,7 @@ def forward(self, input: torch.Tensor): return input def calculate_qparams(self): - # 1 Get min/max per IC from observers + # Step 1: Get min/max per input channel (IC) from observers wei_min_per_ic = self.wei_ic_obs.min_val wei_max_per_ic = self.wei_ic_obs.max_val act_min_per_ic = self.act_ic_obs.min_val @@ -89,43 +85,44 @@ def calculate_qparams(self): w_abs_max_per_ic = ( torch.max(torch.abs(wei_min_per_ic), torch.abs(wei_max_per_ic)) + self.eps ) - # 2 calculate the smoothing factor + + # Step 2: Calculate smoothing factor if self.alpha is None: # fall back to conventional quantization if alpha is None - smoothing_factor = torch.ones_like( - x_abs_max_per_ic, - dtype=x_abs_max_per_ic.dtype, - device=x_abs_max_per_ic.device, - ) + smoothing_factor = torch.ones_like(x_abs_max_per_ic) else: smoothing_factor = torch.pow(x_abs_max_per_ic, self.alpha) / torch.pow( w_abs_max_per_ic.to(x_abs_max_per_ic.device), 1 - self.alpha ) - # 3 apply smoothing factor to activations and find scales for static quantization + + # Step 3: Calculate activation scales for static quantization act_scales = None if self.quant_mode == "static": - act_min_per_ic_new = act_min_per_ic / smoothing_factor.reshape( + act_min_new = act_min_per_ic / smoothing_factor.reshape( act_min_per_ic.shape ) - act_max_per_ic_new = act_max_per_ic / smoothing_factor.reshape( + act_max_new = act_max_per_ic / smoothing_factor.reshape( act_max_per_ic.shape ) - min_val_per_tensor = torch.min(act_min_per_ic_new) - max_val_per_tensor = torch.max(act_max_per_ic_new) - min_val_neg = torch.min( - min_val_per_tensor, torch.zeros_like(min_val_per_tensor) - ) - max_val_pos = torch.max( - max_val_per_tensor, torch.zeros_like(max_val_per_tensor) - ) - max_val_pos = torch.max(-min_val_neg, max_val_pos) - act_scale = max_val_pos / (float(self.quant_max - self.quant_min) / 2) - act_scales = act_scale.to(self.device) - # 4 update weight and find scales + + # Calculate global scale (scalar) + global_min = torch.min(act_min_new) + global_max = torch.max(act_max_new) + abs_max = torch.max(torch.abs(global_min), torch.abs(global_max)) + act_scale = abs_max / (float(self.quant_max - self.quant_min) / 2) + + # Create scalar tensor for tensor-wise quantization + act_scales = act_scale.reshape(()).to(self.device) # Ensure scalar shape + + # Step 4: Update weight and find scales self.wei_oc_obs(self.weight * smoothing_factor.to(self.device)) wei_scales, _ = self.wei_oc_obs.calculate_qparams() - # 5 return results - return smoothing_factor.to(self.device), act_scales, wei_scales.to(self.device) + + return ( + smoothing_factor.to(self.device), + act_scales, + wei_scales.to(self.device), + ) class SmoothQuantObservedLinear(torch.nn.Linear): @@ -133,27 +130,25 @@ def __init__( self, in_features: int, out_features: int, - bias: bool, obs: SmoothQuantObserver, + bias: bool = True, device=None, dtype=None, ): super().__init__(in_features, out_features, bias, device, dtype) - assert isinstance(obs, SmoothQuantObserver) self.obs = obs def forward(self, input: torch.Tensor): input = self.obs(input) - output = F.linear(input, self.weight, self.bias) - return output + return F.linear(input, self.weight, self.bias) @classmethod def from_float(cls, float_linear: torch.nn.Linear, obs: SmoothQuantObserver): observed_linear = cls( float_linear.in_features, float_linear.out_features, - float_linear.bias is not None, obs, + float_linear.bias is not None, device=float_linear.weight.device, dtype=float_linear.weight.dtype, ) diff --git a/torchao/prototype/smoothquant/example.py b/torchao/prototype/smoothquant/example.py index de1e4ed93e..b24740c707 100644 --- a/torchao/prototype/smoothquant/example.py +++ b/torchao/prototype/smoothquant/example.py @@ -15,8 +15,6 @@ from torchao.prototype.smoothquant import ( SmoothQuantConfig, - SmoothQuantObservedLinear, - insert_smooth_quant_observer_, ) from torchao.quantization import quantize_ @@ -137,8 +135,11 @@ def wikitext2_ppl( print(f"Time to load model: {time.time() - t0:.02f} seconds") print("running calibration") t0 = time.time() - # insert observers to find average magnitude and calculate scales - insert_smooth_quant_observer_(model, alpha, quant_mode) + # Step 1: Insert observers to find average magnitude and calculate scales + config = SmoothQuantConfig(step="prepare", alpha=alpha, quant_mode=quant_mode) + quantize_(model, config) + + # Step 2: Calibration calibration_data = get_calib_dataset( tokenizer=tokenizer, n_samples=calibration_size, block_size=sequence_length ) @@ -147,10 +148,11 @@ def wikitext2_ppl( batch.to("cpu") print(f"time for calibration: {time.time() - t0:.02f} seconds") - is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) + # Step 3: Convert to quantized model print(f"running SmoothQuant with {quant_mode} quantization") t0 = time.time() - quantize_(model, SmoothQuantConfig(), is_observed_linear) + config.step = "convert" + quantize_(model, config) print(f"time for quantization: {time.time() - t0:.02f} seconds") if model_save_path is not None: print(f"Saving quantized model to {model_save_path}") @@ -239,7 +241,7 @@ def wikitext2_ppl( args.quant_mode, args.calibration_samples, args.device, - args.precision, + precision_dtype, args.seq_len, args.compile, args.model_load_path,