Skip to content

Support optional_tensor_names in TorchAOBaseTensor #2710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 108 additions & 41 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,34 +105,11 @@ def __init__(self, data):
with self.assertRaisesRegex(NotImplementedError, "arg_types"):
l.weight = torch.nn.Parameter(MyTensor(l.weight))

@skip_if_no_cuda()
def test_default_impls(self):
"""Making sure some common functions has default implementations, such as
__tensor_unflatten__, __tensor_flatten__, _apply_fn_to_data, __repr__, to
"""

class MyTensor(TorchAOBaseTensor):
tensor_data_names = ["qdata"]
tensor_attribute_names = ["attr", "device"]

def __new__(cls, qdata, attr, device=None):
shape = qdata.shape
if device is None:
device = qdata.device
kwargs = {"device": device}
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(self, qdata, attr, device=None):
self.qdata = qdata
self.attr = attr

l = torch.nn.Linear(2, 3)
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr"))
lp_tensor = l.weight
def _test_default_impls_helper(self, lp_tensor, lp_tensor_for_copy):
# test __tensor_flatten__ and __tensor_unflatten__
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
tensor_data_names, tensor_attributes = lp_tensor.__tensor_flatten__()
tensor_data_dict = {
name: getattr(lp_tensor, name) for name in tensor_data_name_dict
name: getattr(lp_tensor, name) for name in tensor_data_names
}
outer_size = lp_tensor.size()
outer_stride = lp_tensor.stride()
Expand All @@ -150,31 +127,121 @@ def __init__(self, qdata, attr, device=None):
self.assertEqual(lp_tensor.device, original_device)

# __repr__
print(lp_tensor)
_ = str(lp_tensor)

# other ops
lp_tensor = lp_tensor.detach()
# explicitly testing aten.alias
lp_tensor = torch.ops.aten.alias(lp_tensor)
lp_tensor = lp_tensor.clone()
# making qdata not contiguous
lp_tensor.qdata = lp_tensor.qdata.transpose(0, 1).contiguous()
lp_tensor.qdata = lp_tensor.qdata.transpose(0, 1)
self.assertFalse(lp_tensor.qdata.is_contiguous())
lp_tensor = lp_tensor.contiguous()
# making sure contiguous call works
self.assertTrue(lp_tensor.qdata.is_contiguous())
# get all tensor_data_names for both
# non optional and valid optional tensors
tensor_data_names = lp_tensor.tensor_data_names.copy()
if hasattr(lp_tensor, "optional_tensor_data_names"):
for tensor_data_name in lp_tensor.optional_tensor_data_names:
if getattr(lp_tensor, tensor_data_name) is not None:
tensor_data_names.append(tensor_data_name)

# for each of the tensor data, we try to
# make it non-contiguous and then use
# lp_tensor.contiguous() call to make sure
# contiguous() works
for tensor_data_name in tensor_data_names:
tensor = getattr(lp_tensor, tensor_data_name)
# making qdata not contiguous
tensor = tensor.transpose(0, 1).contiguous()
tensor = tensor.transpose(0, 1)
setattr(lp_tensor, tensor_data_name, tensor)
self.assertFalse(getattr(lp_tensor, tensor_data_name).is_contiguous())
lp_tensor = lp_tensor.contiguous()
# making sure contiguous call works
self.assertTrue(getattr(lp_tensor, tensor_data_name).is_contiguous())

# copy_
# making sure that initially tensor values are not the same so we can test copy_
self.assertNotEqual(lp_tensor.qdata[0][0], lp_tensor_for_copy.qdata[0][0])
# copy_ requires the attributes to be the same
for tensor_attr_name in lp_tensor.tensor_attribute_names:
self.assertEqual(
getattr(lp_tensor, tensor_attr_name),
getattr(lp_tensor_for_copy, tensor_attr_name),
)
lp_tensor.copy_(lp_tensor_for_copy)
# after copy_, the tensor values should match
for tensor_data_name in tensor_data_names:
self.assertTrue(
torch.equal(
getattr(lp_tensor, tensor_data_name),
getattr(lp_tensor_for_copy, tensor_data_name),
)
)

