Skip to content

Commit a7c2b04

Browse files
committed
Revert #14182: "Optionally disable debug handle validateion"
This reverts commit 378c700 (#14182). It broke unittest / windows. ghstack-source-id: 715d9ce ghstack-comment-id: 3286690859 Pull-Request: #14281
1 parent eb1099e commit a7c2b04

File tree

3 files changed

+5
-151
lines changed

3 files changed

+5
-151
lines changed

devtools/inspector/_inspector.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,7 +1169,6 @@ def _consume_etrecord(self) -> None:
11691169

11701170
def _get_aot_intermediate_outputs_and_op_names(
11711171
self,
1172-
disable_debug_handle_valdiation: bool = False,
11731172
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]:
11741173
"""
11751174
Capture intermediate outputs only if _representative_inputs are provided
@@ -1185,7 +1184,6 @@ def _get_aot_intermediate_outputs_and_op_names(
11851184
self._etrecord.exported_program,
11861185
self._etrecord.export_graph_id,
11871186
self._etrecord.edge_dialect_program,
1188-
disable_debug_handle_valdiation,
11891187
):
11901188
export_program = self._etrecord.exported_program
11911189
else:
@@ -1406,9 +1404,7 @@ def get_exported_program(
14061404
else self._etrecord.graph_map.get(graph)
14071405
)
14081406

1409-
def calculate_numeric_gap(
1410-
self, distance: str = "MSE", disable_debug_handle_valdiation: bool = False
1411-
):
1407+
def calculate_numeric_gap(self, distance: str = "MSE"):
14121408
"""
14131409
Compares logged intermediate outputs from the exported graph (in ETRecord)
14141410
with runtime outputs (in ETDump) using a user-specific numerical comparator.
@@ -1420,19 +1416,12 @@ def calculate_numeric_gap(
14201416
14211417
Args:
14221418
distance: the metrics the inspector will use for gap calculation. Should be one of "MSE", "L1" and "SNR".
1423-
disable_debug_handle_validation: Often when aten graph has symbolic shape nodes, and inbuilt ops like gt/lt etc.,
1424-
during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose connection
1425-
between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR node has corresponding
1426-
node in aten IR, and when such validation fails numeric debugger falls back to edge IR as reference graph. This
1427-
flag allows one to override such behavior and make best effort comparison.
14281419
14291420
Returns:
14301421
pd.DataFrame: A DataFrame listing corresponding operator intermediate outputs from both stages and their computed numerical gaps.
14311422
"""
14321423
aot_intermediate_outputs, aot_debug_handle_to_op_names = (
1433-
self._get_aot_intermediate_outputs_and_op_names(
1434-
disable_debug_handle_valdiation
1435-
)
1424+
self._get_aot_intermediate_outputs_and_op_names()
14361425
)
14371426
if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_names) == 0:
14381427
raise ValueError(
@@ -1462,15 +1451,6 @@ def calculate_numeric_gap(
14621451
) in mapping.items():
14631452
if aot_intermediate_output is None or runtime_intermediate_output is None:
14641453
continue
1465-
# If aot outputs length is > 1 then comparison fails since we dont really have
1466-
# any instances where runtime intermediate output is a tuple or list
1467-
# This does not happen when edge dialect program is reference for comparison
1468-
# but happens in aten graph where ops like unbind remain undecomposed
1469-
if (
1470-
isinstance(aot_intermediate_output, Sequence)
1471-
and len(aot_intermediate_output) > 1
1472-
):
1473-
continue
14741454
rows.append(
14751455
{
14761456
"aot_ops": find_op_names(

devtools/inspector/_inspector_utils.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
965965
# Ensure both sequences have the same length
966966
if len(a) != len(b):
967967
raise ValueError(
968-
f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison. len(a): {len(a)} len(b): {len(b)}."
968+
f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison."
969969
)
970970

971971
# Compare each element in the sequences and return the list of results
@@ -990,9 +990,6 @@ def get_ancestor_node_identifiers(node: Node) -> List[str]:
990990
Returns: the identifiers of all its ancestor nodes
991991
"""
992992

993-
if FROM_NODE_KEY not in node.meta:
994-
return []
995-
996993
node_source = node.meta[FROM_NODE_KEY]
997994
node_source = node_source[-1]
998995
ancestor_node_ids: List[str] = [f"{node_source.name}.{str(node_source.graph_id)}"]
@@ -1114,7 +1111,6 @@ def propagate_back_debug_handle(
11141111
exported_program: ExportedProgram,
11151112
exported_program_graph_id: int,
11161113
edge_dialect_program: ExportedProgram,
1117-
disable_debug_handle_valdiation: bool = False,
11181114
) -> bool:
11191115
"""
11201116
Propagate debug handle from edge dialect program back to the exported program while maintain the correctness
@@ -1128,10 +1124,6 @@ def propagate_back_debug_handle(
11281124
Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
11291125
The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping.
11301126
1131-
disable_debug_handle_validation is used to avoid _verify_graph_match() in case of debug handle mismatch.
1132-
This can happen when we are comparing against aten graph in which case not all debug handles are matched
1133-
in aten graph. Example of this is when symbolic shape nodes are re-exported.
1134-
11351127
Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False.
11361128
"""
11371129
# 1. Extract mapping from ancestor node identifiers to debug handles
@@ -1145,9 +1137,7 @@ def propagate_back_debug_handle(
11451137
)
11461138

11471139
# 3. Verify if every debug handle in edge dialect program has a corresponding node
1148-
if not disable_debug_handle_valdiation and not _verify_graph_match(
1149-
edge_dialect_program, matched_debug_handles
1150-
):
1140+
if not _verify_graph_match(edge_dialect_program, matched_debug_handles):
11511141
return False
11521142

11531143
# 4. Apply debug handles to the exported program

devtools/inspector/tests/inspector_test.py

Lines changed: 1 addition & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def test_calculate_numeric_gap(self):
681681
aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}
682682
runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}
683683

