Skip to content

Commit 378c700

Browse files
authored
Optionally disable debug handle validateion
Differential Revision: D81784685 Pull Request resolved: #14182
1 parent 91fd3b3 commit 378c700

File tree

3 files changed

+151
-5
lines changed

3 files changed

+151
-5
lines changed

devtools/inspector/_inspector.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,7 @@ def _consume_etrecord(self) -> None:
11691169

11701170
def _get_aot_intermediate_outputs_and_op_names(
11711171
self,
1172+
disable_debug_handle_valdiation: bool = False,
11721173
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]:
11731174
"""
11741175
Capture intermediate outputs only if _representative_inputs are provided
@@ -1184,6 +1185,7 @@ def _get_aot_intermediate_outputs_and_op_names(
11841185
self._etrecord.exported_program,
11851186
self._etrecord.export_graph_id,
11861187
self._etrecord.edge_dialect_program,
1188+
disable_debug_handle_valdiation,
11871189
):
11881190
export_program = self._etrecord.exported_program
11891191
else:
@@ -1404,7 +1406,9 @@ def get_exported_program(
14041406
else self._etrecord.graph_map.get(graph)
14051407
)
14061408

1407-
def calculate_numeric_gap(self, distance: str = "MSE"):
1409+
def calculate_numeric_gap(
1410+
self, distance: str = "MSE", disable_debug_handle_valdiation: bool = False
1411+
):
14081412
"""
14091413
Compares logged intermediate outputs from the exported graph (in ETRecord)
14101414
with runtime outputs (in ETDump) using a user-specific numerical comparator.
@@ -1416,12 +1420,19 @@ def calculate_numeric_gap(self, distance: str = "MSE"):
14161420
14171421
Args:
14181422
distance: the metrics the inspector will use for gap calculation. Should be one of "MSE", "L1" and "SNR".
1423+
disable_debug_handle_validation: Often when aten graph has symbolic shape nodes, and inbuilt ops like gt/lt etc.,
1424+
during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose connection
1425+
between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR node has corresponding
1426+
node in aten IR, and when such validation fails numeric debugger falls back to edge IR as reference graph. This
1427+
flag allows one to override such behavior and make best effort comparison.
14191428
14201429
Returns:
14211430
pd.DataFrame: A DataFrame listing corresponding operator intermediate outputs from both stages and their computed numerical gaps.
14221431
"""
14231432
aot_intermediate_outputs, aot_debug_handle_to_op_names = (
1424-
self._get_aot_intermediate_outputs_and_op_names()
1433+
self._get_aot_intermediate_outputs_and_op_names(
1434+
disable_debug_handle_valdiation
1435+
)
14251436
)
14261437
if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_names) == 0:
14271438
raise ValueError(
@@ -1451,6 +1462,15 @@ def calculate_numeric_gap(self, distance: str = "MSE"):
14511462
) in mapping.items():
14521463
if aot_intermediate_output is None or runtime_intermediate_output is None:
14531464
continue
1465+
# If aot outputs length is > 1 then comparison fails since we dont really have
1466+
# any instances where runtime intermediate output is a tuple or list
1467+
# This does not happen when edge dialect program is reference for comparison
1468+
# but happens in aten graph where ops like unbind remain undecomposed
1469+
if (
1470+
isinstance(aot_intermediate_output, Sequence)
1471+
and len(aot_intermediate_output) > 1
1472+
):
1473+
continue
14541474
rows.append(
14551475
{
14561476
"aot_ops": find_op_names(

devtools/inspector/_inspector_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
965965
# Ensure both sequences have the same length
966966
if len(a) != len(b):
967967
raise ValueError(
968-
f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison."
968+
f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison. len(a): {len(a)} len(b): {len(b)}."
969969
)
970970

971971
# Compare each element in the sequences and return the list of results
@@ -990,6 +990,9 @@ def get_ancestor_node_identifiers(node: Node) -> List[str]:
990990
Returns: the identifiers of all its ancestor nodes
991991
"""
992992

993+
if FROM_NODE_KEY not in node.meta:
994+
return []
995+
993996
node_source = node.meta[FROM_NODE_KEY]
994997
node_source = node_source[-1]
995998
ancestor_node_ids: List[str] = [f"{node_source.name}.{str(node_source.graph_id)}"]
@@ -1111,6 +1114,7 @@ def propagate_back_debug_handle(
11111114
exported_program: ExportedProgram,
11121115
exported_program_graph_id: int,
11131116
edge_dialect_program: ExportedProgram,
1117+
disable_debug_handle_valdiation: bool = False,
11141118
) -> bool:
11151119
"""
11161120
Propagate debug handle from edge dialect program back to the exported program while maintain the correctness
@@ -1124,6 +1128,10 @@ def propagate_back_debug_handle(
11241128
Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
11251129
The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping.
11261130
1131+
disable_debug_handle_validation is used to avoid _verify_graph_match() in case of debug handle mismatch.
1132+
This can happen when we are comparing against aten graph in which case not all debug handles are matched
1133+
in aten graph. Example of this is when symbolic shape nodes are re-exported.
1134+
11271135
Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False.
11281136
"""
11291137
# 1. Extract mapping from ancestor node identifiers to debug handles
@@ -1137,7 +1145,9 @@ def propagate_back_debug_handle(
11371145
)
11381146

11391147
# 3. Verify if every debug handle in edge dialect program has a corresponding node
1140-
if not _verify_graph_match(edge_dialect_program, matched_debug_handles):
1148+
if not disable_debug_handle_valdiation and not _verify_graph_match(
1149+
edge_dialect_program, matched_debug_handles
1150+
):
11411151
return False
11421152

