Skip to content

Commit 7c13cde

Browse files
authored
Support optional_tensor_names in TorchAOBaseTensor (#2710)
Summary: Allows subclasses inheriting from TorchAOBaseTensor to have optional tensor attributes, updated all common util functions to support `optional_tensor_names` list, including `__tensor_flatten__`, `__tensor_unflatten__`, ops like aten._to_copy, contiguous, alias etc. Test Plan: python test/test_utils.py Reviewers: Subscribers: Tasks: Tags:
1 parent 0b88286 commit 7c13cde

File tree

2 files changed

+201
-42
lines changed

2 files changed

+201
-42
lines changed

test/test_utils.py

Lines changed: 108 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -105,34 +105,11 @@ def __init__(self, data):
105105
with self.assertRaisesRegex(NotImplementedError, "arg_types"):
106106
l.weight = torch.nn.Parameter(MyTensor(l.weight))
107107

108-
@skip_if_no_cuda()
109-
def test_default_impls(self):
110-
"""Making sure some common functions has default implementations, such as
111-
__tensor_unflatten__, __tensor_flatten__, _apply_fn_to_data, __repr__, to
112-
"""
113-
114-
class MyTensor(TorchAOBaseTensor):
115-
tensor_data_names = ["qdata"]
116-
tensor_attribute_names = ["attr", "device"]
117-
118-
def __new__(cls, qdata, attr, device=None):
119-
shape = qdata.shape
120-
if device is None:
121-
device = qdata.device
122-
kwargs = {"device": device}
123-
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
124-
125-
def __init__(self, qdata, attr, device=None):
126-
self.qdata = qdata
127-
self.attr = attr
128-
129-
l = torch.nn.Linear(2, 3)
130-
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr"))
131-
lp_tensor = l.weight
108+
def _test_default_impls_helper(self, lp_tensor, lp_tensor_for_copy):
132109
# test __tensor_flatten__ and __tensor_unflatten__
133-
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
110+
tensor_data_names, tensor_attributes = lp_tensor.__tensor_flatten__()
134111
tensor_data_dict = {
135-
name: getattr(lp_tensor, name) for name in tensor_data_name_dict
112+
name: getattr(lp_tensor, name) for name in tensor_data_names
136113
}
137114
outer_size = lp_tensor.size()
138115
outer_stride = lp_tensor.stride()
@@ -150,31 +127,121 @@ def __init__(self, qdata, attr, device=None):
150127
self.assertEqual(lp_tensor.device, original_device)
151128

152129
# __repr__
153-
print(lp_tensor)
130+
_ = str(lp_tensor)
154131

155132
# other ops
156133
lp_tensor = lp_tensor.detach()
157134
# explicitly testing aten.alias
158135
lp_tensor = torch.ops.aten.alias(lp_tensor)
159136
lp_tensor = lp_tensor.clone()
160-
# making qdata not contiguous
161-
lp_tensor.qdata = lp_tensor.qdata.transpose(0, 1).contiguous()
162-
lp_tensor.qdata = lp_tensor.qdata.transpose(0, 1)
163-
self.assertFalse(lp_tensor.qdata.is_contiguous())
164-
lp_tensor = lp_tensor.contiguous()
165-
# making sure contiguous call works
166-
self.assertTrue(lp_tensor.qdata.is_contiguous())
137+
# get all tensor_data_names for both
138+
# non optional and valid optional tensors
139+
tensor_data_names = lp_tensor.tensor_data_names.copy()
140+
if hasattr(lp_tensor, "optional_tensor_data_names"):
141+
for tensor_data_name in lp_tensor.optional_tensor_data_names:
142+
if getattr(lp_tensor, tensor_data_name) is not None:
143+
tensor_data_names.append(tensor_data_name)
144+
145+
# for each of the tensor data, we try to
146+
# make it non-contiguous and then use
147+
# lp_tensor.contiguous() call to make sure
148+
# contiguous() works
149+
for tensor_data_name in tensor_data_names:
150+
tensor = getattr(lp_tensor, tensor_data_name)
151+
# making qdata not contiguous
152+
tensor = tensor.transpose(0, 1).contiguous()
153+
tensor = tensor.transpose(0, 1)
154+
setattr(lp_tensor, tensor_data_name, tensor)
155+
self.assertFalse(getattr(lp_tensor, tensor_data_name).is_contiguous())
156+
lp_tensor = lp_tensor.contiguous()
157+
# making sure contiguous call works
158+
self.assertTrue(getattr(lp_tensor, tensor_data_name).is_contiguous())
167159

168160
# copy_
161+
# making sure that initially tensor values are not the same so we can test copy_
162+
self.assertNotEqual(lp_tensor.qdata[0][0], lp_tensor_for_copy.qdata[0][0])
163+
# copy_ requires the attributes to be the same
164+
for tensor_attr_name in lp_tensor.tensor_attribute_names:
165+
self.assertEqual(
166+
getattr(lp_tensor, tensor_attr_name),
167+
getattr(lp_tensor_for_copy, tensor_attr_name),
168+
)
169+
lp_tensor.copy_(lp_tensor_for_copy)
170+
# after copy_, the tensor values should match
171+
for tensor_data_name in tensor_data_names:
172+
self.assertTrue(
173+
torch.equal(
174+
getattr(lp_tensor, tensor_data_name),
175+
getattr(lp_tensor_for_copy, tensor_data_name),
176+
)
177+
)
178+
179+
@skip_if_no_cuda()
180+
def test_default_impls(self):
181+
"""Making sure some common functions has default implementations, such as
182+
__tensor_unflatten__, __tensor_flatten__, _apply_fn_to_data, __repr__, to
183+
"""
184+
185+
class MyTensor(TorchAOBaseTensor):
186+
tensor_data_names = ["qdata"]
187+
tensor_attribute_names = ["attr", "device"]
188+
189+
def __new__(cls, qdata, attr, device=None):
190+
shape = qdata.shape
191+
if device is None:
192+
device = qdata.device
193+
kwargs = {"device": device}
194+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
195+
196+
def __init__(self, qdata, attr, device=None):
197+
self.qdata = qdata
198+
self.attr = attr
199+
200+
l = torch.nn.Linear(2, 3)
201+
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr"))
202+
lp_tensor = l.weight
203+
169204
another_tensor = torch.nn.Linear(2, 3).weight
170205
# attribute has to be the same
171-
another_lp_tensor = MyTensor(another_tensor, "attr")
172-
# initially tensor values are not the same
173-
self.assertNotEqual(lp_tensor.qdata[0][0], another_lp_tensor.qdata[0][0])
174-
lp_tensor.copy_(another_lp_tensor)
175-
self.assertEqual(lp_tensor.attr, "attr")
176-
# after copy_, the tensor values should match
177-
self.assertEqual(lp_tensor.qdata[0][0], another_lp_tensor.qdata[0][0])
206+
lp_tensor_for_copy = MyTensor(another_tensor, "attr")
207+
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
208+
209+
@skip_if_no_cuda()
210+
def test_default_impls_with_optional_data(self):
211+
class MyTensorWithOptionalData(TorchAOBaseTensor):
212+
tensor_data_names = ["qdata"]
213+
optional_tensor_data_names = ["zero_point"]
214+
tensor_attribute_names = ["attr", "device"]
215+
216+
def __new__(cls, qdata, zero_point=None, attr=1.0, device=None):
217+
shape = qdata.shape
218+
if device is None:
219+
device = qdata.device
220+
kwargs = {"device": device}
221+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
222+
223+
def __init__(self, qdata, zero_point=None, attr=1.0, device=None):
224+
self.qdata = qdata
225+
self.zero_point = zero_point
226+
self.attr = attr
227+
228+
# test both the optional Tensor is None
229+
# and not None
230+
l = torch.nn.Linear(2, 3)
231+
lp_tensor = MyTensorWithOptionalData(l.weight, None, "attr")
232+
l = torch.nn.Linear(2, 3)
233+
lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, None, "attr")
234+
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
235+
236+
l = torch.nn.Linear(2, 3)
237+
lp_tensor = MyTensorWithOptionalData(
238+
l.weight, torch.zeros_like(l.weight), "attr"
239+
)
240+
l = torch.nn.Linear(2, 3)
241+
lp_tensor_for_copy = MyTensorWithOptionalData(
242+
l.weight, torch.zeros_like(l.weight), "attr"
243+
)
244+
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
178245

