Skip to content

Commit 6dfe202

Browse files
committed
Support optional_tensor_names in TorchAOBaseTensor
Summary: Allows subclasses inheriting from TorchAOBaseTensor to have optional tensor attributes, updated all common util functions to support `optional_tensor_names` list, including `__tensor_flatten__`, `__tensor_unflatten__`, ops like aten._to_copy, contiguous, alias etc. Test Plan: python test/test_utils.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2710, branch: jerryzh168/stack/17
1 parent 820f264 commit 6dfe202

File tree

2 files changed

+201
-42
lines changed

2 files changed

+201
-42
lines changed

test/test_utils.py

Lines changed: 108 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -55,34 +55,11 @@ def __init__(self, data):
5555
with self.assertRaisesRegex(NotImplementedError, "arg_types"):
5656
l.weight = torch.nn.Parameter(MyTensor(l.weight))
5757

58-
@skip_if_no_cuda()
59-
def test_default_impls(self):
60-
"""Making sure some common functions has default implementations, such as
61-
__tensor_unflatten__, __tensor_flatten__, _apply_fn_to_data, __repr__, to
62-
"""
63-
64-
class MyTensor(TorchAOBaseTensor):
65-
tensor_data_names = ["qdata"]
66-
tensor_attribute_names = ["attr", "device"]
67-
68-
def __new__(cls, qdata, attr, device=None):
69-
shape = qdata.shape
70-
if device is None:
71-
device = qdata.device
72-
kwargs = {"device": device}
73-
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
74-
75-
def __init__(self, qdata, attr, device=None):
76-
self.qdata = qdata
77-
self.attr = attr
78-
79-
l = torch.nn.Linear(2, 3)
80-
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr"))
81-
lp_tensor = l.weight
58+
def _test_default_impls_helper(self, lp_tensor, lp_tensor_for_copy):
8259
# test __tensor_flatten__ and __tensor_unflatten__
83-
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
60+
tensor_data_names, tensor_attributes = lp_tensor.__tensor_flatten__()
8461
tensor_data_dict = {
85-
name: getattr(lp_tensor, name) for name in tensor_data_name_dict
62+
name: getattr(lp_tensor, name) for name in tensor_data_names
8663
}
8764
outer_size = lp_tensor.size()
8865
outer_stride = lp_tensor.stride()
@@ -100,31 +77,121 @@ def __init__(self, qdata, attr, device=None):
10077
self.assertEqual(lp_tensor.device, original_device)
10178

10279
# __repr__
103-
print(lp_tensor)
80+
_ = str(lp_tensor)
10481

10582
# other ops
10683
lp_tensor = lp_tensor.detach()
10784
# explicitly testing aten.alias
10885
lp_tensor = torch.ops.aten.alias(lp_tensor)
10986
lp_tensor = lp_tensor.clone()
110-
# making qdata not contiguous
111-
lp_tensor.qdata = lp_tensor.qdata.transpose(0, 1).contiguous()
112-
lp_tensor.qdata = lp_tensor.qdata.transpose(0, 1)
113-
self.assertFalse(lp_tensor.qdata.is_contiguous())
114-
lp_tensor = lp_tensor.contiguous()
115-
# making sure contiguous call works
116-
self.assertTrue(lp_tensor.qdata.is_contiguous())
87+
# get all tensor_data_names for both
88+
# non optional and valid optional tensors
89+
tensor_data_names = lp_tensor.tensor_data_names.copy()
90+
if hasattr(lp_tensor, "optional_tensor_data_names"):
91+
for tensor_data_name in lp_tensor.optional_tensor_data_names:
92+
if getattr(lp_tensor, tensor_data_name) is not None:
93+
tensor_data_names.append(tensor_data_name)
94+
95+
# for each of the tensor data, we try to
96+
# make it non-contiguous and then use
97+
# lp_tensor.contiguous() call to make sure
98+
# contiguous() works
99+
for tensor_data_name in tensor_data_names:
100+
tensor = getattr(lp_tensor, tensor_data_name)
101+
# making qdata not contiguous
102+
tensor = tensor.transpose(0, 1).contiguous()
103+
tensor = tensor.transpose(0, 1)
104+
setattr(lp_tensor, tensor_data_name, tensor)
105+
self.assertFalse(getattr(lp_tensor, tensor_data_name).is_contiguous())
106+
lp_tensor = lp_tensor.contiguous()
107+
# making sure contiguous call works
108+
self.assertTrue(getattr(lp_tensor, tensor_data_name).is_contiguous())
117109

