|
67 | 67 | from executorch.exir.tensor import TensorSpec |
68 | 68 | from executorch.exir.tests.common import register_additional_test_aten_ops |
69 | 69 | from executorch.exir.tests.control_flow_models import FTCondDeadCode, FTMapBasic |
70 | | -from executorch.exir.tests.models import MLP, Mul |
| 70 | +from executorch.exir.tests.models import FeedForwardBlock, MLP, Mul |
71 | 71 | from functorch.experimental import control_flow |
72 | 72 |
|
73 | 73 | from torch import nn |
@@ -870,6 +870,69 @@ def test_debug_handle_generator_pass(self) -> None: |
870 | 870 | if node.op != "placeholder" and node.op != "output": |
871 | 871 | self.assertIn("debug_handle", node.meta) |
872 | 872 |
|
| 873 | + def test_debug_handle_generator_pass_generate_same_debug_handle_on_ops_sharing_same_source( |
| 874 | + self, |
| 875 | + ) -> None: |
| 876 | + eager_model = FeedForwardBlock(256, 512) |
| 877 | + inputs = (torch.randn(12, 256),) |
| 878 | + |
| 879 | + graph_module = ( |
| 880 | + to_edge(export(eager_model, inputs, strict=True)) |
| 881 | + .exported_program() |
| 882 | + .graph_module |
| 883 | + ) |
| 884 | + |
| 885 | + same_source_nodes = { |
| 886 | + "aten_native_layer_norm_default": ( |
| 887 | + "aten_native_layer_norm_default", |
| 888 | + "getitem", |
| 889 | + ), |
| 890 | + "getitem": ("aten_native_layer_norm_default", "getitem"), |
| 891 | + "aten_permute_copy_default": ( |
| 892 | + "aten_permute_copy_default", |
| 893 | + "aten_addmm_default", |
| 894 | + ), |
| 895 | + "aten_addmm_default": ("aten_permute_copy_default", "aten_addmm_default"), |
| 896 | + "aten_native_dropout_default": ("aten_native_dropout_default", "getitem_1"), |
| 897 | + "getitem_1": ("aten_native_dropout_default", "getitem_1"), |
| 898 | + "aten_relu_default": ("aten_relu_default",), |
| 899 | + "aten_permute_copy_default_1": ( |
| 900 | + "aten_permute_copy_default_1", |
| 901 | + "aten_addmm_default_1", |
| 902 | + ), |
| 903 | + "aten_addmm_default_1": ( |
| 904 | + "aten_permute_copy_default_1", |
| 905 | + "aten_addmm_default_1", |
| 906 | + ), |
| 907 | + "aten_native_dropout_default_1": ( |
| 908 | + "aten_native_dropout_default_1", |
| 909 | + "getitem_2", |
| 910 | + ), |
| 911 | + "getitem_2": ("aten_native_dropout_default_1", "getitem_2"), |
| 912 | + } |
| 913 | + |
| 914 | + node_name_to_debug_handle = {} |
| 915 | + |
| 916 | + # Node having same source should have same debug handle |
| 917 | + for node in graph_module.graph.nodes: |
| 918 | + if node.op != "placeholder" and node.op != "output": |
| 919 | + self.assertIn("debug_handle", node.meta) |
| 920 | + if node.name in node_name_to_debug_handle: |
| 921 | + for node_name_with_same_debug_handle in same_source_nodes[ |
| 922 | + node.name |
| 923 | + ]: |
| 924 | + self.assertEqual( |
| 925 | + node_name_to_debug_handle[node_name_with_same_debug_handle], |
| 926 | + node.meta["debug_handle"], |
| 927 | + ) |
| 928 | + else: |
| 929 | + for node_name_with_same_debug_handle in same_source_nodes[ |
| 930 | + node.name |
| 931 | + ]: |
| 932 | + node_name_to_debug_handle[node_name_with_same_debug_handle] = ( |
| 933 | + node.meta["debug_handle"] |
| 934 | + ) |
| 935 | + |
873 | 936 | def test_generate_missing_debug_handles(self) -> None: |
874 | 937 | eager_model = MLP(2, output_size=4) |
875 | 938 | inputs = eager_model.get_random_inputs() |
|
0 commit comments