179246

180247
if __name__ == "__main__":

torchao/utils.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,16 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool:
510510
getattr(self, t_name).shape == getattr(src, t_name).shape
511511
for t_name in self.tensor_data_names
512512
)
513+
_optional_tensor_shape_match = True
514+
if hasattr(self, "optional_tensor_data_names"):
515+
# either both are None or both are not Tensors and the shape match
516+
_optional_tensor_shape_match = all(
517+
getattr(self, t_name).shape == getattr(src, t_name).shape
518+
if getattr(self, t_name) is not None
519+
else getattr(src, t_name) is None
520+
for t_name in self.optional_tensor_data_names
521+
)
522+
513523
_attr_match = all(
514524
getattr(self, a_name) == getattr(src, a_name)
515525
for a_name in self.tensor_attribute_names
@@ -518,6 +528,7 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool:
518528
type(self) == type(src)
519529
and self.shape == src.shape
520530
and _tensor_shape_match
531+
and _optional_tensor_shape_match
521532
and _attr_match
522533
)
523534

@@ -545,6 +556,14 @@ def _(func, types, args, kwargs):
545556
tensors = [
546557
getattr(self, name).to(device) for name in self.tensor_data_names
547558
]
559+
if hasattr(self, "optional_tensor_data_names"):
560+
for tensor_data_name in self.optional_tensor_data_names:
561+
maybe_tensor = getattr(self, tensor_data_name)
562+
if maybe_tensor is not None:
563+
tensors.append(maybe_tensor.to(device))
564+
else:
565+
tensors.append(None)
566+
548567
# change device
549568
tensor_attributes = [
550569
getattr(self, attr_name) if attr_name != "device" else device
@@ -712,6 +731,52 @@ class PlainAQTTensorImpl(...):
712731
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
713732
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
714733
734+
class variables to define to simplify implmentation of tensor subclasses:
735+
`tensor_data_names` (List[str]): list of names of all requires tensor_data, order should match
736+
the `__init__` list of tensor subclass
737+
`optional_tensor_data_names` (List[str]): it's optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional Tensor attributes, when defined, this will be a list of names of Tensors that can be optional
738+
`tensor_attribute_names` (List[str]): list of names of non-Tensor attributes,
739+
order should match the `__init__` list of tensor subclass, following all the `tensor_data_names` arguments and `optional_tensor_data_names`
740+
741+
If `tensor_data_names` and `tensor_attribute_names` are defined, there are some additional
742+
functions that will be added, this includes:
743+
`__tensor_flatten__`: flattens a subclassed tensor instance, returns a tuple, first element is tensor data names for valid tensor data,
744+
second element is a list of non-Tensor attributes
745+
`__tensor_unflatten__`: takes a tensor_data_dict (a map from tensor name to Tensor), and list of non-tensor attributes, returns a new instance of the subclassed tensor
746+
`_apply_fn_to_data`: takes a function (Tensor -> Tensor), applies function to all tensor data and
747+
recreate a new subclassed Tensor with the transformed tensor data
748+
`__repr__`: the string representation of the subclassed tensor instance
749+
torch ops: torch.Tensor.contiguous
750+
aten ops: aten.detach.default, aten.clone.default, aten.alias,default, aten.contiguous.default, aten.copy_.default, aten._to_copy.default (enables t.to)
751+
752+
Example:
753+
class MyTensor(torch.Tensor):
754+
tensor_data_names = ["a", "b"]
755+
optional_tensor_data_names = ["c", "d"]
756+
tensor_attribute_names = ["e", "f"]
757+
758+
def __new__(
759+
cls,
760+
a: Tensor,
761+
b: Tensor,
762+
c: Optional[Tensor],
763+
d: Optional[Tensor],
764+
e: int,
765+
f: str
766+
):
767+
pass
768+
769+
def __init__(
770+
self,
771+
a: Tensor,
772+
b: Tensor,
773+
c: Optional[Tensor],
774+
d: Optional[Tensor],
775+
e: int,
776+
f: str
777+
):
778+
pass
779+
715780
"""
716781

717782
@classmethod
@@ -746,7 +811,14 @@ def __tensor_flatten__(self):
746811
if hasattr(self, "tensor_data_names") and hasattr(
747812
self, "tensor_attribute_names"
748813
):
749-
return self.tensor_data_names, [
814+
tensor_data_names = self.tensor_data_names.copy()
815+
if hasattr(self, "optional_tensor_data_names"):
816+
for tensor_data_name in self.optional_tensor_data_names:
817+
maybe_tensor = getattr(self, tensor_data_name)
818+
if maybe_tensor is not None:
819+
tensor_data_names.append(tensor_data_name)
820+
821+
return tensor_data_names, [
750822
getattr(self, attr) for attr in self.tensor_attribute_names
751823
]
752824
raise NotImplementedError(
@@ -758,13 +830,27 @@ def __tensor_unflatten__(
758830
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
759831
):
760832
tensors = [tensor_data_dict[name] for name in cls.tensor_data_names]
833+
if hasattr(cls, "optional_tensor_data_names"):
834+
for tensor_data_name in cls.optional_tensor_data_names:
835+
if tensor_data_name in tensor_data_dict:
836+
tensors.append(tensor_data_dict[tensor_data_name])
837+
else:
838+
tensors.append(None)
761839
return cls(*tensors, *tensor_attributes)
762840

763841
def _apply_fn_to_data(self, fn):
764842
if hasattr(self, "tensor_data_names") and hasattr(
765843
self, "tensor_attribute_names"
766844
):
767845
tensors = [fn(getattr(self, attr)) for attr in self.tensor_data_names]
846+
if hasattr(self, "optional_tensor_data_names"):
847+
for tensor_data_name in self.optional_tensor_data_names:
848+
maybe_tensor = getattr(self, tensor_data_name)
849+
if maybe_tensor is not None:
850+
tensors.append(fn(maybe_tensor))
851+
else:
852+
tensors.append(None)
853+
768854
tensor_attributes = [
769855
getattr(self, attr) for attr in self.tensor_attribute_names
770856
]
@@ -785,6 +871,12 @@ def __repr__(self):
785871
repr_str += f"{self.tensor_data_names[0]}={getattr(self, self.tensor_data_names[0])}"
786872
for tensor_data_name in self.tensor_data_names[1:]:
787873
repr_str += f", {tensor_data_name}={getattr(self, tensor_data_name)}"
874+
if hasattr(self, "optional_tensor_data_names"):
875+
for tensor_data_name in self.optional_tensor_data_names:
876+
repr_str += (
877+
f", {tensor_data_name}={getattr(self, tensor_data_name)}"
878+
)
879+
788880
for tensor_attribute_name in self.tensor_attribute_names:
789881
repr_str += (
790882
f", {tensor_attribute_name}={getattr(self, tensor_attribute_name)}"

0 commit comments

Comments
 (0)