@skip_if_no_cuda()
def test_default_impls(self):
"""Making sure some common functions has default implementations, such as
__tensor_unflatten__, __tensor_flatten__, _apply_fn_to_data, __repr__, to
"""

class MyTensor(TorchAOBaseTensor):
tensor_data_names = ["qdata"]
tensor_attribute_names = ["attr", "device"]

def __new__(cls, qdata, attr, device=None):
shape = qdata.shape
if device is None:
device = qdata.device
kwargs = {"device": device}
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(self, qdata, attr, device=None):
self.qdata = qdata
self.attr = attr

l = torch.nn.Linear(2, 3)
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr"))
lp_tensor = l.weight

another_tensor = torch.nn.Linear(2, 3).weight
# attribute has to be the same
another_lp_tensor = MyTensor(another_tensor, "attr")
# initially tensor values are not the same
self.assertNotEqual(lp_tensor.qdata[0][0], another_lp_tensor.qdata[0][0])
lp_tensor.copy_(another_lp_tensor)
self.assertEqual(lp_tensor.attr, "attr")
# after copy_, the tensor values should match
self.assertEqual(lp_tensor.qdata[0][0], another_lp_tensor.qdata[0][0])
lp_tensor_for_copy = MyTensor(another_tensor, "attr")
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)

@skip_if_no_cuda()
def test_default_impls_with_optional_data(self):
class MyTensorWithOptionalData(TorchAOBaseTensor):
tensor_data_names = ["qdata"]
optional_tensor_data_names = ["zero_point"]
tensor_attribute_names = ["attr", "device"]

def __new__(cls, qdata, zero_point=None, attr=1.0, device=None):
shape = qdata.shape
if device is None:
device = qdata.device
kwargs = {"device": device}
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(self, qdata, zero_point=None, attr=1.0, device=None):
self.qdata = qdata
self.zero_point = zero_point
self.attr = attr

# test both the optional Tensor is None
# and not None
l = torch.nn.Linear(2, 3)
lp_tensor = MyTensorWithOptionalData(l.weight, None, "attr")
l = torch.nn.Linear(2, 3)
lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, None, "attr")
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)

l = torch.nn.Linear(2, 3)
lp_tensor = MyTensorWithOptionalData(
l.weight, torch.zeros_like(l.weight), "attr"
)
l = torch.nn.Linear(2, 3)
lp_tensor_for_copy = MyTensorWithOptionalData(
l.weight, torch.zeros_like(l.weight), "attr"
)
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)


if __name__ == "__main__":
Expand Down
94 changes: 93 additions & 1 deletion torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,16 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool:
getattr(self, t_name).shape == getattr(src, t_name).shape
for t_name in self.tensor_data_names
)
_optional_tensor_shape_match = True
if hasattr(self, "optional_tensor_data_names"):
# either both are None or both are not Tensors and the shape match
_optional_tensor_shape_match = all(
getattr(self, t_name).shape == getattr(src, t_name).shape
if getattr(self, t_name) is not None
else getattr(src, t_name) is None
for t_name in self.optional_tensor_data_names
)

_attr_match = all(
getattr(self, a_name) == getattr(src, a_name)
for a_name in self.tensor_attribute_names
Expand All @@ -518,6 +528,7 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool:
type(self) == type(src)
and self.shape == src.shape
and _tensor_shape_match
and _optional_tensor_shape_match
and _attr_match
)

Expand Down Expand Up @@ -545,6 +556,14 @@ def _(func, types, args, kwargs):
tensors = [
getattr(self, name).to(device) for name in self.tensor_data_names
]
if hasattr(self, "optional_tensor_data_names"):
for tensor_data_name in self.optional_tensor_data_names:
maybe_tensor = getattr(self, tensor_data_name)
if maybe_tensor is not None:
tensors.append(maybe_tensor.to(device))
else:
tensors.append(None)

# change device
tensor_attributes = [
getattr(self, attr_name) if attr_name != "device" else device
Expand Down Expand Up @@ -712,6 +731,52 @@ class PlainAQTTensorImpl(...):
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)

