Skip to content

Commit 69074bc

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Optionally disable debug handle validateion (#14182)
Summary: Often when aten graph has symbolic shape nodes, and inbuilt ops like gt/lt etc., during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose connection between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR node has corresponding node in aten IR, and when such validation fails numeric debugger falls back to edge IR as reference graph. This flag allows one to override such behavior and make best effort comparison. Reviewed By: Gasoonjia Differential Revision: D81784685
1 parent 8358516 commit 69074bc

File tree

3 files changed

+151
-5
lines changed

3 files changed

+151
-5
lines changed

devtools/inspector/_inspector.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,7 @@ def _consume_etrecord(self) -> None:
11691169

11701170
def _get_aot_intermediate_outputs_and_op_names(
11711171
self,
1172+
disable_debug_handle_valdiation: bool = False,
11721173
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]:
11731174
"""
11741175
Capture intermediate outputs only if _representative_inputs are provided
@@ -1184,6 +1185,7 @@ def _get_aot_intermediate_outputs_and_op_names(
11841185
self._etrecord.exported_program,
11851186
self._etrecord.export_graph_id,
11861187
self._etrecord.edge_dialect_program,
1188+
disable_debug_handle_valdiation,
11871189
):
11881190
export_program = self._etrecord.exported_program
11891191
else:
@@ -1404,7 +1406,9 @@ def get_exported_program(
14041406
else self._etrecord.graph_map.get(graph)
14051407
)
14061408

1407-
def calculate_numeric_gap(self, distance: str = "MSE"):
1409+
def calculate_numeric_gap(
1410+
self, distance: str = "MSE", disable_debug_handle_valdiation: bool = False
1411+
):
14081412
"""
14091413
Compares logged intermediate outputs from the exported graph (in ETRecord)
14101414
with runtime outputs (in ETDump) using a user-specific numerical comparator.
@@ -1416,12 +1420,19 @@ def calculate_numeric_gap(self, distance: str = "MSE"):
14161420
14171421
Args:
14181422
distance: the metrics the inspector will use for gap calculation. Should be one of "MSE", "L1" and "SNR".
1423+
disable_debug_handle_validation: Often when aten graph has symbolic shape nodes, and inbuilt ops like gt/lt etc.,
1424+
during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose connection
1425+
between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR node has corresponding
1426+
node in aten IR, and when such validation fails numeric debugger falls back to edge IR as reference graph. This
1427+
flag allows one to override such behavior and make best effort comparison.
14191428
14201429
Returns:
14211430
pd.DataFrame: A DataFrame listing corresponding operator intermediate outputs from both stages and their computed numerical gaps.
14221431
"""
14231432
aot_intermediate_outputs, aot_debug_handle_to_op_names = (
1424-
self._get_aot_intermediate_outputs_and_op_names()
1433+
self._get_aot_intermediate_outputs_and_op_names(
1434+
disable_debug_handle_valdiation
1435+
)
14251436
)
14261437
if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_names) == 0:
14271438
raise ValueError(
@@ -1451,6 +1462,15 @@ def calculate_numeric_gap(self, distance: str = "MSE"):
14511462
) in mapping.items():
14521463
if aot_intermediate_output is None or runtime_intermediate_output is None:
14531464
continue
1465+
# If aot outputs length is > 1 then comparison fails since we dont really have
1466+
# any instances where runtime intermediate output is a tuple or list
1467+
# This does not happen when edge dialect program is reference for comparison
1468+
# but happens in aten graph where ops like unbind remain undecomposed
1469+
if (
1470+
isinstance(aot_intermediate_output, Sequence)
1471+
and len(aot_intermediate_output) > 1
1472+
):
1473+
continue
14541474
rows.append(
14551475
{
14561476
"aot_ops": find_op_names(

devtools/inspector/_inspector_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
965965
# Ensure both sequences have the same length
966966
if len(a) != len(b):
967967
raise ValueError(
968-
f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison."
968+
f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison. len(a): {len(a)} len(b): {len(b)}."
969969
)
970970

971971
# Compare each element in the sequences and return the list of results
@@ -990,6 +990,9 @@ def get_ancestor_node_identifiers(node: Node) -> List[str]:
990990
Returns: the identifiers of all its ancestor nodes
991991
"""
992992

993+
if FROM_NODE_KEY not in node.meta:
994+
return []
995+
993996
node_source = node.meta[FROM_NODE_KEY]
994997
node_source = node_source[-1]
995998
ancestor_node_ids: List[str] = [f"{node_source.name}.{str(node_source.graph_id)}"]
@@ -1111,6 +1114,7 @@ def propagate_back_debug_handle(
11111114
exported_program: ExportedProgram,
11121115
exported_program_graph_id: int,
11131116
edge_dialect_program: ExportedProgram,
1117+
disable_debug_handle_valdiation: bool = False,
11141118
) -> bool:
11151119
"""
11161120
Propagate debug handle from edge dialect program back to the exported program while maintain the correctness
@@ -1124,6 +1128,10 @@ def propagate_back_debug_handle(
11241128
Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
11251129
The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping.
11261130
1131+
disable_debug_handle_validation is used to avoid _verify_graph_match() in case of debug handle mismatch.
1132+
This can happen when we are comparing against aten graph in which case not all debug handles are matched
1133+
in aten graph. Example of this is when symbolic shape nodes are re-exported.
1134+
11271135
Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False.
11281136
"""
11291137
# 1. Extract mapping from ancestor node identifiers to debug handles
@@ -1137,7 +1145,9 @@ def propagate_back_debug_handle(
11371145
)
11381146

11391147
# 3. Verify if every debug handle in edge dialect program has a corresponding node
1140-
if not _verify_graph_match(edge_dialect_program, matched_debug_handles):
1148+
if not disable_debug_handle_valdiation and not _verify_graph_match(
1149+
edge_dialect_program, matched_debug_handles
1150+
):
11411151
return False
11421152

11431153
# 4. Apply debug handles to the exported program

devtools/inspector/tests/inspector_test.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def test_calculate_numeric_gap(self):
681681
aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}
682682
runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}
683683