118110
# copy_
111+
# making sure that initially tensor values are not the same so we can test copy_
112+
self.assertNotEqual(lp_tensor.qdata[0][0], lp_tensor_for_copy.qdata[0][0])
113+
# copy_ requires the attributes to be the same
114+
for tensor_attr_name in lp_tensor.tensor_attribute_names:
115+
self.assertEqual(
116+
getattr(lp_tensor, tensor_attr_name),
117+
getattr(lp_tensor_for_copy, tensor_attr_name),
118+
)
119+
lp_tensor.copy_(lp_tensor_for_copy)
120+
# after copy_, the tensor values should match
121+
for tensor_data_name in tensor_data_names:
122+
self.assertTrue(
123+
torch.equal(
124+
getattr(lp_tensor, tensor_data_name),
125+
getattr(lp_tensor_for_copy, tensor_data_name),
126+
)
127+
)
128+
129+
@skip_if_no_cuda()
130+
def test_default_impls(self):
131+
"""Making sure some common functions has default implementations, such as
132+
__tensor_unflatten__, __tensor_flatten__, _apply_fn_to_data, __repr__, to
133+
"""
134+
135+
class MyTensor(TorchAOBaseTensor):
136+
tensor_data_names = ["qdata"]
137+
tensor_attribute_names = ["attr", "device"]
138+
139+
def __new__(cls, qdata, attr, device=None):
140+
shape = qdata.shape
141+
if device is None:
142+
device = qdata.device
143+
kwargs = {"device": device}
144+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
145+
146+
def __init__(self, qdata, attr, device=None):
147+
self.qdata = qdata
148+
self.attr = attr
149+
150+
l = torch.nn.Linear(2, 3)
151+
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr"))
152+
lp_tensor = l.weight
153+
119154
another_tensor = torch.nn.Linear(2, 3).weight
120155
# attribute has to be the same
121-
another_lp_tensor = MyTensor(another_tensor, "attr")
122-
# initially tensor values are not the same
123-
self.assertNotEqual(lp_tensor.qdata[0][0], another_lp_tensor.qdata[0][0])
124-
lp_tensor.copy_(another_lp_tensor)
125-
self.assertEqual(lp_tensor.attr, "attr")
126-
# after copy_, the tensor values should match
127-
self.assertEqual(lp_tensor.qdata[0][0], another_lp_tensor.qdata[0][0])
156+
lp_tensor_for_copy = MyTensor(another_tensor, "attr")
157+
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
158+
159+
@skip_if_no_cuda()
160+
def test_default_impls_with_optional_data(self):
161+
class MyTensorWithOptionalData(TorchAOBaseTensor):
162+
tensor_data_names = ["qdata"]
163+
optional_tensor_data_names = ["zero_point"]
164+
tensor_attribute_names = ["attr", "device"]
165+
166+
def __new__(cls, qdata, zero_point=None, attr=1.0, device=None):
167+
shape = qdata.shape
168+
if device is None:
169+
device = qdata.device
170+
kwargs = {"device": device}
171+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
172+
173+
def __init__(self, qdata, zero_point=None, attr=1.0, device=None):
174+
self.qdata = qdata
175+
self.zero_point = zero_point
176+
self.attr = attr
177+
178+
# test both the optional Tensor is None
179+
# and not None
180+
l = torch.nn.Linear(2, 3)
181+
lp_tensor = MyTensorWithOptionalData(l.weight, None, "attr")
182+
l = torch.nn.Linear(2, 3)
183+
lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, None, "attr")
184+
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
185+
186+
l = torch.nn.Linear(2, 3)
187+
lp_tensor = MyTensorWithOptionalData(
188+
l.weight, torch.zeros_like(l.weight), "attr"
189+
)
190+
l = torch.nn.Linear(2, 3)
191+
lp_tensor_for_copy = MyTensorWithOptionalData(
192+
l.weight, torch.zeros_like(l.weight), "attr"
193+
)
194+
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
128195

129196

130197
if __name__ == "__main__":

torchao/utils.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,16 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool:
463463
getattr(self, t_name).shape == getattr(src, t_name).shape
464464
for t_name in self.tensor_data_names
465465
)
466+
_optional_tensor_shape_match = True
467+
if hasattr(self, "optional_tensor_data_names"):
468+
# either both are None or both are not Tensors and the shape match
469+
_optional_tensor_shape_match = all(
470+
getattr(self, t_name).shape == getattr(src, t_name).shape
471+
if getattr(self, t_name) is not None
472+
else getattr(src, t_name) is None
473+
for t_name in self.optional_tensor_data_names
474+
)
475+
466476
_attr_match = all(
467477
getattr(self, a_name) == getattr(src, a_name)
468478
for a_name in self.tensor_attribute_names
@@ -471,6 +481,7 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool:
471481
type(self) == type(src)
472482
and self.shape == src.shape
473483
and _tensor_shape_match
484+
and _optional_tensor_shape_match
474485
and _attr_match
475486
)
476487

