Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import unittest
import tempfile
from packaging import version

import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)

from torchao.quantization import (
Int4WeightOnlyConfig,
quantize_,
)
from torchao.quantization.quantize_.common import SupportsActivationPreScaling
from torchao.quantization.utils import compute_error
from torchao.utils import (
torch_version_at_least,
)


def get_config(group_size):
return Int4WeightOnlyConfig(
group_size=group_size,
int4_packing_format="plain_int32",
)


@unittest.skipIf(not torch_version_at_least("2.7.1"), "Need pytorch 2.7.1+")
@unittest.skipIf(
torch.accelerator.current_accelerator(True).type == "npu"
and torch.accelerator.is_available(),
"NPU not available",
)
class Int4PlainInt32TensorNPU(TestCase):

@parametrize("device", ["npu"])
@parametrize(
"sizes",
[
((128,), 256, 128),
((32, 128), 512, 128),
((2, 32, 128), 256, 128),
],
)
@parametrize("dtype", [torch.float16, torch.bfloat16])
@parametrize("group_size", [32, 64])
def test_linear(self, device, sizes, dtype, group_size):
M, N, K = sizes
input = torch.randn(*M, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
orig_output = linear(input)
quantize_(linear, get_config(group_size))
quantized_output = linear(input)
self.assertTrue(compute_error(orig_output, quantized_output) > 10)

@parametrize("device", ["npu"])
@parametrize("dtype", [torch.float16, torch.bfloat16])
def test_module_path(self, device, dtype):
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
quantize_(linear, get_config(group_size=64))
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Int4PlainInt32TensorNPU'>",
)

with tempfile.NamedTemporaryFile() as f:
torch.save(linear.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
self.assertEqual(
str(type(state_dict["weight"])),
"<class 'torchao.quantization.Int4PlainInt32TensorNPU'>",
)

@parametrize("device", ["npu"])
@parametrize("dtype", [torch.float16, torch.bfloat16])
def test_activation_prescaling(self, device, dtype):
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, get_config(64))
qw = linear.weight
assert isinstance(
qw, SupportsActivationPreScaling
), "Expected int4 tensor supports activation prescaling"
assert qw.act_pre_scale is None, "Default `act_pre_scale` is None"
_ACT_PRE_SCALE = 2
qw.act_pre_scale = _ACT_PRE_SCALE
quantized = linear(input)

# making sure activation pre scaling is successfully applied to the activation
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 10)


instantiate_parametrized_tests(Int4PlainInt32TensorNPU)

if __name__ == "__main__":
run_tests()
1 change: 1 addition & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
Int4MarlinSparseTensor,
Int4OpaqueTensor,
Int4PlainInt32Tensor,
Int4PlainInt32TensorNPU,
Int4PreshuffledTensor,
Int4Tensor,
Int4TilePackedTo4dTensor,
Expand Down
15 changes: 11 additions & 4 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
Int4OpaqueTensor,
Int4PackingFormat,
Int4PlainInt32Tensor,
Int4PlainInt32TensorNPU,
Int4PreshuffledTensor,
Int4Tensor,
Int4TilePackedTo4dTensor,
Expand Down Expand Up @@ -1163,10 +1164,16 @@ def _int4_weight_only_quantize_tensor(weight, config):
)
return new_weight
elif int4_packing_format == Int4PackingFormat.PLAIN_INT32:
new_weight = Int4PlainInt32Tensor.from_hp(
weight,
block_size,
)
if weight.device.type == "npu":
new_weight = Int4PlainInt32TensorNPU.from_hp(
weight,
block_size,
)
else:
new_weight = Int4PlainInt32Tensor.from_hp(
weight,
block_size,
)
return new_weight
elif int4_packing_format == Int4PackingFormat.MARLIN_SPARSE:
new_weight = Int4MarlinSparseTensor.from_hp(
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from .int4.int4_plain_int32_tensor import (
Int4PlainInt32Tensor,
)
from .int4.int4_plain_int32_tensor_npu import (
Int4PlainInt32TensorNPU,
)
from .int4.int4_preshuffled_tensor import (
Int4PreshuffledTensor,
)
Expand All @@ -36,6 +39,7 @@
"Int4PreshuffledTensor",
"Int4MarlinSparseTensor",
"Int4PlainInt32Tensor",
"Int4PlainInt32TensorNPU",
"Int4TilePackedTo4dTensor",
"Float8Tensor",
"QuantizeTensorToFloat8Kwargs",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional

import torch

from torchao.quantization.quant_primitives import (
MappingType,
choose_qparams_affine,
quantize_affine,
)
from torchao.utils import (
TorchAOBaseTensor,
)

__all__ = ["Int4PlainInt32TensorNPU"]

aten = torch.ops.aten


class Int4PlainInt32TensorNPU(TorchAOBaseTensor):
"""
int4 weight-only quantization on Ascend NPU backend (groupwise quantization only)

Tensor Attributes:
qdata: (N, K/8), packed int4 weight, the data type is int32 here with 8*int4, the original dtype can be float16 or bfloat16
scale: (K/group_size, N), dtype is the same as the original Tensor type (float16 or bfloat16)
zero_point: (K/group_size, N), dtype is the same as the original Tensor type (float16 or bfloat16)

Non-Tensor Attributes:
block_size: the block size for quantization, representing the granularity
shape: shape of the original Tensor

Optional Tensor Data Attributes:
act_pre_scale (Optional[Tensor]): Optional scale for activation Tensor, if present,
we'll multiply activation Tensor with act_pre_scale before applying dynamic
quantization to activation or running quantized mm op

"""

tensor_data_names = ["qdata", "scale", "zero_point"]
tensor_attribute_names = ["block_size", "shape"]
optional_tensor_data_names = ["act_pre_scale"]

def __new__(
cls,
qdata,
scale,
zero_point,
block_size,
shape,
act_pre_scale: Optional[torch.Tensor] = None,
):
kwargs = {}
kwargs["device"] = qdata.device
kwargs["dtype"] = scale.dtype
kwargs["requires_grad"] = False
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
qdata,
scale,
zero_point,
block_size,
shape,
act_pre_scale: Optional[torch.Tensor] = None,
):
self.qdata = qdata
self.scale = scale
self.zero_point = zero_point
self.block_size = block_size
self.act_pre_scale = act_pre_scale