684-
inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda: (
684+
inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: (
685685
aot_intermediate_outputs,
686686
aot_debug_handle_to_op_name,
687687
)
@@ -838,6 +838,122 @@ def _gen_random_runtime_output(
838838
) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]:
839839
return [torch.randn(RAW_DATA_SIZE)]
840840

841+
def test_disable_debug_handle_validation_with_symbolic_shapes(self):
842+
"""
843+
Test that demonstrates the issue with symbolic shape related nodes losing from_node info
844+
during dynamic shape based export, and shows how disable_debug_handle_valdiation parameter
845+
in propagate_back_debug_handle allows validation to be bypassed.
846+
"""
847+
from executorch.devtools.inspector._inspector_utils import (
848+
propagate_back_debug_handle,
849+
)
850+
851+
class SymbolicShapeModel(torch.nn.Module):
852+
"""Model that will have symbolic shape related operations after export."""
853+
854+
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
855+
# This will create symbolic shape nodes during dynamic export
856+
batch_size = x.shape[0]
857+
x = x + torch.rand((batch_size, 1))
858+
# Masking operation that creates gt/lt nodes
859+
valid_mask = mask > 0.5
860+
x = torch.where(valid_mask, x, torch.zeros_like(x))
861+
return x
862+
863+
# Create model and dynamic inputs
864+
model = SymbolicShapeModel()
865+
batch_size = 2
866+
seq_len = 4
867+
x = torch.randn(batch_size, seq_len)
868+
mask = torch.rand(batch_size, seq_len)
869+
example_inputs = (x, mask)
870+
871+
# Export with dynamic shapes to create symbolic shape related nodes
872+
dynamic_shapes = {
873+
"x": {0: torch.export.Dim("batch_size", min=1, max=10)},
874+
"mask": {0: torch.export.Dim("batch_size", min=1, max=10)},
875+
}
876+
877+
exported_program = torch.export.export(
878+
model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True
879+
)
880+
881+
"""
882+
In this case origina aten graph has sym_size_int_2 node but when we look at
883+
nodes metadata in edge_program_manager, its sym_size node's from_node says
884+
sym_size_int_3 which is not in the original aten graph.
885+
"""
886+
# Create edge program - this is where from_node info can be lost for symbolic shape nodes
887+
edge_program_manager: EdgeProgramManager = to_edge(exported_program)
888+
edge_program_manager_copy = copy.deepcopy(edge_program_manager)
889+
et_program_manager: ExecutorchProgramManager = (
890+
edge_program_manager.to_executorch()
891+
)
892+
893+
with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file:
894+
etrecord_path = tmp_file.name
895+
896+
# Generate ETRecord with the exported program (aten graph)
897+
generate_etrecord(
898+
etrecord_path,
899+
edge_program_manager_copy,
900+
et_program_manager,
901+
exported_program=exported_program,
902+
)
903+
904+
# Create Inspector and get etrecord
905+
with patch.object(
906+
_inspector, "gen_etdump_object", return_value=None
907+
), patch.object(EventBlock, "_gen_from_etdump"):
908+
inspector_instance = Inspector(
909+
etdump_path=ETDUMP_PATH,
910+
etrecord=etrecord_path,
911+
)
912+
913+
# Extract the necessary values from the inspector's etrecord
914+
exported_program_from_etrecord = (
915+
inspector_instance._etrecord.exported_program
916+
)
917+
export_graph_id = inspector_instance._etrecord.export_graph_id
918+
edge_dialect_program = inspector_instance._etrecord.edge_dialect_program
919+
920+
# Ensure we have all the necessary components
921+
self.assertIsNotNone(exported_program_from_etrecord)
922+
self.assertIsNotNone(export_graph_id)
923+
self.assertIsNotNone(edge_dialect_program)
924+
925+
# Test propagate_back_debug_handle with validation enabled (should fail or return False)
926+
# This demonstrates the issue with symbolic shape nodes losing from_node info
927+
validation_enabled_result = propagate_back_debug_handle(
928+
exported_program_from_etrecord,
929+
export_graph_id,
930+
edge_dialect_program,
931+
disable_debug_handle_valdiation=False,
932+
)
933+
934+
# With validation enabled, it should return False when from_node info is lost
935+
self.assertFalse(
936+
validation_enabled_result,
937+
"propagate_back_debug_handle should return False when validation is enabled "
938+
"and symbolic shape nodes lose from_node info",
939+
)
940+
941+
# Test propagate_back_debug_handle with validation disabled (should succeed)
942+
# This shows how the disable_debug_handle_valdiation flag allows the function to work
943+
validation_disabled_result = propagate_back_debug_handle(
944+
exported_program_from_etrecord,
945+
export_graph_id,
946+
edge_dialect_program,
947+
disable_debug_handle_valdiation=True,
948+
)
949+
950+
# With validation disabled, it should return True even when from_node info is lost
951+
self.assertTrue(
952+
validation_disabled_result,
953+
"propagate_back_debug_handle should return True when validation is disabled, "
954+
"allowing best effort comparison even when from_node info is lost",
955+
)
956+
841957
def _gen_random_events(self) -> List[Event]:
842958
events = []
843959
for i in range(2):

0 commit comments

Comments
 (0)