Skip to content

Add IntxUnpackedTensor #2732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

from torchao.quantization import (
IntxWeightOnlyConfig,
quantize_,
)
from torchao.quantization.granularity import PerGroup
from torchao.quantization.utils import compute_error
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
)


@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
class TestIntxUnpackedTensor(TestCase):
def setUp(self):
self.config = IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=PerGroup(32),
VERSION=2,
)

def test_linear(self):
dtype = torch.bfloat16
device = "cpu"
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, self.config)
quantized = linear(input)
error = compute_error(original, quantized)
self.assertTrue(error > 20)

def test_slice(self):
dtype = torch.bfloat16
device = "cpu"
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)

dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device)
dummy1.weight = torch.nn.Parameter(
dummy.weight.narrow(0, 0, 64), requires_grad=False
)

dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
dummy2.weight = torch.nn.Parameter(
dummy.weight.narrow(1, 0, 128), requires_grad=False
)

quantize_(dummy, self.config)
weight1 = dummy.weight.narrow(0, 0, 64)
weight2 = dummy.weight.narrow(1, 0, 128)

self.assertEqual(weight1.int_data, dummy.weight.int_data.narrow(0, 0, 64))
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, 64))

self.assertEqual(weight2.int_data, dummy.weight.int_data.narrow(1, 0, 128))
self.assertEqual(weight2.scale, dummy.weight.scale.narrow(1, 0, 4))

# check for sliced weight, before and after float8 quantization
# does not differ too much
input = torch.randn(2, 256, dtype=dtype, device=device)
res_ref = dummy1(input)
dummy.weight = torch.nn.Parameter(weight1, requires_grad=False)
res = dummy(input)
assert compute_error(res, res_ref) > 20

input = torch.randn(2, 128, dtype=dtype, device=device)
res_ref = dummy2(input)
dummy.weight = torch.nn.Parameter(weight2, requires_grad=False)
res = dummy(input)
assert compute_error(res, res_ref) > 15

def test_slice_and_copy_(self):
device = "cpu"
l = torch.nn.Linear(1024, 1024).to(device).to(torch.bfloat16)
l.weight = torch.nn.Parameter(
torch.zeros(1024, 1024, dtype=torch.bfloat16, device=device)
)
quantize_(l, self.config)
param = l.weight
param_data = param.data
param_data = param_data.narrow(0, 0, 512)
assert param.data.int_data.data_ptr() == param_data.int_data.data_ptr()
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()
assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr()
orig_value = param.data.int_data[0][0].item()

# dummy_l has random input (shouldn't be 0)
dummy_l = torch.nn.Linear(1024, 1024).to(device).to(torch.bfloat16)
quantize_(dummy_l, self.config)
quantized = dummy_l.weight
quantized = quantized.narrow(0, 0, 512)

param_data.copy_(quantized)

# making sure param.data is updated
assert param.data.int_data[0][0] != orig_value

def test_to_dtype(self):
activations_bf16 = torch.randn(1, 128, dtype=torch.bfloat16)
activations_fp32 = torch.randn(1, 128, dtype=torch.float32)
activations_fp16 = torch.randn(1, 128, dtype=torch.float16)

linear = torch.nn.Linear(128, 256)
quantize_(linear, self.config)

linear.to(dtype=torch.float16)
linear(activations_fp16)

linear.to(dtype=torch.float32)
linear(activations_fp32)

linear.to(dtype=torch.bfloat16)
linear(activations_bf16)

def test_export(self):
linear = torch.nn.Linear(128, 256)
quantize_(linear, self.config)
ep = torch.export.export(linear, (torch.randn(1, 128),))
assert "torch.ops.torchao.dequantize_affine.default" in ep.graph_module.code


