Skip to content

Commit b51110f

Browse files
committed
Update on "[et] generate debug handle before opeartor decomposition"
This diff update the debug handle generation, from each node in the edge program having a individual debug handle, to all nodes having a same ancestor in export graph sharing a same debug handle, which update the start point of tracing our node transformation from edge graph to exported graph. Differential Revision: [D76860368](https://our.internmc.facebook.com/intern/diff/D76860368/) [ghstack-poisoned]
2 parents 5fad521 + 1b7d708 commit b51110f

File tree

1 file changed

+64
-1
lines changed

1 file changed

+64
-1
lines changed

exir/tests/test_passes.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
from executorch.exir.tensor import TensorSpec
6868
from executorch.exir.tests.common import register_additional_test_aten_ops
6969
from executorch.exir.tests.control_flow_models import FTCondDeadCode, FTMapBasic
70-
from executorch.exir.tests.models import MLP, Mul
70+
from executorch.exir.tests.models import FeedForwardBlock, MLP, Mul
7171
from functorch.experimental import control_flow
7272

7373
from torch import nn
@@ -870,6 +870,69 @@ def test_debug_handle_generator_pass(self) -> None:
870870
if node.op != "placeholder" and node.op != "output":
871871
self.assertIn("debug_handle", node.meta)
872872

873+
def test_debug_handle_generator_pass_generate_same_debug_handle_on_ops_sharing_same_source(
874+
self,
875+
) -> None:
876+
eager_model = FeedForwardBlock(256, 512)
877+
inputs = (torch.randn(12, 256),)
878+
879+
graph_module = (
880+
to_edge(export(eager_model, inputs, strict=True))
881+
.exported_program()
882+
.graph_module
883+
)
884+
885+
same_source_nodes = {
886+
"aten_native_layer_norm_default": (
887+
"aten_native_layer_norm_default",
888+
"getitem",
889+
),
890+
"getitem": ("aten_native_layer_norm_default", "getitem"),
891+
"aten_permute_copy_default": (
892+
"aten_permute_copy_default",
893+
"aten_addmm_default",
894+
),
895+
"aten_addmm_default": ("aten_permute_copy_default", "aten_addmm_default"),
896+
"aten_native_dropout_default": ("aten_native_dropout_default", "getitem_1"),
897+
"getitem_1": ("aten_native_dropout_default", "getitem_1"),
898+
"aten_relu_default": ("aten_relu_default",),
899+
"aten_permute_copy_default_1": (
900+
"aten_permute_copy_default_1",
901+
"aten_addmm_default_1",
902+
),
903+
"aten_addmm_default_1": (
904+
"aten_permute_copy_default_1",
905+
"aten_addmm_default_1",
906+
),
907+
"aten_native_dropout_default_1": (
908+
"aten_native_dropout_default_1",
909+
"getitem_2",
910+
),
911+
"getitem_2": ("aten_native_dropout_default_1", "getitem_2"),
912+
}
913+
914+
node_name_to_debug_handle = {}
915+
916+
# Node having same source should have same debug handle
917+
for node in graph_module.graph.nodes:
918+
if node.op != "placeholder" and node.op != "output":
919+
self.assertIn("debug_handle", node.meta)
920+
if node.name in node_name_to_debug_handle:
921+
for node_name_with_same_debug_handle in same_source_nodes[
922+
node.name
923+
]:
924+
self.assertEqual(
925+
node_name_to_debug_handle[node_name_with_same_debug_handle],
926+
node.meta["debug_handle"],
927+
)
928+
else:
929+
for node_name_with_same_debug_handle in same_source_nodes[
930+
node.name
931+
]:
932+
node_name_to_debug_handle[node_name_with_same_debug_handle] = (
933+
node.meta["debug_handle"]
934+
)
935+
873936
def test_generate_missing_debug_handles(self) -> None:
874937
eager_model = MLP(2, output_size=4)
875938
inputs = eager_model.get_random_inputs()

0 commit comments

Comments
 (0)