684-
inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: (
684+
inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda: (
685685
aot_intermediate_outputs,
686686
aot_debug_handle_to_op_name,
687687
)
@@ -838,122 +838,6 @@ def _gen_random_runtime_output(
838838
) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]:
839839
return [torch.randn(RAW_DATA_SIZE)]
840840

841-
def test_disable_debug_handle_validation_with_symbolic_shapes(self):
842-
"""
843-
Test that demonstrates the issue with symbolic shape related nodes losing from_node info
844-
during dynamic shape based export, and shows how disable_debug_handle_valdiation parameter
845-
in propagate_back_debug_handle allows validation to be bypassed.
846-
"""
847-
from executorch.devtools.inspector._inspector_utils import (
848-
propagate_back_debug_handle,
849-
)
850-
851-
class SymbolicShapeModel(torch.nn.Module):
852-
"""Model that will have symbolic shape related operations after export."""
853-
854-
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
855-
# This will create symbolic shape nodes during dynamic export
856-
batch_size = x.shape[0]
857-
x = x + torch.rand((batch_size, 1))
858-
# Masking operation that creates gt/lt nodes
859-
valid_mask = mask > 0.5
860-
x = torch.where(valid_mask, x, torch.zeros_like(x))
861-
return x
862-
863-
# Create model and dynamic inputs
864-
model = SymbolicShapeModel()
865-
batch_size = 2
866-
seq_len = 4
867-
x = torch.randn(batch_size, seq_len)
868-
mask = torch.rand(batch_size, seq_len)
869-
example_inputs = (x, mask)
870-
871-
# Export with dynamic shapes to create symbolic shape related nodes
872-
dynamic_shapes = {
873-
"x": {0: torch.export.Dim("batch_size", min=1, max=10)},
874-
"mask": {0: torch.export.Dim("batch_size", min=1, max=10)},
875-
}
876-
877-
exported_program = torch.export.export(
878-
model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True
879-
)
880-
881-
"""
882-
In this case origina aten graph has sym_size_int_2 node but when we look at
883-
nodes metadata in edge_program_manager, its sym_size node's from_node says
884-
sym_size_int_3 which is not in the original aten graph.
885-
"""
886-
# Create edge program - this is where from_node info can be lost for symbolic shape nodes
887-
edge_program_manager: EdgeProgramManager = to_edge(exported_program)
888-
edge_program_manager_copy = copy.deepcopy(edge_program_manager)
889-
et_program_manager: ExecutorchProgramManager = (
890-
edge_program_manager.to_executorch()
891-
)
892-
893-
with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file:
894-
etrecord_path = tmp_file.name
895-
896-
# Generate ETRecord with the exported program (aten graph)
897-
generate_etrecord(
898-
etrecord_path,
899-
edge_program_manager_copy,
900-
et_program_manager,
901-
exported_program=exported_program,
902-
)
903-
904-
# Create Inspector and get etrecord
905-
with patch.object(
906-
_inspector, "gen_etdump_object", return_value=None
907-
), patch.object(EventBlock, "_gen_from_etdump"):
908-
inspector_instance = Inspector(
909-
etdump_path=ETDUMP_PATH,
910-
etrecord=etrecord_path,
911-
)
912-
913-
# Extract the necessary values from the inspector's etrecord
914-
exported_program_from_etrecord = (
915-
inspector_instance._etrecord.exported_program
916-
)
917-
export_graph_id = inspector_instance._etrecord.export_graph_id
918-
edge_dialect_program = inspector_instance._etrecord.edge_dialect_program
919-
920-
# Ensure we have all the necessary components
921-
self.assertIsNotNone(exported_program_from_etrecord)
922-
self.assertIsNotNone(export_graph_id)
923-
self.assertIsNotNone(edge_dialect_program)
924-
925-
# Test propagate_back_debug_handle with validation enabled (should fail or return False)
926-
# This demonstrates the issue with symbolic shape nodes losing from_node info
927-
validation_enabled_result = propagate_back_debug_handle(
928-
exported_program_from_etrecord,
929-
export_graph_id,
930-
edge_dialect_program,
931-
disable_debug_handle_valdiation=False,
932-
)
933-
934-
# With validation enabled, it should return False when from_node info is lost
935-
self.assertFalse(
936-
validation_enabled_result,
937-
"propagate_back_debug_handle should return False when validation is enabled "
938-
"and symbolic shape nodes lose from_node info",
939-
)
940-
941-
# Test propagate_back_debug_handle with validation disabled (should succeed)
942-
# This shows how the disable_debug_handle_valdiation flag allows the function to work
943-
validation_disabled_result = propagate_back_debug_handle(
944-
exported_program_from_etrecord,
945-
export_graph_id,
946-
edge_dialect_program,
947-
disable_debug_handle_valdiation=True,
948-
)
949-
950-
# With validation disabled, it should return True even when from_node info is lost
951-
self.assertTrue(
952-
validation_disabled_result,
953-
"propagate_back_debug_handle should return True when validation is disabled, "
954-
"allowing best effort comparison even when from_node info is lost",
955-
)
956-
957841
def _gen_random_events(self) -> List[Event]:
958842
events = []
959843
for i in range(2):

0 commit comments

Comments
 (0)