Skip to content

Commit 30f5850

Browse files
authored
Support more ops in TorchAOBaseTensor (#2609)
Summary: * detach * clone * alias * contiguous * copy_ * to Test Plan: python test/test_utils.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2598, branch: jerryzh168/stack/13
1 parent 3932909 commit 30f5850

File tree

2 files changed

+164
-27
lines changed

2 files changed

+164
-27
lines changed

test/test_utils.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from unittest.mock import patch
88

99
import torch
10-
from torch.utils._python_dispatch import return_and_correct_aliasing
1110

11+
from torchao.testing.utils import skip_if_no_cuda
1212
from torchao.utils import TorchAOBaseTensor, torch_version_at_least
1313

1414

@@ -47,37 +47,39 @@ def __init__(self, data):
4747
self.data = data
4848

4949
l = torch.nn.Linear(10, 10)
50+
# since we did not define `tensor_data_names` and `tensor_attribute_names` for MyTensor
51+
# the following call will error out because `detach` is defined in `TorchAOBaseTensor`
52+
# but would rely on `tensor_data_names` and `tensor_attribute_names` being defined for it to work
53+
# user could either specify `tensor_data_names` and `tensor_attribute_names` or manually implement
54+
# detach op
5055
with self.assertRaisesRegex(NotImplementedError, "arg_types"):
5156
l.weight = torch.nn.Parameter(MyTensor(l.weight))
5257

58+
@skip_if_no_cuda()
5359
def test_default_impls(self):
5460
"""Making sure some common functions has default implementations, such as
5561
__tensor_unflatten__, __tensor_flatten__, _apply_fn_to_data, __repr__, to
5662
"""
5763

5864
class MyTensor(TorchAOBaseTensor):
5965
tensor_data_names = ["qdata"]
60-
tensor_attribute_names = ["attr"]
66+
tensor_attribute_names = ["attr", "device"]
6167

62-
def __new__(cls, qdata, attr):
68+
def __new__(cls, qdata, attr, device=None):
6369
shape = qdata.shape
64-
return torch.Tensor._make_wrapper_subclass(cls, shape) # type: ignore[attr-defined]
70+
if device is None:
71+
device = qdata.device
72+
kwargs = {"device": device}
73+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
6574

66-
def __init__(self, qdata, attr):
75+
def __init__(self, qdata, attr, device=None):
6776
self.qdata = qdata
6877
self.attr = attr
6978

70-
implements = MyTensor.implements
71-
72-
@implements(torch.ops.aten.detach.default)
73-
def _(func, types, args, kwargs):
74-
return return_and_correct_aliasing(
75-
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
76-
)
77-
7879
l = torch.nn.Linear(1, 1)
7980
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr"))
8081
lp_tensor = l.weight
82+
# test __tensor_flatten__ and __tensor_unflatten__
8183
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
8284
tensor_data_dict = {
8385
name: getattr(lp_tensor, name) for name in tensor_data_name_dict
@@ -89,8 +91,35 @@ def _(func, types, args, kwargs):
8991
)
9092
self.assertTrue(torch.equal(lp_tensor.qdata, reconstructed.qdata))
9193
self.assertEqual(lp_tensor.attr, reconstructed.attr)
94+
95+
# `to` / `_to_copy`
96+
original_device = lp_tensor.device
97+
lp_tensor = lp_tensor.to("cuda")
98+
self.assertEqual(lp_tensor.device.type, "cuda")
99+
lp_tensor = lp_tensor.to(original_device)
100+
self.assertEqual(lp_tensor.device, original_device)
101+
102+
# __repr__
92103
print(lp_tensor)
93104

105+
# other ops
106+
lp_tensor = lp_tensor.detach()
107+
# explicitly testing aten.alias
108+
lp_tensor = torch.ops.aten.alias(lp_tensor)
109+
lp_tensor = lp_tensor.clone()
110+
lp_tensor = lp_tensor.contiguous()
111+
112+
# copy_
113+
another_tensor = torch.nn.Linear(1, 1).weight
114+
# attribute has to be the same
115+
another_lp_tensor = MyTensor(another_tensor, "attr")
116+
# initially tensor values are not the same
117+
self.assertNotEqual(lp_tensor.qdata[0], another_lp_tensor.qdata[0])
118+
lp_tensor.copy_(another_lp_tensor)
119+
self.assertEqual(lp_tensor.attr, "attr")
120+
# after copy_, the tensor values should match
121+
self.assertEqual(lp_tensor.qdata[0], another_lp_tensor.qdata[0])
122+
94123

95124
if __name__ == "__main__":
96125
unittest.main()

torchao/utils.py

Lines changed: 122 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import torch
1717
import torch.nn.utils.parametrize as parametrize
18+
from torch.utils._python_dispatch import return_and_correct_aliasing
1819

1920
__all__ = [
2021
"benchmark_model",
@@ -409,6 +410,9 @@ def _(func, types, args, kwargs):
409410
if not hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE"):
410411
cls._ATEN_OP_OR_TORCH_FN_TABLE = {}
411412

413+
if cls not in cls._ATEN_OP_OR_TORCH_FN_TABLE:
414+
cls._ATEN_OP_OR_TORCH_FN_TABLE[cls] = {}
415+
412416
if not isinstance(aten_ops_or_torch_fns, (list, tuple)):
413417
aten_ops_or_torch_fns = [aten_ops_or_torch_fns]
414418

@@ -419,12 +423,83 @@ def decorator(func):
419423
def wrapper(f, types, args, kwargs):
420424
return func(f, types, args, kwargs)
421425

422-
cls._ATEN_OP_OR_TORCH_FN_TABLE[op] = wrapper
426+
cls._ATEN_OP_OR_TORCH_FN_TABLE[cls][op] = wrapper
423427
return func
424428

425429
return decorator
426430

427431

432+
def _implements_common_tensor_ops(cls):
433+
implements = cls.implements
434+
aten = torch.ops.aten
435+
436+
@implements(
437+
[aten.detach.default, aten.clone.default, aten.alias.default, aten.contiguous]
438+
)
439+
def _(func, types, args, kwargs):
440+
return return_and_correct_aliasing(
441+
func,
442+
args,
443+
kwargs,
444+
args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)),
445+
)
446+
447+
def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool:
448+
_tensor_shape_match = all(
449+
getattr(self, t_name).shape == getattr(src, t_name).shape
450+
for t_name in self.tensor_data_names
451+
)
452+
_attr_match = all(
453+
getattr(self, a_name) == getattr(src, a_name)
454+
for a_name in self.tensor_attribute_names
455+
)
456+
return (
457+
type(self) == type(src)
458+
and self.shape == src.shape
459+
and _tensor_shape_match
460+
and _attr_match
461+
)
462+
463+
@implements(aten.copy_.default)
464+
def _(func, types, args, kwargs):
465+
self = args[0]
466+
src = args[1]
467+
if _same_metadata(self, src):
468+
self_tensors = self.__tensor_flatten__()[0]
469+
for tensor_name in self_tensors:
470+
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
471+
return
472+
raise ValueError(
473+
f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}"
474+
)
475+
476+
@implements(aten._to_copy.default)
477+
def _(func, types, args, kwargs):
478+
self = args[0]
479+
if hasattr(self, "tensor_data_names") and hasattr(
480+
self, "tensor_attribute_names"
481+
):
482+
kwargs = self._get_to_kwargs(*args[1:], **kwargs)
483+
device = kwargs.pop("device")
484+
tensors = [
485+
getattr(self, name).to(device) for name in self.tensor_data_names
486+
]
487+
# change device
488+
tensor_attributes = [
489+
getattr(self, attr_name) if attr_name != "device" else device
490+
for attr_name in self.tensor_attribute_names
491+
]
492+
t = self.__class__(
493+
*tensors,
494+
*tensor_attributes,
495+
)
496+
return return_and_correct_aliasing(func, args, kwargs, t)
497+
498+
raise NotImplementedError(
499+
"Subclasses must implement `aten._to_copy.default` or specify `tensor_data_names` and `tensor_attribute_names` for tensor class or tensor instance before using it"
500+
)
501+
502+
428503
def _dispatch__torch_function__(cls, func, types, args=(), kwargs=None):
429504
"""Use this util function for a common `__torch_function__` implementation
430505
that dispatches to ops/functions registered with `_implements`
@@ -436,9 +511,10 @@ class MyTensor(torch.Tensor):
436511
kwargs = {} if kwargs is None else kwargs
437512
if (
438513
hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE")
439-
and func in cls._ATEN_OP_OR_TORCH_FN_TABLE
514+
and cls in cls._ATEN_OP_OR_TORCH_FN_TABLE
515+
and func in cls._ATEN_OP_OR_TORCH_FN_TABLE[cls]
440516
):
441-
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
517+
return cls._ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, types, args, kwargs)
442518

443519
with torch._C.DisableTorchFunctionSubclass():
444520
return func(*args, **kwargs)
@@ -454,9 +530,10 @@ class MyTensor(torch.Tensor):
454530
"""
455531
if (
456532
hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE")
457-
and func in cls._ATEN_OP_OR_TORCH_FN_TABLE
533+
and cls in cls._ATEN_OP_OR_TORCH_FN_TABLE
534+
and func in cls._ATEN_OP_OR_TORCH_FN_TABLE[cls]
458535
):
459-
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
536+
return cls._ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, types, args, kwargs)
460537

