Skip to content

Commit 93948a4

Browse files
committed
up
1 parent d7dbb84 commit 93948a4

File tree

5 files changed

+37
-72
lines changed

5 files changed

+37
-72
lines changed

test/quantization/quantize_/workflows/intx/test_intx_unpacked_tensor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,20 @@ def setUp(self):
2929
self.config = IntxWeightOnlyConfig(
3030
weight_dtype=torch.int4,
3131
granularity=PerGroup(32),
32-
VERSION=2,
32+
version=2,
3333
)
3434

35+
def test_embedding(self):
36+
dtype = torch.bfloat16
37+
device = "cpu"
38+
input = torch.randint(low=0, high=128, size=(10,), device=device)
39+
embedding = torch.nn.Embedding(128, 256, dtype=dtype, device=device)
40+
original = embedding(input)
41+
quantize_(embedding, self.config)
42+
quantized = embedding(input)
43+
error = compute_error(original, quantized)
44+
self.assertTrue(error > 20)
45+
3546
def test_linear(self):
3647
dtype = torch.bfloat16
3748
device = "cpu"

torchao/experimental/tests/test_embedding_xbit_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def test_shared_embedding(self):
185185
# Check the shared_embedding and linear ops use the same lifted weight
186186
expected_lines = [
187187
"torch.ops.torchao._shared_embedding_4bit.default",
188-
"torch.ops.torchao._linear_8bit_act_4bit_weight.defaul",
188+
"torch.ops.torchao._linear_8bit_act_4bit_weight.default",
189189
]
190190
for line in expected_lines:
191191
FileCheck().check_count(line, 1, exactly=True).run(

torchao/quantization/quant_api.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,10 @@ def _linear_extra_repr(self):
564564
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}"
565565

566566

567+
def _embedding_extra_repr(self):
568+
return f"num_embeddings={self.weight.shape[0]}, embedding_dim={self.weight.shape[1]}, weight={_quantization_type(self.weight)}"
569+
570+
567571
def _get_linear_subclass_inserter(
568572
constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs
569573
):
@@ -2061,8 +2065,8 @@ class IntxWeightOnlyConfig(AOBaseConfig):
20612065
mapping_type: MappingType = MappingType.SYMMETRIC
20622066
scale_dtype: Optional[torch.dtype] = None
20632067
layout: Layout = QDQLayout()
2064-
packing_format: PackingFormat = PackingFormat.UNPACKED
2065-
VERSION: int = 1
2068+
packing_format: PackingFormat = PackingFormat.UNPACKED_TO_INT8
2069+
version: int = 1
20662070

20672071
def __post_init__(self):
20682072
assert TORCH_VERSION_AT_LEAST_2_6, "IntxWeightOnlyConfig requires torch 2.6+"
@@ -2104,9 +2108,9 @@ def _intx_weight_only_quantize_tensor(weight, config):
21042108

21052109
block_size = (1, group_size)
21062110