@@ -498,6 +509,14 @@ def _(func, types, args, kwargs):
498509
tensors = [
499510
getattr(self, name).to(device) for name in self.tensor_data_names
500511
]
512+
if hasattr(self, "optional_tensor_data_names"):
513+
for tensor_data_name in self.optional_tensor_data_names:
514+
maybe_tensor = getattr(self, tensor_data_name)
515+
if maybe_tensor is not None:
516+
tensors.append(maybe_tensor.to(device))
517+
else:
518+
tensors.append(None)
519+
501520
# change device
502521
tensor_attributes = [
503522
getattr(self, attr_name) if attr_name != "device" else device
@@ -665,6 +684,52 @@ class PlainAQTTensorImpl(...):
665684
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
666685
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
667686
687+
class variables to define to simplify implmentation of tensor subclasses:
688+
`tensor_data_names` (List[str]): list of names of all requires tensor_data, order should match
689+
the `__init__` list of tensor subclass
690+
`optional_tensor_data_names` (List[str]): it's optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional Tensor attributes, when defined, this will be a list of names of Tensors that can be optional
691+
`tensor_attribute_names` (List[str]): list of names of non-Tensor attributes,
692+
order should match the `__init__` list of tensor subclass, following all the `tensor_data_names` arguments and `optional_tensor_data_names`
693+
694+
If `tensor_data_names` and `tensor_attribute_names` are defined, there are some additional
695+
functions that will be added, this includes:
696+
`__tensor_flatten__`: flattens a subclassed tensor instance, returns a tuple, first element is tensor data names for valid tensor data,
697+
second element is a list of non-Tensor attributes
698+
`__tensor_unflatten__`: takes a tensor_data_dict (a map from tensor name to Tensor), and list of non-tensor attributes, returns a new instance of the subclassed tensor
699+
`_apply_fn_to_data`: takes a function (Tensor -> Tensor), applies function to all tensor data and
700+
recreate a new subclassed Tensor with the transformed tensor data
701+
`__repr__`: the string representation of the subclassed tensor instance
702+
torch ops: torch.Tensor.contiguous
703+
aten ops: aten.detach.default, aten.clone.default, aten.alias,default, aten.contiguous.default, aten.copy_.default, aten._to_copy.default (enables t.to)
704+
705+
Example:
706+
class MyTensor(torch.Tensor):
707+
tensor_data_names = ["a", "b"]
708+
optional_tensor_data_names = ["c", "d"]
709+
tensor_attribute_names = ["e", "f"]
710+
711+
def __new__(
712+
cls,
713+
a: Tensor,
714+
b: Tensor,
715+
c: Optional[Tensor],
716+
d: Optional[Tensor],
717+
e: int,
718+
f: str
719+
):
720+
pass
721+
722+
def __init__(
723+
self,
724+
a: Tensor,
725+
b: Tensor,
726+
c: Optional[Tensor],
727+
d: Optional[Tensor],
728+
e: int,
729+
f: str
730+
):
731+
pass
732+
668733
"""
669734

670735
@classmethod
@@ -699,7 +764,14 @@ def __tensor_flatten__(self):
699764
if hasattr(self, "tensor_data_names") and hasattr(
700765
self, "tensor_attribute_names"
701766
):
702-
return self.tensor_data_names, [
767+
tensor_data_names = self.tensor_data_names.copy()
768+
if hasattr(self, "optional_tensor_data_names"):
769+
for tensor_data_name in self.optional_tensor_data_names:
770+
maybe_tensor = getattr(self, tensor_data_name)
771+
if maybe_tensor is not None:
772+
tensor_data_names.append(tensor_data_name)
773+
774+
return tensor_data_names, [
703775
getattr(self, attr) for attr in self.tensor_attribute_names
704776
]
705777
raise NotImplementedError(
@@ -711,13 +783,27 @@ def __tensor_unflatten__(
711783
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
712784
):
713785
tensors = [tensor_data_dict[name] for name in cls.tensor_data_names]
786+
if hasattr(cls, "optional_tensor_data_names"):
787+
for tensor_data_name in cls.optional_tensor_data_names:
788+
if tensor_data_name in tensor_data_dict:
789+
tensors.append(tensor_data_dict[tensor_data_name])
790+
else:
791+
tensors.append(None)
714792
return cls(*tensors, *tensor_attributes)
715793

716794
def _apply_fn_to_data(self, fn):
717795
if hasattr(self, "tensor_data_names") and hasattr(
718796
self, "tensor_attribute_names"
719797
):
720798
tensors = [fn(getattr(self, attr)) for attr in self.tensor_data_names]
799+
if hasattr(self, "optional_tensor_data_names"):
800+
for tensor_data_name in self.optional_tensor_data_names:
801+
maybe_tensor = getattr(self, tensor_data_name)
802+
if maybe_tensor is not None:
803+
tensors.append(fn(maybe_tensor))
804+
else:
805+
tensors.append(None)
806+
721807
tensor_attributes = [
722808
getattr(self, attr) for attr in self.tensor_attribute_names
723809
]
@@ -738,6 +824,12 @@ def __repr__(self):
738824
repr_str += f"{self.tensor_data_names[0]}={getattr(self, self.tensor_data_names[0])}"
739825
for tensor_data_name in self.tensor_data_names[1:]:
740826
repr_str += f", {tensor_data_name}={getattr(self, tensor_data_name)}"
827+
if hasattr(self, "optional_tensor_data_names"):
828+
for tensor_data_name in self.optional_tensor_data_names:
829+
repr_str += (
830+
f", {tensor_data_name}={getattr(self, tensor_data_name)}"
831+
)
832+
741833
for tensor_attribute_name in self.tensor_attribute_names:
742834
repr_str += (
743835
f", {tensor_attribute_name}={getattr(self, tensor_attribute_name)}"

0 commit comments

Comments
 (0)