Skip to content

Commit 3932909

Browse files
authored
Add more utils to TorchAOBaseTensor (#2597)
Summary: Added default impls for: * __tensor_flatten__ and __tensor_unflatten__ when tensor_data_names and tensor_attribute_names are defined * __repr__ * _apply_fn_to_data Next * more op definitions Test Plan: python test/test_utils.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2597, branch: jerryzh168/stack/12
1 parent 5fe4ebd commit 3932909

File tree

2 files changed

+75
-2
lines changed

2 files changed

+75
-2
lines changed

test/test_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from unittest.mock import patch
88

99
import torch
10+
from torch.utils._python_dispatch import return_and_correct_aliasing
1011

1112
from torchao.utils import TorchAOBaseTensor, torch_version_at_least
1213

@@ -49,6 +50,47 @@ def __init__(self, data):
4950
with self.assertRaisesRegex(NotImplementedError, "arg_types"):
5051
l.weight = torch.nn.Parameter(MyTensor(l.weight))
5152

53+
def test_default_impls(self):
54+
"""Making sure some common functions has default implementations, such as
55+
__tensor_unflatten__, __tensor_flatten__, _apply_fn_to_data, __repr__, to
56+
"""
57+
58+
class MyTensor(TorchAOBaseTensor):
59+
tensor_data_names = ["qdata"]
60+
tensor_attribute_names = ["attr"]
61+
62+
def __new__(cls, qdata, attr):
63+
shape = qdata.shape
64+
return torch.Tensor._make_wrapper_subclass(cls, shape) # type: ignore[attr-defined]
65+
66+
def __init__(self, qdata, attr):
67+
self.qdata = qdata
68+
self.attr = attr
69+
70+
implements = MyTensor.implements
71+
72+
@implements(torch.ops.aten.detach.default)
73+
def _(func, types, args, kwargs):
74+
return return_and_correct_aliasing(
75+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
76+
)
77+
78+
l = torch.nn.Linear(1, 1)
79+
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr"))
80+
lp_tensor = l.weight
81+
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
82+
tensor_data_dict = {
83+
name: getattr(lp_tensor, name) for name in tensor_data_name_dict
84+
}
85+
outer_size = lp_tensor.size()
86+
outer_stride = lp_tensor.stride()
87+
reconstructed = type(lp_tensor).__tensor_unflatten__(
88+
tensor_data_dict, tensor_attributes, outer_size, outer_stride
89+
)
90+
self.assertTrue(torch.equal(lp_tensor.qdata, reconstructed.qdata))
91+
self.assertEqual(lp_tensor.attr, reconstructed.attr)
92+
print(lp_tensor)
93+
5294

5395
if __name__ == "__main__":
5496
unittest.main()

torchao/utils.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -584,15 +584,46 @@ class PlainAQTTensorImpl(...):
584584
_get_to_kwargs = _get_to_kwargs
585585

586586
def __tensor_flatten__(self):
587-
raise NotImplementedError("Subclasses must implement __tensor_flatten__")
587+
if hasattr(self, "tensor_data_names") and hasattr(
588+
self, "tensor_attribute_names"
589+
):
590+
return self.tensor_data_names, [
591+
getattr(self, attr) for attr in self.tensor_attribute_names
592+
]
593+
raise NotImplementedError(
594+
"Subclasses must implement __tensor_flatten__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class or tensor instance"
595+
)
588596

589597
@classmethod
590598
def __tensor_unflatten__(
591599
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
592600
):
593-
raise NotImplementedError("Subclasses must implement __tensor_unflatten__")
601+
tensors = [tensor_data_dict[name] for name in cls.tensor_data_names]
602+
return cls(*tensors, *tensor_attributes)
603+
604+
def _apply_fn_to_data(self, fn):
605+
tensors = [fn(getattr(self, attr)) for attr in self.tensor_data_names]
606+
tensor_attributes = [
607+
getattr(self, attr) for attr in self.tensor_attribute_names
608+
]
609+
return self.__class__(
610+
*tensors,
611+
*tensor_attributes,
612+
)
594613

595614
def __repr__(self):
615+
if hasattr(self, "tensor_data_names") and hasattr(
616+
self, "tensor_attribute_names"
617+
):
618+
repr_str = ""
619+
repr_str += f"{self.tensor_data_names[0]}={getattr(self, self.tensor_data_names[0])}"
620+
for tensor_data_name in self.tensor_data_names[1:]:
621+
repr_str += f", {tensor_data_name}={getattr(self, tensor_data_name)}"
622+
for tensor_attribute_name in self.tensor_attribute_names:
623+
repr_str += (
624+
f", {tensor_attribute_name}={getattr(self, tensor_attribute_name)}"
625+
)
626+
return f"{self.__class__.__name__}({repr_str})"
596627
raise NotImplementedError("Subclasses must implement __repr__")
597628

598629
def get_layout(self):

0 commit comments

Comments
 (0)