From b1ca10449f528aa7a0f74b9d3b40ba5de3bd5c3b Mon Sep 17 00:00:00 2001 From: Krishn Parasar Date: Sat, 11 Oct 2025 03:18:40 +0530 Subject: [PATCH 1/2] testing subclassing --- test/test_utils.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/test/test_utils.py b/test/test_utils.py index b46d600053..60a7a390b9 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -392,6 +392,49 @@ def fake_linear(func, types, args, kwargs): counter["calls"], 2, "Expected fake_linear to be called via aten.t.default" ) + def test_subclassing(self): + class Parent(TorchAOBaseTensor): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr"] + + Parent._ATEN_OP_TABLE[Parent]["op_parent"] = "parent_impl" + Parent._TORCH_FN_TABLE[Parent]["fn_parent"] = "parent_fn_impl" + + class Child(Parent): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr"] + + # ensure child has copied parent ops + self.assertEqual(Child._ATEN_OP_TABLE[Child]["op_parent"], "parent_impl") + self.assertEqual(Child._TORCH_FN_TABLE[Child]["fn_parent"], "parent_fn_impl") + + # ensure the top-level dicts are distinct (not inherited) + self.assertIsNot(Parent._ATEN_OP_TABLE, Child._ATEN_OP_TABLE) + self.assertIsNot(Parent._TORCH_FN_TABLE, Child._TORCH_FN_TABLE) + + # change the parent's op after subclass creation — should not leak + Parent._ATEN_OP_TABLE[Parent]["new_op"] = "added_later" + self.assertNotIn("new_op", Child._ATEN_OP_TABLE[Child]) + + def test_multiple_inheritance(self): + class A(TorchAOBaseTensor): + tensor_data_names = ["a"] + tensor_attribute_names = ["b"] + + class B(TorchAOBaseTensor): + tensor_data_names = ["a"] + tensor_attribute_names = ["b"] + + A._ATEN_OP_TABLE[A]["shared"] = "from_a" + B._ATEN_OP_TABLE[B]["shared"] = "from_b" + + class C(A, B): + tensor_data_names = ["a"] + tensor_attribute_names = ["b"] + + # C(A, B) should inherit from A then B, so B wins + self.assertEqual(C._ATEN_OP_TABLE[C]["shared"], "from_b") + if __name__ == "__main__": unittest.main() From 6f490d2b5b262dd15769cea7a00b951d4d16cd7c Mon Sep 17 00:00:00 2001 From: Krishn Parasar Date: Sat, 11 Oct 2025 20:52:11 +0530 Subject: [PATCH 2/2] Adding test with real op implementation the new test with real op fails for the same line self.assertIsNot(Parent._ATEN_OP_TABLE, Child._ATEN_OP_TABLE). Also when it is called using the child tensor, it works. For the inheritance test case, it fails at self.assertEqual(C._ATEN_OP_TABLE[C]["shared"], "from_b") AssertionError: 'from_a' != 'from_b' --- test/test_utils.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/test/test_utils.py b/test/test_utils.py index 60a7a390b9..206b935937 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -416,6 +416,49 @@ class Child(Parent): Parent._ATEN_OP_TABLE[Parent]["new_op"] = "added_later" self.assertNotIn("new_op", Child._ATEN_OP_TABLE[Child]) + def test_subclassing_with_real_op(self): + counter = {"calls": 0} + + class Parent(TorchAOBaseTensor): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr"] + + def __new__(cls, qdata, attr): + r = torch.Tensor._make_wrapper_subclass(cls, qdata.shape) + r.qdata = qdata + r.attr = attr + return r + + def __init__(self, qdata, attr): + pass + + # Real op implementation + @Parent.implements([torch.ops.aten.cat.default]) + def _cat_op(func, types, args, kwargs): + counter["calls"] += 1 + return func(*args, **kwargs) + + class Child(Parent): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr"] + + # Table checks + self.assertIn(torch.ops.aten.cat.default, Parent._ATEN_OP_TABLE[Parent]) + self.assertIn(torch.ops.aten.cat.default, Child._ATEN_OP_TABLE[Child]) + + # Ensure child table is distinct + self.assertIsNot(Parent._ATEN_OP_TABLE, Child._ATEN_OP_TABLE) + + # calling the op through the child tensor + t1 = torch.randn(2, 3) + t2 = torch.randn(2, 3) + child_tensor1 = Child(t1, "a") + child_tensor2 = Child(t2, "b") + + torch.ops.aten.cat.default([child_tensor1, child_tensor2], 0) + + self.assertEqual(counter["calls"], 1) + def test_multiple_inheritance(self): class A(TorchAOBaseTensor): tensor_data_names = ["a"]