Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import unittest
import tempfile
from packaging import version

import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)

from torchao.quantization import (
Int4WeightOnlyConfig,
quantize_,
)
from torchao.quantization.quantize_.common import SupportsActivationPreScaling
from torchao.quantization.utils import compute_error
from torchao.utils import (
torch_version_at_least,
)

try:
import torch_npu
except ImportError:
torch_npu = None

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch provide Autoload mechinasm, so we do not need to import it explicitly.


def get_config(group_size):
return Int4WeightOnlyConfig(
group_size=group_size,
int4_packing_format="plain_int32",
)


@unittest.skipIf(not torch_version_at_least("2.7.1"), "Need pytorch 2.7.1+")
@unittest.skipIf(torch_npu is None, "torch_npu is not available")
@unittest.skipIf(not torch_npu.npu.is_available(), "NPU not available")
Copy link

@FFFrog FFFrog Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@unittest.skipIf(torch_npu is None, "torch_npu is not available")
@unittest.skipIf(not torch_npu.npu.is_available(), "NPU not available")
@unittest.skipIf(torch.accelerator.current_accelerator(True).type == "npu" and torch.accelerator.is_available(), "NPU not available")

@unittest.skipIf(
version.parse(torch_npu.__version__) < version.parse("2.7.1rc1"),
"Need torch_npu 2.7.1rc1+",
)
Comment on lines 38 to 42
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove it because there are some strcit version mapping between PyTorch and Torch_NPU

class Int4PlainInt32TensorNPU(TestCase):

@parametrize("device", ["npu"])
@parametrize(
"sizes",
[
((128,), 256, 128),
((32, 128), 512, 128),
((2, 32, 128), 256, 128),
],
)
@parametrize("dtype", [torch.float16, torch.bfloat16])
@parametrize("group_size", [32, 64])
def test_linear(self, device, sizes, dtype, group_size):
M, N, K = sizes
input = torch.randn(*M, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
orig_output = linear(input)
quantize_(linear, get_config(group_size))
quantized_output = linear(input)
self.assertTrue(compute_error(orig_output, quantized_output) > 10)

@parametrize("device", ["npu"])
@parametrize("dtype", [torch.float16, torch.bfloat16])
def test_module_path(self, device, dtype):
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
quantize_(linear, get_config(group_size=64))
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Int4PlainInt32TensorNPU'>",
)

with tempfile.NamedTemporaryFile() as f:
torch.save(linear.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
self.assertEqual(
str(type(state_dict["weight"])),
"<class 'torchao.quantization.Int4PlainInt32TensorNPU'>",
)

@parametrize("device", ["npu"])
@parametrize("dtype", [torch.float16, torch.bfloat16])
def test_activation_prescaling(self, device, dtype):
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, get_config(64))
qw = linear.weight
assert isinstance(
qw, SupportsActivationPreScaling
), "Expected int4 tensor supports activation prescaling"
assert qw.act_pre_scale is None, "Default `act_pre_scale` is None"
_ACT_PRE_SCALE = 2
qw.act_pre_scale = _ACT_PRE_SCALE
quantized = linear(input)

# making sure activation pre scaling is successfully applied to the activation
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 10)


instantiate_parametrized_tests(Int4PlainInt32TensorNPU)

if __name__ == "__main__":
run_tests()
1 change: 1 addition & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
Int4MarlinSparseTensor,
Int4OpaqueTensor,
Int4PlainInt32Tensor,
Int4PlainInt32TensorNPU,
Int4PreshuffledTensor,
Int4Tensor,
Int4TilePackedTo4dTensor,
Expand Down
15 changes: 11 additions & 4 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
Int4OpaqueTensor,
Int4PackingFormat,
Int4PlainInt32Tensor,
Int4PlainInt32TensorNPU,
Int4PreshuffledTensor,
Int4Tensor,
Int4TilePackedTo4dTensor,
Expand Down Expand Up @@ -1163,10 +1164,16 @@ def _int4_weight_only_quantize_tensor(weight, config):
)
return new_weight
elif int4_packing_format == Int4PackingFormat.PLAIN_INT32:
new_weight = Int4PlainInt32Tensor.from_hp(
weight,
block_size,
)
if weight.device.type == "npu":
new_weight = Int4PlainInt32TensorNPU.from_hp(
weight,
block_size,
)
else:
new_weight = Int4PlainInt32Tensor.from_hp(
weight,
block_size,
)
return new_weight
elif int4_packing_format == Int4PackingFormat.MARLIN_SPARSE:
new_weight = Int4MarlinSparseTensor.from_hp(
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from .int4.int4_plain_int32_tensor import (
Int4PlainInt32Tensor,
)
from .int4.int4_plain_int32_tensor_npu import (
Int4PlainInt32TensorNPU,
)
from .int4.int4_preshuffled_tensor import (
Int4PreshuffledTensor,
)
Expand All @@ -36,6 +39,7 @@
"Int4PreshuffledTensor",
"Int4MarlinSparseTensor",
"Int4PlainInt32Tensor",
"Int4PlainInt32TensorNPU",
"Int4TilePackedTo4dTensor",
"Float8Tensor",
"QuantizeTensorToFloat8Kwargs",
Expand Down
Loading