Skip to content

Commit 9d899c9

Browse files
author
Juntian Liu
authored
Integrated debug_handle to operator name mapping into inspector
Differential Revision: D77269265 Pull Request resolved: #12001
1 parent d4cc258 commit 9d899c9

File tree

5 files changed

+132
-50
lines changed

5 files changed

+132
-50
lines changed

devtools/inspector/_inspector.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
EXCLUDED_COLUMNS_WHEN_PRINTING,
4949
EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT,
5050
EXCLUDED_EVENTS_WHEN_PRINTING,
51+
find_op_names,
5152
find_populated_event,
5253
FORWARD,
5354
gen_etdump_object,
@@ -68,6 +69,7 @@
6869
from executorch.devtools.inspector.numerical_comparator import (
6970
L1Comparator,
7071
MSEComparator,
72+
SNRComparator,
7173
)
7274
from executorch.exir import ExportedProgram
7375

@@ -1084,8 +1086,6 @@ def __init__(
10841086
# Key str is method name; value is list of ProgramOutputs because of list of test cases
10851087
self._reference_outputs: Dict[str, List[ProgramOutput]] = {}
10861088
self._enable_module_hierarchy = enable_module_hierarchy
1087-
self._aot_intermediate_outputs: Optional[Dict[Tuple[int, ...], Any]] = None
1088-
self._aot_debug_handles_to_op_names: Optional[Dict[Tuple[int, ...], str]] = None
10891089
self._consume_etrecord()
10901090

10911091
def _consume_etrecord(self) -> None:
@@ -1146,19 +1146,26 @@ def _consume_etrecord(self) -> None:
11461146
event_block.reference_output = self._reference_outputs[FORWARD][
11471147
index
11481148
]
1149-
# Capture intermediate outputs only if _representative_inputs are provided
1150-
# when using bundled program to create the etrecord
1149+
1150+
def _get_aot_intermediate_outputs_and_op_names(
1151+
self,
1152+
) -> Tuple[Dict[Tuple[int, ...], Any], Dict[Tuple[int, ...], str]]:
1153+
"""
1154+
Capture intermediate outputs only if _representative_inputs are provided
1155+
when using bundled program to create the etrecord
1156+
"""
11511157
if self._etrecord._representative_inputs is None:
1152-
return
1158+
return {}, {}
11531159
export_program = self._etrecord.edge_dialect_program
11541160
graph_module = export_program.module()
1155-
self._aot_debug_handles_to_op_names = get_aot_debug_handle_to_op_name_mapping(
1161+
aot_debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(
11561162
graph_module
11571163
)
11581164
capturer = IntermediateOutputCapturer(graph_module)
1159-
self._aot_intermediate_outputs = capturer.run_and_capture(
1165+
aot_intermediate_outputs = capturer.run_and_capture(
11601166
self._etrecord._representative_inputs
11611167
)
1168+
return aot_intermediate_outputs, aot_debug_handle_to_op_name
11621169

11631170
# TODO: Make it more extensible to further merge overlapping debug handles
11641171
def _get_runtime_intermediate_outputs_and_op_names(
@@ -1366,22 +1373,27 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
13661373
pd.DataFrame: A DataFrame listing corresponding operator outputs from
13671374
both stages and their computed numerical gaps.
13681375
"""
1369-
if self._aot_intermediate_outputs is None:
1376+
aot_intermediate_outputs, aot_debug_handle_to_op_name = (
1377+
self._get_aot_intermediate_outputs_and_op_names()
1378+
)
1379+
if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_name) == 0:
13701380
raise ValueError(
1371-
"The aot intermediate outputs is required but not populated."
1381+
"calculate_numerical_gap error: The aot debug information is required but not populated"
13721382
)
13731383
# The runtime_op_names will be used later to map runtime debug_handle to op_name
1374-
runtime_intermediate_outputs, runtime_op_names = (
1384+
runtime_intermediate_outputs, runtime_debug_handle_to_op_name = (
13751385
self._get_runtime_intermediate_outputs_and_op_names()
13761386
)
13771387
mapping = map_runtime_aot_intermediate_outputs(
1378-
self._aot_intermediate_outputs, runtime_intermediate_outputs
1388+
aot_intermediate_outputs, runtime_intermediate_outputs
13791389
)
13801390
metric = distance.strip().upper()
13811391
if metric == "MSE":
13821392
comparator = MSEComparator()
13831393
elif metric == "L1":
13841394
comparator = L1Comparator()
1395+
elif metric == "SNR":
1396+
comparator = SNRComparator()
13851397
else:
13861398
raise ValueError(f"Unsupported distance metric {distance!r}")
13871399

@@ -1394,9 +1406,13 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
13941406
continue
13951407
rows.append(
13961408
{
1397-
"aot_debug_handle": aot_debug_handle,
1409+
"aot_ops": find_op_names(
1410+
aot_debug_handle, aot_debug_handle_to_op_name
1411+
),
13981412
"aot_intermediate_output": aot_intermediate_output,
1399-
"runtime_debug_handle": runtime_debug_handle,
1413+
"runtime_ops": find_op_names(
1414+
runtime_debug_handle, runtime_debug_handle_to_op_name
1415+
),
14001416
"runtime_intermediate_output": runtime_intermediate_output,
14011417
"gap": comparator.compare(
14021418
aot_intermediate_output, runtime_intermediate_output

devtools/inspector/_inspector_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,3 +784,23 @@ def get_aot_debug_handle_to_op_name_mapping(
784784
)
785785
debug_handle_to_op_name[key] = node.name
786786
return debug_handle_to_op_name
787+
788+
789+
def find_op_names(
790+
target_debug_handle: Tuple[int, ...],
791+
debug_handle_to_op_name: Dict[Tuple[int, ...], str],
792+
) -> List[str]:
793+
"""
794+
Record the operator names only if their debug handles are part of the target debug handle.
795+
The debug handles in `debug_handle_to_op_name` have undergone merging and remain unchanged,
796+
and this function identifies operations corresponding to these transformed handles.
797+
"""
798+
dh_set = set(target_debug_handle)
799+
result = []
800+
801+
for key_tuple, op_name in debug_handle_to_op_name.items():
802+
# Check if key is a subset of the target_debug_handle
803+
if set(key_tuple).issubset(dh_set):
804+
result.append(op_name)
805+
806+
return result

devtools/inspector/tests/inspector_test.py

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
TimeScale,
4545
)
4646
from executorch.devtools.inspector.tests.inspector_test_utils import (
47+
check_if_debug_handle_to_op_name_match,
4748
check_if_final_outputs_match,
4849
model_registry,
4950
)
@@ -468,25 +469,7 @@ def test_populate_debugging_related_fields_passes_for_consistent_events(self):
468469
events=events,
469470
)
470471

471-
def test_no_capture_when_representative_inputs_are_none(self):
472-
# Create a context manager to patch functions called by Inspector.__init__
473-
with patch.object(
474-
_inspector, "parse_etrecord", return_value=None
475-
), patch.object(
476-
_inspector, "gen_etdump_object", return_value=None
477-
), patch.object(
478-
EventBlock, "_gen_from_etdump"
479-
), patch.object(
480-
_inspector, "gen_graphs_from_etrecord"
481-
):
482-
# Call the constructor of Inspector
483-
inspector_instance = Inspector(
484-
etdump_path=ETDUMP_PATH,
485-
etrecord=ETRECORD_PATH,
486-
)
487-
self.assertIsNone(inspector_instance._aot_intermediate_outputs)
488-
489-
def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self):
472+
def test_etrecord_populates_correct_aot_intermediate_outputs(self):
490473
with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file:
491474
etrecord_path = tmp_file.name
492475
mod = model_registry["ConvLinearModel"]()
@@ -505,7 +488,6 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self):
505488
generate_etrecord(
506489
etrecord_path, edge_program_manager_copy, et_program_manager
507490
)
508-
original_consume_etrecord = Inspector._consume_etrecord
509491
with patch.object(
510492
Inspector, "_consume_etrecord", return_value=None
511493
), patch.object(
@@ -529,11 +511,17 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self):
529511
_representative_inputs=aten_model.example_inputs[0],
530512
)
531513
inspector_instance._etrecord = etrecord
532-
Inspector._consume_etrecord = original_consume_etrecord
533-
inspector_instance._consume_etrecord()
514+
aot_intermediate_outputs, aot_debug_handle_to_op_name = (
515+
inspector_instance._get_aot_intermediate_outputs_and_op_names()
516+
)
534517
self.assertTrue(
535518
check_if_final_outputs_match(
536-
"ConvLinearModel", inspector_instance._aot_intermediate_outputs
519+
"ConvLinearModel", aot_intermediate_outputs
520+
)
521+
)
522+
self.assertTrue(
523+
check_if_debug_handle_to_op_name_match(
524+
"ConvLinearModel", aot_debug_handle_to_op_name
537525
)
538526
)
539527

@@ -605,6 +593,7 @@ def test_calculate_numeric_gap(self):
605593
), patch.object(
606594
_inspector, "gen_graphs_from_etrecord"
607595
):
596+
608597
# Call the constructor of Inspector
609598
inspector_instance = Inspector(
610599
etdump_path=ETDUMP_PATH,
@@ -621,43 +610,44 @@ def test_calculate_numeric_gap(self):
621610
(1,): torch.tensor([3.0, 6.0, 5.0]),
622611
}
623612

624-
inspector_instance._aot_intermediate_outputs = aot_intermediate_outputs
613+
aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}
614+
runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}
615+
616+
inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda: (
617+
aot_intermediate_outputs,
618+
aot_debug_handle_to_op_name,
619+
)
625620
inspector_instance._get_runtime_intermediate_outputs_and_op_names = (
626-
lambda: (runtime_intermediate_outputs, {})
621+
lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name)
627622
)
628623

629624
df = inspector_instance.calculate_numeric_gap(distance="L1")
630625
self.assertIsInstance(df, pd.DataFrame)
631626
self.assertEqual(len(df), 2)
632627
cols = set(df.columns)
633628
expected_cols = {
634-
"aot_debug_handle",
629+
"aot_ops",
635630
"aot_intermediate_output",
636-
"runtime_debug_handle",
631+
"runtime_ops",
637632
"runtime_intermediate_output",
638633
"gap",
639634
}
640635
self.assertEqual(cols, expected_cols)
641-
founded_aot_debug_handle = set(df["aot_debug_handle"])
642-
self.assertEqual(
643-
founded_aot_debug_handle, set(aot_intermediate_outputs.keys())
644-
)
645-
for _, row in df.iterrows():
646-
aot_debuh_handle = row["aot_debug_handle"]
636+
for i, row in df.iterrows():
637+
# Dummpy key to get the expected aot/runtime internmediate outputs
638+
key = (i,)
647639
# aot_intermediate_output should equal aot_intermediate_outputs[h]
648640
self.assertTrue(
649641
torch.allclose(
650642
row["aot_intermediate_output"],
651-
aot_intermediate_outputs[aot_debuh_handle],
643+
aot_intermediate_outputs[key],
652644
)
653645
)
654-
# runtime_debug_hanlde equals aot_debug_handle at this case
655-
self.assertEqual(row["runtime_debug_handle"], aot_debuh_handle)
656646
# runtime_intermediate_output should equal runtime_intermediate_outputs[h]
657647
self.assertTrue(
658648
torch.allclose(
659649
row["runtime_intermediate_output"],
660-
runtime_intermediate_outputs[aot_debuh_handle],
650+
runtime_intermediate_outputs[key],
661651
)
662652
)
663653
# gap should equal 3.0

devtools/inspector/tests/inspector_test_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,26 @@ def get_expected_intermediate_outputs():
8383
(21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
8484
}
8585

86+
@staticmethod
87+
def get_expected_debug_handle_to_op_name():
88+
"""
89+
Returns the expected debug handle and op name mapping for this model for the given input.
90+
"""
91+
return {
92+
(10,): "aten_convolution_default",
93+
(11,): "aten_view_copy_default",
94+
(12,): "aten_permute_copy_default",
95+
(13,): "aten_addmm_default",
96+
(14,): "aten_add_tensor",
97+
(15,): "aten_sub_tensor",
98+
(16,): "aten_mul_tensor",
99+
(17,): "aten_add_tensor_1",
100+
(18,): "aten_div_tensor",
101+
(19,): "aten_relu_default",
102+
(20,): "aten_sigmoid_default",
103+
(21,): "aten_split_with_sizes_copy_default",
104+
}
105+
86106

87107
# Global model registry
88108
model_registry = {
@@ -116,3 +136,21 @@ def check_if_final_outputs_match(model_name, actual_outputs_with_handles):
116136
if not torch.allclose(actual_output, expected_output, rtol=1e-4, atol=1e-5):
117137
return False
118138
return True
139+
140+
141+
def check_if_debug_handle_to_op_name_match(model_name, actual_debug_handle_to_op_name):
142+
"""
143+
Checks if the actual op names match the expected op names for the specified model.
144+
Returns True if all match, otherwise returns False.
145+
"""
146+
model_instance = model_registry[model_name]
147+
expected_debug_handle_to_op_name = (
148+
model_instance.get_expected_debug_handle_to_op_name()
149+
)
150+
if len(actual_debug_handle_to_op_name) != len(expected_debug_handle_to_op_name):
151+
return False
152+
for debug_handle, expected_op_name in expected_debug_handle_to_op_name.items():
153+
actual_op_name = actual_debug_handle_to_op_name.get(debug_handle)
154+
if actual_op_name != expected_op_name:
155+
return False
156+
return True

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
convert_to_float_tensor,
3333
create_debug_handle_to_op_node_mapping,
3434
EDGE_DIALECT_GRAPH_KEY,
35+
find_op_names,
3536
find_populated_event,
3637
gen_graphs_from_etrecord,
3738
get_aot_debug_handle_to_op_name_mapping,
@@ -472,6 +473,23 @@ def test_node_op_type_mismatch(self):
472473
# Test that the filter doesn't match the mock node (op_type mismatch)
473474
self.assertFalse(node_filter.matches(mock_node_op_type_mismatch))
474475

476+
def test_find_op_names_empty_debug_handle(self):
477+
debug_handle = ()
478+
debug_handle_to_op_name = {(1, 2): "op1", (3, 4): "op2"}
479+
self.assertEqual(find_op_names(debug_handle, debug_handle_to_op_name), [])
480+
481+
def test_find_op_names_no_matching_handles(self):
482+
debug_handle = (1, 2)
483+
debug_handle_to_op_name = {(3, 4): "op1", (5, 6): "op2"}
484+
self.assertEqual(find_op_names(debug_handle, debug_handle_to_op_name), [])
485+
486+
def test_find_op_names_matching_handles(self):
487+
debug_handle = (1, 2, 3)
488+
debug_handle_to_op_name = {(1, 2): "op1", (2, 3): "op2", (4, 5, 6): "op3"}
489+
self.assertEqual(
490+
find_op_names(debug_handle, debug_handle_to_op_name), ["op1", "op2"]
491+
)
492+
475493

476494
def gen_mock_operator_graph_with_expected_map() -> (
477495
Tuple[OperatorGraph, Dict[int, OperatorNode]]

0 commit comments

Comments
 (0)