class variables to define to simplify implmentation of tensor subclasses:
`tensor_data_names` (List[str]): list of names of all requires tensor_data, order should match
the `__init__` list of tensor subclass
`optional_tensor_data_names` (List[str]): it's optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional Tensor attributes, when defined, this will be a list of names of Tensors that can be optional
`tensor_attribute_names` (List[str]): list of names of non-Tensor attributes,
order should match the `__init__` list of tensor subclass, following all the `tensor_data_names` arguments and `optional_tensor_data_names`

If `tensor_data_names` and `tensor_attribute_names` are defined, there are some additional
functions that will be added, this includes:
`__tensor_flatten__`: flattens a subclassed tensor instance, returns a tuple, first element is tensor data names for valid tensor data,
second element is a list of non-Tensor attributes
`__tensor_unflatten__`: takes a tensor_data_dict (a map from tensor name to Tensor), and list of non-tensor attributes, returns a new instance of the subclassed tensor
`_apply_fn_to_data`: takes a function (Tensor -> Tensor), applies function to all tensor data and
recreate a new subclassed Tensor with the transformed tensor data
`__repr__`: the string representation of the subclassed tensor instance
torch ops: torch.Tensor.contiguous
aten ops: aten.detach.default, aten.clone.default, aten.alias,default, aten.contiguous.default, aten.copy_.default, aten._to_copy.default (enables t.to)

Example:
class MyTensor(torch.Tensor):
tensor_data_names = ["a", "b"]
optional_tensor_data_names = ["c", "d"]
tensor_attribute_names = ["e", "f"]

def __new__(
cls,
a: Tensor,
b: Tensor,
c: Optional[Tensor],
d: Optional[Tensor],
e: int,
f: str
):
pass

def __init__(
self,
a: Tensor,
b: Tensor,
c: Optional[Tensor],
d: Optional[Tensor],
e: int,
f: str
):
pass

"""

@classmethod
Expand Down Expand Up @@ -746,7 +811,14 @@ def __tensor_flatten__(self):
if hasattr(self, "tensor_data_names") and hasattr(
self, "tensor_attribute_names"
):
return self.tensor_data_names, [
tensor_data_names = self.tensor_data_names.copy()
if hasattr(self, "optional_tensor_data_names"):
for tensor_data_name in self.optional_tensor_data_names:
maybe_tensor = getattr(self, tensor_data_name)
if maybe_tensor is not None:
tensor_data_names.append(tensor_data_name)

return tensor_data_names, [
getattr(self, attr) for attr in self.tensor_attribute_names
]
raise NotImplementedError(
Expand All @@ -758,13 +830,27 @@ def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
tensors = [tensor_data_dict[name] for name in cls.tensor_data_names]
if hasattr(cls, "optional_tensor_data_names"):
for tensor_data_name in cls.optional_tensor_data_names:
if tensor_data_name in tensor_data_dict:
tensors.append(tensor_data_dict[tensor_data_name])
else:
tensors.append(None)
return cls(*tensors, *tensor_attributes)

def _apply_fn_to_data(self, fn):
if hasattr(self, "tensor_data_names") and hasattr(
self, "tensor_attribute_names"
):
tensors = [fn(getattr(self, attr)) for attr in self.tensor_data_names]
if hasattr(self, "optional_tensor_data_names"):
for tensor_data_name in self.optional_tensor_data_names:
maybe_tensor = getattr(self, tensor_data_name)
if maybe_tensor is not None:
tensors.append(fn(maybe_tensor))
else:
tensors.append(None)

tensor_attributes = [
getattr(self, attr) for attr in self.tensor_attribute_names
]
Expand All @@ -785,6 +871,12 @@ def __repr__(self):
repr_str += f"{self.tensor_data_names[0]}={getattr(self, self.tensor_data_names[0])}"
for tensor_data_name in self.tensor_data_names[1:]:
repr_str += f", {tensor_data_name}={getattr(self, tensor_data_name)}"
if hasattr(self, "optional_tensor_data_names"):
for tensor_data_name in self.optional_tensor_data_names:
repr_str += (
f", {tensor_data_name}={getattr(self, tensor_data_name)}"
)

for tensor_attribute_name in self.tensor_attribute_names:
repr_str += (
f", {tensor_attribute_name}={getattr(self, tensor_attribute_name)}"
Expand Down