diff --git a/test/test_utils.py b/test/test_utils.py index 5704e9963c..c5bbf45a96 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -105,34 +105,11 @@ def __init__(self, data): with self.assertRaisesRegex(NotImplementedError, "arg_types"): l.weight = torch.nn.Parameter(MyTensor(l.weight)) - @skip_if_no_cuda() - def test_default_impls(self): - """Making sure some common functions has default implementations, such as - __tensor_unflatten__, __tensor_flatten__, _apply_fn_to_data, __repr__, to - """ - - class MyTensor(TorchAOBaseTensor): - tensor_data_names = ["qdata"] - tensor_attribute_names = ["attr", "device"] - - def __new__(cls, qdata, attr, device=None): - shape = qdata.shape - if device is None: - device = qdata.device - kwargs = {"device": device} - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__(self, qdata, attr, device=None): - self.qdata = qdata - self.attr = attr - - l = torch.nn.Linear(2, 3) - l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr")) - lp_tensor = l.weight + def _test_default_impls_helper(self, lp_tensor, lp_tensor_for_copy): # test __tensor_flatten__ and __tensor_unflatten__ - tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() + tensor_data_names, tensor_attributes = lp_tensor.__tensor_flatten__() tensor_data_dict = { - name: getattr(lp_tensor, name) for name in tensor_data_name_dict + name: getattr(lp_tensor, name) for name in tensor_data_names } outer_size = lp_tensor.size() outer_stride = lp_tensor.stride() @@ -150,31 +127,121 @@ def __init__(self, qdata, attr, device=None): self.assertEqual(lp_tensor.device, original_device) # __repr__ - print(lp_tensor) + _ = str(lp_tensor) # other ops lp_tensor = lp_tensor.detach() # explicitly testing aten.alias lp_tensor = torch.ops.aten.alias(lp_tensor) lp_tensor = lp_tensor.clone() - # making qdata not contiguous - lp_tensor.qdata = lp_tensor.qdata.transpose(0, 1).contiguous() - lp_tensor.qdata = lp_tensor.qdata.transpose(0, 1) - self.assertFalse(lp_tensor.qdata.is_contiguous()) - lp_tensor = lp_tensor.contiguous() - # making sure contiguous call works - self.assertTrue(lp_tensor.qdata.is_contiguous()) + # get all tensor_data_names for both + # non optional and valid optional tensors + tensor_data_names = lp_tensor.tensor_data_names.copy() + if hasattr(lp_tensor, "optional_tensor_data_names"): + for tensor_data_name in lp_tensor.optional_tensor_data_names: + if getattr(lp_tensor, tensor_data_name) is not None: + tensor_data_names.append(tensor_data_name) + + # for each of the tensor data, we try to + # make it non-contiguous and then use + # lp_tensor.contiguous() call to make sure + # contiguous() works + for tensor_data_name in tensor_data_names: + tensor = getattr(lp_tensor, tensor_data_name) + # making qdata not contiguous + tensor = tensor.transpose(0, 1).contiguous() + tensor = tensor.transpose(0, 1) + setattr(lp_tensor, tensor_data_name, tensor) + self.assertFalse(getattr(lp_tensor, tensor_data_name).is_contiguous()) + lp_tensor = lp_tensor.contiguous() + # making sure contiguous call works + self.assertTrue(getattr(lp_tensor, tensor_data_name).is_contiguous()) # copy_ + # making sure that initially tensor values are not the same so we can test copy_ + self.assertNotEqual(lp_tensor.qdata[0][0], lp_tensor_for_copy.qdata[0][0]) + # copy_ requires the attributes to be the same + for tensor_attr_name in lp_tensor.tensor_attribute_names: + self.assertEqual( + getattr(lp_tensor, tensor_attr_name), + getattr(lp_tensor_for_copy, tensor_attr_name), + ) + lp_tensor.copy_(lp_tensor_for_copy) + # after copy_, the tensor values should match + for tensor_data_name in tensor_data_names: + self.assertTrue( + torch.equal( + getattr(lp_tensor, tensor_data_name), + getattr(lp_tensor_for_copy, tensor_data_name), + ) + ) + + @skip_if_no_cuda() + def test_default_impls(self): + """Making sure some common functions has default implementations, such as + __tensor_unflatten__, __tensor_flatten__, _apply_fn_to_data, __repr__, to + """ + + class MyTensor(TorchAOBaseTensor): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr", "device"] + + def __new__(cls, qdata, attr, device=None): + shape = qdata.shape + if device is None: + device = qdata.device + kwargs = {"device": device} + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, qdata, attr, device=None): + self.qdata = qdata + self.attr = attr + + l = torch.nn.Linear(2, 3) + l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr")) + lp_tensor = l.weight + another_tensor = torch.nn.Linear(2, 3).weight # attribute has to be the same - another_lp_tensor = MyTensor(another_tensor, "attr") - # initially tensor values are not the same - self.assertNotEqual(lp_tensor.qdata[0][0], another_lp_tensor.qdata[0][0]) - lp_tensor.copy_(another_lp_tensor) - self.assertEqual(lp_tensor.attr, "attr") - # after copy_, the tensor values should match - self.assertEqual(lp_tensor.qdata[0][0], another_lp_tensor.qdata[0][0]) + lp_tensor_for_copy = MyTensor(another_tensor, "attr") + self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) + + @skip_if_no_cuda() + def test_default_impls_with_optional_data(self): + class MyTensorWithOptionalData(TorchAOBaseTensor): + tensor_data_names = ["qdata"] + optional_tensor_data_names = ["zero_point"] + tensor_attribute_names = ["attr", "device"] + + def __new__(cls, qdata, zero_point=None, attr=1.0, device=None): + shape = qdata.shape + if device is None: + device = qdata.device + kwargs = {"device": device} + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, qdata, zero_point=None, attr=1.0, device=None): + self.qdata = qdata + self.zero_point = zero_point + self.attr = attr + + # test both the optional Tensor is None + # and not None + l = torch.nn.Linear(2, 3) + lp_tensor = MyTensorWithOptionalData(l.weight, None, "attr") + l = torch.nn.Linear(2, 3) + lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, None, "attr") + self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) + + l = torch.nn.Linear(2, 3) + lp_tensor = MyTensorWithOptionalData( + l.weight, torch.zeros_like(l.weight), "attr" + ) + l = torch.nn.Linear(2, 3) + lp_tensor_for_copy = MyTensorWithOptionalData( + l.weight, torch.zeros_like(l.weight), "attr" + ) + self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) if __name__ == "__main__": diff --git a/torchao/utils.py b/torchao/utils.py index 4a5bca699b..40ca9e3702 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -510,6 +510,16 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool: getattr(self, t_name).shape == getattr(src, t_name).shape for t_name in self.tensor_data_names ) + _optional_tensor_shape_match = True + if hasattr(self, "optional_tensor_data_names"): + # either both are None or both are not Tensors and the shape match + _optional_tensor_shape_match = all( + getattr(self, t_name).shape == getattr(src, t_name).shape + if getattr(self, t_name) is not None + else getattr(src, t_name) is None + for t_name in self.optional_tensor_data_names + ) + _attr_match = all( getattr(self, a_name) == getattr(src, a_name) for a_name in self.tensor_attribute_names @@ -518,6 +528,7 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool: type(self) == type(src) and self.shape == src.shape and _tensor_shape_match + and _optional_tensor_shape_match and _attr_match ) @@ -545,6 +556,14 @@ def _(func, types, args, kwargs): tensors = [ getattr(self, name).to(device) for name in self.tensor_data_names ] + if hasattr(self, "optional_tensor_data_names"): + for tensor_data_name in self.optional_tensor_data_names: + maybe_tensor = getattr(self, tensor_data_name) + if maybe_tensor is not None: + tensors.append(maybe_tensor.to(device)) + else: + tensors.append(None) + # change device tensor_attributes = [ getattr(self, attr_name) if attr_name != "device" else device @@ -712,6 +731,52 @@ class PlainAQTTensorImpl(...): tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout) + class variables to define to simplify implmentation of tensor subclasses: + `tensor_data_names` (List[str]): list of names of all requires tensor_data, order should match + the `__init__` list of tensor subclass + `optional_tensor_data_names` (List[str]): it's optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional Tensor attributes, when defined, this will be a list of names of Tensors that can be optional + `tensor_attribute_names` (List[str]): list of names of non-Tensor attributes, + order should match the `__init__` list of tensor subclass, following all the `tensor_data_names` arguments and `optional_tensor_data_names` + + If `tensor_data_names` and `tensor_attribute_names` are defined, there are some additional + functions that will be added, this includes: + `__tensor_flatten__`: flattens a subclassed tensor instance, returns a tuple, first element is tensor data names for valid tensor data, + second element is a list of non-Tensor attributes + `__tensor_unflatten__`: takes a tensor_data_dict (a map from tensor name to Tensor), and list of non-tensor attributes, returns a new instance of the subclassed tensor + `_apply_fn_to_data`: takes a function (Tensor -> Tensor), applies function to all tensor data and + recreate a new subclassed Tensor with the transformed tensor data + `__repr__`: the string representation of the subclassed tensor instance + torch ops: torch.Tensor.contiguous + aten ops: aten.detach.default, aten.clone.default, aten.alias,default, aten.contiguous.default, aten.copy_.default, aten._to_copy.default (enables t.to) + + Example: + class MyTensor(torch.Tensor): + tensor_data_names = ["a", "b"] + optional_tensor_data_names = ["c", "d"] + tensor_attribute_names = ["e", "f"] + + def __new__( + cls, + a: Tensor, + b: Tensor, + c: Optional[Tensor], + d: Optional[Tensor], + e: int, + f: str + ): + pass + + def __init__( + self, + a: Tensor, + b: Tensor, + c: Optional[Tensor], + d: Optional[Tensor], + e: int, + f: str + ): + pass + """ @classmethod @@ -746,7 +811,14 @@ def __tensor_flatten__(self): if hasattr(self, "tensor_data_names") and hasattr( self, "tensor_attribute_names" ): - return self.tensor_data_names, [ + tensor_data_names = self.tensor_data_names.copy() + if hasattr(self, "optional_tensor_data_names"): + for tensor_data_name in self.optional_tensor_data_names: + maybe_tensor = getattr(self, tensor_data_name) + if maybe_tensor is not None: + tensor_data_names.append(tensor_data_name) + + return tensor_data_names, [ getattr(self, attr) for attr in self.tensor_attribute_names ] raise NotImplementedError( @@ -758,6 +830,12 @@ def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): tensors = [tensor_data_dict[name] for name in cls.tensor_data_names] + if hasattr(cls, "optional_tensor_data_names"): + for tensor_data_name in cls.optional_tensor_data_names: + if tensor_data_name in tensor_data_dict: + tensors.append(tensor_data_dict[tensor_data_name]) + else: + tensors.append(None) return cls(*tensors, *tensor_attributes) def _apply_fn_to_data(self, fn): @@ -765,6 +843,14 @@ def _apply_fn_to_data(self, fn): self, "tensor_attribute_names" ): tensors = [fn(getattr(self, attr)) for attr in self.tensor_data_names] + if hasattr(self, "optional_tensor_data_names"): + for tensor_data_name in self.optional_tensor_data_names: + maybe_tensor = getattr(self, tensor_data_name) + if maybe_tensor is not None: + tensors.append(fn(maybe_tensor)) + else: + tensors.append(None) + tensor_attributes = [ getattr(self, attr) for attr in self.tensor_attribute_names ] @@ -785,6 +871,12 @@ def __repr__(self): repr_str += f"{self.tensor_data_names[0]}={getattr(self, self.tensor_data_names[0])}" for tensor_data_name in self.tensor_data_names[1:]: repr_str += f", {tensor_data_name}={getattr(self, tensor_data_name)}" + if hasattr(self, "optional_tensor_data_names"): + for tensor_data_name in self.optional_tensor_data_names: + repr_str += ( + f", {tensor_data_name}={getattr(self, tensor_data_name)}" + ) + for tensor_attribute_name in self.tensor_attribute_names: repr_str += ( f", {tensor_attribute_name}={getattr(self, tensor_attribute_name)}"