|
7 | 7 | from unittest.mock import patch
|
8 | 8 |
|
9 | 9 | import torch
|
| 10 | +from torch.utils._python_dispatch import return_and_correct_aliasing |
10 | 11 |
|
11 | 12 | from torchao.utils import TorchAOBaseTensor, torch_version_at_least
|
12 | 13 |
|
@@ -49,6 +50,47 @@ def __init__(self, data):
|
49 | 50 | with self.assertRaisesRegex(NotImplementedError, "arg_types"):
|
50 | 51 | l.weight = torch.nn.Parameter(MyTensor(l.weight))
|
51 | 52 |
|
| 53 | + def test_default_impls(self): |
| 54 | + """Making sure some common functions has default implementations, such as |
| 55 | + __tensor_unflatten__, __tensor_flatten__, _apply_fn_to_data, __repr__, to |
| 56 | + """ |
| 57 | + |
| 58 | + class MyTensor(TorchAOBaseTensor): |
| 59 | + tensor_data_names = ["qdata"] |
| 60 | + tensor_attribute_names = ["attr"] |
| 61 | + |
| 62 | + def __new__(cls, qdata, attr): |
| 63 | + shape = qdata.shape |
| 64 | + return torch.Tensor._make_wrapper_subclass(cls, shape) # type: ignore[attr-defined] |
| 65 | + |
| 66 | + def __init__(self, qdata, attr): |
| 67 | + self.qdata = qdata |
| 68 | + self.attr = attr |
| 69 | + |
| 70 | + implements = MyTensor.implements |
| 71 | + |
| 72 | + @implements(torch.ops.aten.detach.default) |
| 73 | + def _(func, types, args, kwargs): |
| 74 | + return return_and_correct_aliasing( |
| 75 | + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) |
| 76 | + ) |
| 77 | + |
| 78 | + l = torch.nn.Linear(1, 1) |
| 79 | + l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr")) |
| 80 | + lp_tensor = l.weight |
| 81 | + tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() |
| 82 | + tensor_data_dict = { |
| 83 | + name: getattr(lp_tensor, name) for name in tensor_data_name_dict |
| 84 | + } |
| 85 | + outer_size = lp_tensor.size() |
| 86 | + outer_stride = lp_tensor.stride() |
| 87 | + reconstructed = type(lp_tensor).__tensor_unflatten__( |
| 88 | + tensor_data_dict, tensor_attributes, outer_size, outer_stride |
| 89 | + ) |
| 90 | + self.assertTrue(torch.equal(lp_tensor.qdata, reconstructed.qdata)) |
| 91 | + self.assertEqual(lp_tensor.attr, reconstructed.attr) |
| 92 | + print(lp_tensor) |
| 93 | + |
52 | 94 |
|
53 | 95 | if __name__ == "__main__":
|
54 | 96 | unittest.main()
|
0 commit comments