Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion torchao/prototype/smoothquant/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,9 @@ def _smooth_quant_transform(
qw = quant_mod.weight

# Add smoothing factor metadata
use_inv_scale = qw.device.type == "cpu"
qw = to_weight_tensor_with_linear_activation_scale_metadata(
qw, smoothing_factor.to(qw.dtype)
qw, smoothing_factor.to(qw.dtype), use_inv_scale
)
linear.weight = torch.nn.Parameter(qw, requires_grad=False)
linear.extra_repr = types.MethodType(_linear_extra_repr, linear)
Expand Down
24 changes: 18 additions & 6 deletions torchao/quantization/linear_activation_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ class WeightTensorWithLinearActivationScaleMetadata(TorchAOBaseTensor):
"""
Tensor subclass that wraps a weight tensor and provides metadata for linear activation scaling.
Right now we hardcode how we apply the scale:
scaled_linear_act = input_act / scale
out = F.linear(scaled_linear_act, weight, ...)
scaled_linear_act = input_act / scale
# or scaled_linear_act = input_act * inv_scale
out = F.linear(scaled_linear_act, weight, ...)

We can generalize this to accept a function as well if needed.

Expand All @@ -31,12 +32,13 @@ class WeightTensorWithLinearActivationScaleMetadata(TorchAOBaseTensor):
"""

tensor_data_names = ["original_weight_tensor", "scale"]
tensor_attribute_names = []
tensor_attribute_names = ["use_inv_scale"]

def __new__(
cls,
original_weight_tensor: torch.Tensor,
scale: torch.Tensor,
use_inv_scale: bool = False,
):
kwargs = {}
dtype = original_weight_tensor.dtype
Expand All @@ -50,9 +52,12 @@ def __init__(
self,
original_weight_tensor: torch.Tensor,
scale: torch.Tensor,
use_inv_scale: bool = False,
):
self.original_weight_tensor = original_weight_tensor
self.scale = scale
self.use_inv_scale = use_inv_scale
self.inv_scale = 1.0 / scale if use_inv_scale else None

def _quantization_type(self):
return f"{self.__class__}"
Expand All @@ -63,8 +68,12 @@ def _quantized_linear_op(
):
original_weight_tensor = weight_tensor.original_weight_tensor
scale = weight_tensor.scale
inv_scale = weight_tensor.inv_scale
use_inv_scale = weight_tensor.use_inv_scale
# Note: we can make this function configurable as well
scaled_input_act = input_tensor / scale
scaled_input_act = (
input_tensor * inv_scale if use_inv_scale else input_tensor / scale
)
return torch.nn.functional.linear(
scaled_input_act, original_weight_tensor, bias
)
Expand All @@ -74,8 +83,9 @@ def from_float(
cls,
input_float: torch.Tensor,
scale: torch.Tensor,
use_inv_scale: bool = False,
):
return cls(input_float, scale)
return cls(input_float, scale, use_inv_scale)


implements = WeightTensorWithLinearActivationScaleMetadata.implements
Expand Down Expand Up @@ -103,7 +113,9 @@ def _(func, types, args, kwargs):
def _(func, types, args, kwargs):
self = args[0]
new = self.__class__(
func(self.original_weight_tensor, *args[1:], **kwargs), self.scale
func(self.original_weight_tensor, *args[1:], **kwargs),
self.scale,
self.use_inv_scale,
)
return return_and_correct_aliasing(func, args, kwargs, new)

Expand Down
Loading