if __name__ == "__main__":
run_tests()
43 changes: 36 additions & 7 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
Float8Tensor,
Int4PreshuffledTensor,
Int4Tensor,
IntxUnpackedTensor,
QuantizeTensorToFloat8Kwargs,
)
from torchao.quantization.transform_module import (
Expand Down Expand Up @@ -2060,6 +2061,8 @@ class IntxWeightOnlyConfig(AOBaseConfig):
mapping_type: MappingType = MappingType.SYMMETRIC
scale_dtype: Optional[torch.dtype] = None
layout: Layout = QDQLayout()
packing_format: PackingFormat = PackingFormat.UNPACKED
VERSION: int = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we updated the name to version


def __post_init__(self):
assert TORCH_VERSION_AT_LEAST_2_6, "IntxWeightOnlyConfig requires torch 2.6+"
Expand All @@ -2078,16 +2081,13 @@ def __post_init__(self):
)


@register_quantize_module_handler(IntxWeightOnlyConfig)
def _intx_weight_only_transform(
module: torch.nn.Module, config: IntxWeightOnlyConfig
) -> torch.nn.Module:
weight = module.weight
def _intx_weight_only_quantize_tensor(weight, config):
weight_dtype = config.weight_dtype
granularity = config.granularity
mapping_type = config.mapping_type
scale_dtype = config.scale_dtype
layout = config.layout
packing_format = config.packing_format

assert weight.dim() == 2, (
f"IntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}"
Expand All @@ -2102,11 +2102,28 @@ def _intx_weight_only_transform(
else:
raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}")

block_size = (1, group_size)

if config.VERSION == 2:
if config.packing_format == PackingFormat.UNPACKED:
new_weight = IntxUnpackedTensor.from_float(
weight,
block_size,
weight_dtype,
mapping_type=mapping_type,
)
if scale_dtype is not None and scale_dtype != weight.dtype:
new_weight.scale = new_weight.scale.to(scale_dtype).to(weight.dtype)
return new_weight
else:
raise ValueError(f"Unsupported packing format: {packing_format}")

# Version 1
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype]
weight = to_affine_quantized_intx(
input_float=weight,
mapping_type=mapping_type,
block_size=(1, group_size),
block_size=block_size,
target_dtype=torch.int8,
quant_min=quant_min,
quant_max=quant_max,
Expand All @@ -2116,7 +2133,19 @@ def _intx_weight_only_transform(
zero_point_domain=ZeroPointDomain.INT,
_layout=layout,
)
module.weight = torch.nn.Parameter(weight, requires_grad=False)


@register_quantize_module_handler(IntxWeightOnlyConfig)
def _intx_weight_only_transform(
module: torch.nn.Module, config: IntxWeightOnlyConfig
) -> torch.nn.Module:
assert hasattr(module, "weight"), (
"applying intx weight only quant requires module to have weight attribute"
+ " but {module} does not have one"
)
new_weight = _intx_weight_only_quantize_tensor(module.weight, config)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


Expand Down
5 changes: 5 additions & 0 deletions torchao/quantization/quantize_/common/packing_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ class PackingFormat(str, Enum):
preshuffled is referring to the preshuffled format used by fbgemm kernels
"""
PRESHUFFLED = "preshuffled"

"""
Unpacked means the subbyte quantized data is stored as int8
Copy link
Contributor

@jerryzh168 jerryzh168 Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this int only? we could be more specific and say UnpackedToInt8

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can make the format UNPACKED_TO_INT8

"""
UNPACKED = "unpacked"
4 changes: 4 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
from .int4.int4_tensor import (
Int4Tensor,
)
from .intx.intx_unpacked_tensor import (
IntxUnpackedTensor,
)

__all__ = [
"Int4Tensor",
"Int4PreshuffledTensor",
"Float8Tensor",
"QuantizeTensorToFloat8Kwargs",
"IntxUnpackedTensor",
]
5 changes: 5 additions & 0 deletions torchao/quantization/quantize_/workflows/intx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .intx_unpacked_tensor import IntxUnpackedTensor

__all__ = [
"IntxUnpackedTensor",
]
Loading
Loading