2107-
if config.VERSION == 2:
2108-
if config.packing_format == PackingFormat.UNPACKED:
2109-
new_weight = IntxUnpackedTensor.from_float(
2111+
if config.version == 2:
2112+
if config.packing_format == PackingFormat.UNPACKED_TO_INT8:
2113+
new_weight = IntxUnpackedTensor.from_hp(
21102114
weight,
21112115
block_size,
21122116
weight_dtype,
@@ -2146,7 +2150,12 @@ def _intx_weight_only_transform(
21462150
)
21472151
new_weight = _intx_weight_only_quantize_tensor(module.weight, config)
21482152
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
2149-
module.extra_repr = types.MethodType(_linear_extra_repr, module)
2153+
2154+
if isinstance(module, nn.Linear):
2155+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
2156+
elif isinstance(module, nn.Embedding):
2157+
module.extra_repr = types.MethodType(_embedding_extra_repr, module)
2158+
21502159
return module
21512160

21522161

torchao/quantization/quantize_/common/packing_format.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ class PackingFormat(str, Enum):
3434
"""
3535
Unpacked means the subbyte quantized data is stored as int8
3636
"""
37-
UNPACKED = "unpacked"
37+
UNPACKED_TO_INT8 = "unpacked_to_int8"

torchao/quantization/quantize_/workflows/intx/intx_unpacked_tensor.py

Lines changed: 8 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ class IntxUnpackedTensor(TorchAOBaseTensor):
5555
block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size)
5656
"""
5757

58-
tensor_data_attrs = ["int_data", "scale", "zero_point"]
59-
tensor_attributes = ["bit_width", "block_size"]
58+
tensor_data_names = ["int_data", "scale", "zero_point"]
59+
tensor_attribute_names = ["bit_width", "block_size"]
6060

6161
def __new__(cls, int_data, scale, zero_point, bit_width, block_size=None):
6262
kwargs = {}
@@ -105,30 +105,10 @@ def __init__(
105105
self.bit_width = bit_width
106106
self.block_size = block_size
107107

108-
def __tensor_flatten__(self):
109-
return self.tensor_data_attrs, [
110-
getattr(self, attr) for attr in self.tensor_attributes
111-
]
112-
113-
@classmethod
114-
def __tensor_unflatten__(
115-
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
116-
):
117-
return cls(
118-
*[tensor_data_dict[name] for name in cls.tensor_data_attrs],
119-
*tensor_attributes,
120-
)
121-
122-
def _apply_fn_to_data(self, fn):
123-
return self.__class__(
124-
*[fn(getattr(self, attr)) for attr in self.tensor_data_attrs],
125-
*[getattr(self, attr) for attr in self.tensor_attributes],
126-
)
127-
128108
def __repr__(self):
129109
repr_fields = (
130-
self.tensor_data_attrs
131-
+ self.tensor_attributes
110+
self.tensor_data_names
111+
+ self.tensor_attribute_names
132112
+ ["shape", "device", "dtype", "require_grad"]
133113
)
134114
inner_repr = [f"{attr}={getattr(self, attr)}" for attr in repr_fields]
@@ -157,14 +137,17 @@ def to(self, *args, **kwargs):
157137
)
158138

159139
@classmethod
160-
def from_float(
140+
def from_hp(
161141
cls,
162142
float_tensor: torch.Tensor,
163143
block_size: Tuple[int],
164144
dtype: torch.dtype,
165145
*,
166146
mapping_type: MappingType = MappingType.SYMMETRIC,
167147
):
148+
"""
149+
Create an IntxUnpackedTensor from a high-precision tensor
150+
"""
168151
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[dtype]
169152
bit_width = _DTYPE_TO_BIT_WIDTH[dtype]
170153
scale, zero_point = choose_qparams_affine(
@@ -234,44 +217,6 @@ def _(func, types, args, kwargs):
234217
return torch.nn.functional.embedding(indices, weight_tensor, **kwargs)
235218

236219

237-
@implements([aten.detach.default, aten.alias.default])
238-
def _(func, types, args, kwargs):
239-
return return_and_correct_aliasing(
240-
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
241-
)
242-
243-
244-
@implements(aten.clone.default)
245-
def _(func, types, args, kwargs):
246-
return return_and_correct_aliasing(
247-
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
248-
)
249-
250-
251-
def _same_metadata(self: "IntxUnpackedTensor", src: "IntxUnpackedTensor") -> bool:
252-
return (
253-
isinstance(self, IntxUnpackedTensor)
254-
and isinstance(src, IntxUnpackedTensor)
255-
and all(
256-
getattr(self, attr) == getattr(src, attr) for attr in self.tensor_attributes
257-
)
258-
)
259-
260-
261-
@implements(aten.copy_.default)
262-
def _(func, types, args, kwargs):
263-
self = args[0]
264-
src = args[1]
265-
if _same_metadata(self, src):
266-
self_tensors = self.__tensor_flatten__()[0]
267-
for tensor_name in self_tensors:
268-
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
269-
return
270-
raise ValueError(
271-
f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}"
272-
)
273-
274-
275220
@implements(aten.slice.Tensor)
276221
def _(func, types, args, kwargs):
277222
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])

0 commit comments

Comments
 (0)