diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index 9f78c49fb8..12b9644438 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -117,8 +117,9 @@ def _smooth_quant_transform( qw = quant_mod.weight # Add smoothing factor metadata + use_inv_scale = qw.device.type == "cpu" qw = to_weight_tensor_with_linear_activation_scale_metadata( - qw, smoothing_factor.to(qw.dtype) + qw, smoothing_factor.to(qw.dtype), use_inv_scale ) linear.weight = torch.nn.Parameter(qw, requires_grad=False) linear.extra_repr = types.MethodType(_linear_extra_repr, linear) diff --git a/torchao/quantization/linear_activation_scale.py b/torchao/quantization/linear_activation_scale.py index fd61d07d9a..f7878cd295 100644 --- a/torchao/quantization/linear_activation_scale.py +++ b/torchao/quantization/linear_activation_scale.py @@ -20,8 +20,9 @@ class WeightTensorWithLinearActivationScaleMetadata(TorchAOBaseTensor): """ Tensor subclass that wraps a weight tensor and provides metadata for linear activation scaling. Right now we hardcode how we apply the scale: - scaled_linear_act = input_act / scale - out = F.linear(scaled_linear_act, weight, ...) + scaled_linear_act = input_act / scale + # or scaled_linear_act = input_act * inv_scale + out = F.linear(scaled_linear_act, weight, ...) We can generalize this to accept a function as well if needed. @@ -31,12 +32,13 @@ class WeightTensorWithLinearActivationScaleMetadata(TorchAOBaseTensor): """ tensor_data_names = ["original_weight_tensor", "scale"] - tensor_attribute_names = [] + tensor_attribute_names = ["use_inv_scale"] def __new__( cls, original_weight_tensor: torch.Tensor, scale: torch.Tensor, + use_inv_scale: bool = False, ): kwargs = {} dtype = original_weight_tensor.dtype @@ -50,9 +52,12 @@ def __init__( self, original_weight_tensor: torch.Tensor, scale: torch.Tensor, + use_inv_scale: bool = False, ): self.original_weight_tensor = original_weight_tensor self.scale = scale + self.use_inv_scale = use_inv_scale + self.inv_scale = 1.0 / scale if use_inv_scale else None def _quantization_type(self): return f"{self.__class__}" @@ -63,8 +68,12 @@ def _quantized_linear_op( ): original_weight_tensor = weight_tensor.original_weight_tensor scale = weight_tensor.scale + inv_scale = weight_tensor.inv_scale + use_inv_scale = weight_tensor.use_inv_scale # Note: we can make this function configurable as well - scaled_input_act = input_tensor / scale + scaled_input_act = ( + input_tensor * inv_scale if use_inv_scale else input_tensor / scale + ) return torch.nn.functional.linear( scaled_input_act, original_weight_tensor, bias ) @@ -74,8 +83,9 @@ def from_float( cls, input_float: torch.Tensor, scale: torch.Tensor, + use_inv_scale: bool = False, ): - return cls(input_float, scale) + return cls(input_float, scale, use_inv_scale) implements = WeightTensorWithLinearActivationScaleMetadata.implements @@ -103,7 +113,9 @@ def _(func, types, args, kwargs): def _(func, types, args, kwargs): self = args[0] new = self.__class__( - func(self.original_weight_tensor, *args[1:], **kwargs), self.scale + func(self.original_weight_tensor, *args[1:], **kwargs), + self.scale, + self.use_inv_scale, ) return return_and_correct_aliasing(func, args, kwargs, new)