461538
arg_types = tuple(type(arg) for arg in args)
462539
kwarg_types = {k: type(arg) for k, arg in kwargs.items()}
@@ -576,7 +653,28 @@ class PlainAQTTensorImpl(...):
576653
577654
"""
578655

656+
@classmethod
657+
def __init_subclass__(cls, **kwargs):
658+
if not hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE"):
659+
cls._ATEN_OP_OR_TORCH_FN_TABLE = {}
660+
661+
if cls not in cls._ATEN_OP_OR_TORCH_FN_TABLE:
662+
cls._ATEN_OP_OR_TORCH_FN_TABLE[cls] = {}
663+
664+
# define the common ops if the tensor_data_names and tensor_attribute_names are defined
665+
if hasattr(cls, "tensor_data_names") and hasattr(cls, "tensor_attribute_names"):
666+
cls._implements_common_tensor_ops()
667+
668+
# inherit the torch function and dispatch implementations from direct parent classes
669+
# e.g. for `class C(B, A)`, C.__bases__ == (B, A)
670+
for parent in cls.__bases__:
671+
if parent in cls._ATEN_OP_OR_TORCH_FN_TABLE:
672+
cls._ATEN_OP_OR_TORCH_FN_TABLE[cls].update(
673+
cls._ATEN_OP_OR_TORCH_FN_TABLE[parent]
674+
)
675+
579676
implements = classmethod(_implements)
677+
_implements_common_tensor_ops = classmethod(_implements_common_tensor_ops)
580678
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
581679
__torch_function__ = classmethod(_dispatch__torch_function__)
582680
register_layout = classmethod(_register_layout)
@@ -591,7 +689,7 @@ def __tensor_flatten__(self):
591689
getattr(self, attr) for attr in self.tensor_attribute_names
592690
]
593691
raise NotImplementedError(
594-
"Subclasses must implement __tensor_flatten__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class or tensor instance"
692+
"Subclasses should implement __tensor_flatten__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class or tensor instance before using it"
595693
)
596694

597695
@classmethod
@@ -602,13 +700,20 @@ def __tensor_unflatten__(
602700
return cls(*tensors, *tensor_attributes)
603701

604702
def _apply_fn_to_data(self, fn):
605-
tensors = [fn(getattr(self, attr)) for attr in self.tensor_data_names]
606-
tensor_attributes = [
607-
getattr(self, attr) for attr in self.tensor_attribute_names
608-
]
609-
return self.__class__(
610-
*tensors,
611-
*tensor_attributes,
703+
if hasattr(self, "tensor_data_names") and hasattr(
704+
self, "tensor_attribute_names"
705+
):
706+
tensors = [fn(getattr(self, attr)) for attr in self.tensor_data_names]
707+
tensor_attributes = [
708+
getattr(self, attr) for attr in self.tensor_attribute_names
709+
]
710+
return self.__class__(
711+
*tensors,
712+
*tensor_attributes,
713+
)
714+
715+
raise NotImplementedError(
716+
"Subclasses should implement _apply_fn_to_data or specify `tensor_data_names` and `tensor_attribute_names` for tensor class or tensor instance before using it"
612717
)
613718

614719
def __repr__(self):
@@ -624,7 +729,10 @@ def __repr__(self):
624729
f", {tensor_attribute_name}={getattr(self, tensor_attribute_name)}"
625730
)
626731
return f"{self.__class__.__name__}({repr_str})"
627-
raise NotImplementedError("Subclasses must implement __repr__")
732+
733+
raise NotImplementedError(
734+
"Subclasses must implement __repr__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class or tensor instance before using it"
735+
)
628736

629737
def get_layout(self):
630738
if not hasattr(self, "_layout"):

0 commit comments

Comments
 (0)