diff --git a/test/test_utils.py b/test/test_utils.py index b46d600053..206b935937 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -392,6 +392,92 @@ def fake_linear(func, types, args, kwargs): counter["calls"], 2, "Expected fake_linear to be called via aten.t.default" ) + def test_subclassing(self): + class Parent(TorchAOBaseTensor): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr"] + + Parent._ATEN_OP_TABLE[Parent]["op_parent"] = "parent_impl" + Parent._TORCH_FN_TABLE[Parent]["fn_parent"] = "parent_fn_impl" + + class Child(Parent): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr"] + + # ensure child has copied parent ops + self.assertEqual(Child._ATEN_OP_TABLE[Child]["op_parent"], "parent_impl") + self.assertEqual(Child._TORCH_FN_TABLE[Child]["fn_parent"], "parent_fn_impl") + + # ensure the top-level dicts are distinct (not inherited) + self.assertIsNot(Parent._ATEN_OP_TABLE, Child._ATEN_OP_TABLE) + self.assertIsNot(Parent._TORCH_FN_TABLE, Child._TORCH_FN_TABLE) + + # change the parent's op after subclass creation — should not leak + Parent._ATEN_OP_TABLE[Parent]["new_op"] = "added_later" + self.assertNotIn("new_op", Child._ATEN_OP_TABLE[Child]) + + def test_subclassing_with_real_op(self): + counter = {"calls": 0} + + class Parent(TorchAOBaseTensor): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr"] + + def __new__(cls, qdata, attr): + r = torch.Tensor._make_wrapper_subclass(cls, qdata.shape) + r.qdata = qdata + r.attr = attr + return r + + def __init__(self, qdata, attr): + pass + + # Real op implementation + @Parent.implements([torch.ops.aten.cat.default]) + def _cat_op(func, types, args, kwargs): + counter["calls"] += 1 + return func(*args, **kwargs) + + class Child(Parent): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr"] + + # Table checks + self.assertIn(torch.ops.aten.cat.default, Parent._ATEN_OP_TABLE[Parent]) + self.assertIn(torch.ops.aten.cat.default, Child._ATEN_OP_TABLE[Child]) + + # Ensure child table is distinct + self.assertIsNot(Parent._ATEN_OP_TABLE, Child._ATEN_OP_TABLE) + + # calling the op through the child tensor + t1 = torch.randn(2, 3) + t2 = torch.randn(2, 3) + child_tensor1 = Child(t1, "a") + child_tensor2 = Child(t2, "b") + + torch.ops.aten.cat.default([child_tensor1, child_tensor2], 0) + + self.assertEqual(counter["calls"], 1) + + def test_multiple_inheritance(self): + class A(TorchAOBaseTensor): + tensor_data_names = ["a"] + tensor_attribute_names = ["b"] + + class B(TorchAOBaseTensor): + tensor_data_names = ["a"] + tensor_attribute_names = ["b"] + + A._ATEN_OP_TABLE[A]["shared"] = "from_a" + B._ATEN_OP_TABLE[B]["shared"] = "from_b" + + class C(A, B): + tensor_data_names = ["a"] + tensor_attribute_names = ["b"] + + # C(A, B) should inherit from A then B, so B wins + self.assertEqual(C._ATEN_OP_TABLE[C]["shared"], "from_b") + if __name__ == "__main__": unittest.main()