11431153
# 4. Apply debug handles to the exported program

devtools/inspector/tests/inspector_test.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def test_calculate_numeric_gap(self):
681681
aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}
682682
runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}
683683

684-
inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda: (
684+
inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: (
685685
aot_intermediate_outputs,
686686
aot_debug_handle_to_op_name,
687687
)
@@ -838,6 +838,122 @@ def _gen_random_runtime_output(
838838
) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]:
839839
return [torch.randn(RAW_DATA_SIZE)]
840840

841+
def test_disable_debug_handle_validation_with_symbolic_shapes(self):
842+
"""
843+
Test that demonstrates the issue with symbolic shape related nodes losing from_node info
844+
during dynamic shape based export, and shows how disable_debug_handle_valdiation parameter
845+
in propagate_back_debug_handle allows validation to be bypassed.
846+
"""
847+
from executorch.devtools.inspector._inspector_utils import (
848+
propagate_back_debug_handle,
849+
)
850+
851+
class SymbolicShapeModel(torch.nn.Module):
852+
"""Model that will have symbolic shape related operations after export."""
853+
854+
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
855+
# This will create symbolic shape nodes during dynamic export
856+
batch_size = x.shape[0]
857+
x = x + torch.rand((batch_size, 1))
858+
# Masking operation that creates gt/lt nodes
859+
valid_mask = mask > 0.5
860+
x = torch.where(valid_mask, x, torch.zeros_like(x))
861+
return x
862+
863+
# Create model and dynamic inputs
864+
model = SymbolicShapeModel()
865+
batch_size = 2
866+
seq_len = 4
867+
x = torch.randn(batch_size, seq_len)
868+
mask = torch.rand(batch_size, seq_len)
869+
example_inputs = (x, mask)
870+
871+
# Export with dynamic shapes to create symbolic shape related nodes
872+
dynamic_shapes = {
873+
"x": {0: torch.export.Dim("batch_size", min=1, max=10)},
874+
"mask": {0: torch.export.Dim("batch_size", min=1, max=10)},
875+
}
876+
877+
exported_program = torch.export.export(
878+
model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True
879+
)
880+
881+
"""
882+
In this case origina aten graph has sym_size_int_2 node but when we look at
883+
nodes metadata in edge_program_manager, its sym_size node's from_node says
884+
sym_size_int_3 which is not in the original aten graph.
885+
"""
886+
# Create edge program - this is where from_node info can be lost for symbolic shape nodes
887+
edge_program_manager: EdgeProgramManager = to_edge(exported_program)
888+
edge_program_manager_copy = copy.deepcopy(edge_program_manager)
889+
et_program_manager: ExecutorchProgramManager = (
890+
edge_program_manager.to_executorch()
891+
)
892+
893+
with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file:
894+
etrecord_path = tmp_file.name
895+
896+
# Generate ETRecord with the exported program (aten graph)
897+
generate_etrecord(
898+
etrecord_path,
899+
edge_program_manager_copy,
900+
et_program_manager,
901+
exported_program=exported_program,
902+
)
903+
904+
# Create Inspector and get etrecord
905+
with patch.object(
906+
_inspector, "gen_etdump_object", return_value=None
907+
), patch.object(EventBlock, "_gen_from_etdump"):
908+
inspector_instance = Inspector(
909+
etdump_path=ETDUMP_PATH,
910+
etrecord=etrecord_path,
911+
)
912+
913+
# Extract the necessary values from the inspector's etrecord
914+
exported_program_from_etrecord = (
915+
inspector_instance._etrecord.exported_program
916+
)
917+
export_graph_id = inspector_instance._etrecord.export_graph_id
918+
edge_dialect_program = inspector_instance._etrecord.edge_dialect_program
919+
920+
# Ensure we have all the necessary components
921+
self.assertIsNotNone(exported_program_from_etrecord)
922+
self.assertIsNotNone(export_graph_id)
923+
self.assertIsNotNone(edge_dialect_program)
924+
925+
# Test propagate_back_debug_handle with validation enabled (should fail or return False)
926+
# This demonstrates the issue with symbolic shape nodes losing from_node info
927+
validation_enabled_result = propagate_back_debug_handle(
928+
exported_program_from_etrecord,
929+
export_graph_id,
930+
edge_dialect_program,
931+
disable_debug_handle_valdiation=False,
932+
)
933+
934+
# With validation enabled, it should return False when from_node info is lost
935+
self.assertFalse(
936+
validation_enabled_result,
937+
"propagate_back_debug_handle should return False when validation is enabled "
938+
"and symbolic shape nodes lose from_node info",
939+
)
940+
941+
# Test propagate_back_debug_handle with validation disabled (should succeed)
942+
# This shows how the disable_debug_handle_valdiation flag allows the function to work
943+
validation_disabled_result = propagate_back_debug_handle(
944+
exported_program_from_etrecord,
945+
export_graph_id,
946+
edge_dialect_program,
947+
disable_debug_handle_valdiation=True,
948+
)
949+
950+
# With validation disabled, it should return True even when from_node info is lost
951+
self.assertTrue(
952+
validation_disabled_result,
953+
"propagate_back_debug_handle should return True when validation is disabled, "
954+
"allowing best effort comparison even when from_node info is lost",
955+
)
956+
841957
def _gen_random_events(self) -> List[Event]:
842958
events = []
843959
for i in range(2):

0 commit comments

Comments
 (0)