def _quantization_type(self):
s = f"shape={self.shape}, block_size={self.block_size}, device={self.device}"
if self.act_pre_scale is not None:
s += f", act_pre_scale.shape={self.act_pre_scale.shape}"
return s

@classmethod
def from_hp(
cls,
w: torch.Tensor,
block_size: List[int],
):
assert w.ndim == 2 and w.device.type == "npu", (
f"Expecting 2D tensor on NPU, but got: {w.shape} on {w.device.type}"
)
assert len(block_size) == w.ndim
assert w.dtype in [torch.float16, torch.bfloat16], (
f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}"
)

original_shape = w.shape
mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int32
quant_min = -8
quant_max = 7
eps = 1e-6
scale_dtype = w.dtype
zero_point_dtype = w.dtype

scale, zero_point = choose_qparams_affine(
w,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
)

int_data = quantize_affine(
w,
block_size,
scale,
zero_point,
target_dtype,
quant_min,
quant_max,
)

assert int_data.dtype == torch.int32, (
f"torch.ops.npu.npu_convert_weight_to_int4pack expects `int32` dtype"
)

assert int_data.shape[-1] % 8 == 0, (
f"torch.ops.npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got {int_data.shape[-1]}"
)

packed_weight = torch.ops.npu.npu_convert_weight_to_int4pack(
int_data.contiguous(), 0
)

scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)

return Int4PlainInt32TensorNPU(
packed_weight,
scale.transpose(0, 1).contiguous(),
zero_point.transpose(0, 1).contiguous(),
block_size,
original_shape,
act_pre_scale=None,
)


implements = Int4PlainInt32TensorNPU.implements
implements_torch_function = Int4PlainInt32TensorNPU.implements_torch_function


@implements(aten.linear.default)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):

input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)

assert input_tensor.device.type == "npu", (
f"For NPU device only but got: {input_tensor.device.type}"
)
assert isinstance(weight_tensor, Int4PlainInt32TensorNPU), (
f"Expected weight_tensor to be Int4PlainInt32NPUTensor, got: {type(weight_tensor)}"
)
assert weight_tensor.block_size[0] == 1, (
f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
)
assert input_tensor.shape[-1] == weight_tensor.shape[1], (
f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}"
)

if weight_tensor.act_pre_scale is not None:
input_tensor = input_tensor * weight_tensor.act_pre_scale

act_mat = input_tensor
packed_weight = weight_tensor.qdata
scale = weight_tensor.scale
zero_point = weight_tensor.zero_point

orig_act_size = act_mat.shape
orig_dtype = act_mat.dtype

# dtype alignment
if act_mat.dtype == torch.float16:
scale = scale.to(torch.float16)
zero_point = zero_point.to(torch.float16)
if bias is not None:
bias = bias.to(torch.float16)
elif act_mat.dtype == torch.bfloat16:
scale = scale.to(torch.bfloat16)
zero_point = zero_point.to(torch.bfloat16)
if bias is not None:
bias = bias.to(torch.float32)

# reshape to 2D
act_mat = act_mat.reshape(-1, act_mat.shape[-1])

# groupwise int4 quantization
groupsize = weight_tensor.block_size[1]

y = torch.ops.npu.npu_weight_quant_batchmatmul(
x=act_mat,
weight=packed_weight.contiguous().transpose(-1, -2),
antiquant_scale=scale,
antiquant_offset=zero_point,
antiquant_group_size=groupsize,
bias=bias,
)

# remove out_feature padding
assert weight_tensor.ndim == 2
orig_out_features = weight_tensor.shape[-2]
y = y[:, :orig_out_features]
y = y.reshape(*orig_act_size[:-1], orig_out_features)

return y.to(orig_dtype)


Int4PlainInt32TensorNPU.__module__ = "torchao.quantization"

# Allow a model with Int4PlainInt32TensorNPU weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([Int4PlainInt32TensorNPU])