Skip to content

Commit a198b37

Browse files
Juntian Liufacebook-github-bot
authored andcommitted
Updated AOT debug_handle to operator names mapping (#12366)
Summary: Pull Request resolved: #12366 This PR updates the AOT debug handle to operator names mapping. Previously, each debug handle was mapped to a single operator name, but this update allows for multiple operator names to be associated with a single debug handle by storing them in a list. Reviewed By: Gasoonjia Differential Revision: D78118798
1 parent a1e3d48 commit a198b37

File tree

5 files changed

+71
-44
lines changed

5 files changed

+71
-44
lines changed

devtools/inspector/_inspector.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,7 +1159,7 @@ def _consume_etrecord(self) -> None:
11591159

11601160
def _get_aot_intermediate_outputs_and_op_names(
11611161
self,
1162-
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, str]]:
1162+
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]:
11631163
"""
11641164
Capture intermediate outputs only if _representative_inputs are provided
11651165
when using bundled program to create the etrecord
@@ -1180,13 +1180,13 @@ def _get_aot_intermediate_outputs_and_op_names(
11801180
# TODO: Make it more extensible to further merge overlapping debug handles
11811181
def _get_runtime_intermediate_outputs_and_op_names(
11821182
self,
1183-
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, str]]:
1183+
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]:
11841184
"""
11851185
Retrieve the runtime intermediate outputs(debug handles and intermediate values mappings)
11861186
from the event blocks, along with the corresponding debug handles and op names mapping.
11871187
"""
11881188
debug_handle_to_output = {}
1189-
debug_handle_to_op_name = {}
1189+
debug_handle_to_op_names = {}
11901190
for event_block in self.event_blocks:
11911191
for event in event_block.events:
11921192
# Skip OPERATOR_CALL events to avoid double-counting and exclude framework tax
@@ -1208,12 +1208,13 @@ def _get_runtime_intermediate_outputs_and_op_names(
12081208
event._instruction_id,
12091209
event.debug_data,
12101210
)
1211-
debug_handle_to_op_name[debug_handle] = event.name
1211+
# TODO: One debug handle can be associated with multiple op names
1212+
debug_handle_to_op_names[debug_handle] = [event.name]
12121213

12131214
merge_runtime_overlapping_debug_handles(debug_handle_to_output)
12141215
return {
12151216
k: v[1] for k, v in debug_handle_to_output.items()
1216-
}, debug_handle_to_op_name
1217+
}, debug_handle_to_op_names
12171218

12181219
def to_dataframe(
12191220
self,
@@ -1385,15 +1386,15 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
13851386
pd.DataFrame: A DataFrame listing corresponding operator outputs from
13861387
both stages and their computed numerical gaps.
13871388
"""
1388-
aot_intermediate_outputs, aot_debug_handle_to_op_name = (
1389+
aot_intermediate_outputs, aot_debug_handle_to_op_names = (
13891390
self._get_aot_intermediate_outputs_and_op_names()
13901391
)
1391-
if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_name) == 0:
1392+
if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_names) == 0:
13921393
raise ValueError(
13931394
"Missing etrecord or missing representative inputs within etrecord, both of which are required for calculating numerical gap"
13941395
)
13951396
# The runtime_op_names will be used later to map runtime debug_handle to op_name
1396-
runtime_intermediate_outputs, runtime_debug_handle_to_op_name = (
1397+
runtime_intermediate_outputs, runtime_debug_handle_to_op_names = (
13971398
self._get_runtime_intermediate_outputs_and_op_names()
13981399
)
13991400
mapping = map_runtime_aot_intermediate_outputs(
@@ -1419,11 +1420,11 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
14191420
rows.append(
14201421
{
14211422
"aot_ops": find_op_names(
1422-
aot_debug_handle, aot_debug_handle_to_op_name
1423+
aot_debug_handle, aot_debug_handle_to_op_names
14231424
),
14241425
"aot_intermediate_output": aot_intermediate_output,
14251426
"runtime_ops": find_op_names(
1426-
runtime_debug_handle, runtime_debug_handle_to_op_name
1427+
runtime_debug_handle, runtime_debug_handle_to_op_names
14271428
),
14281429
"runtime_intermediate_output": runtime_intermediate_output,
14291430
"gap": compare_intermediate_outputs(

devtools/inspector/_inspector_utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -814,13 +814,13 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
814814

815815
def get_aot_debug_handle_to_op_name_mapping(
816816
graph_module: torch.fx.GraphModule,
817-
) -> Dict[DebugHandle, str]:
817+
) -> Dict[DebugHandle, List[str]]:
818818
"""
819819
Get a mapping from debug handle to operator name from the ETRecord edge_dialect_program's graph module.
820820
Parameters:
821821
graph_module (torch.fx.GraphModule): The graph module to get the mapping from.
822822
Returns:
823-
Dict[DebugHandle, str]: A dictionary mapping debug handles to operator names.
823+
Dict[DebugHandle, List[str]]: A dictionary mapping debug handles to operator names.
824824
"""
825825
node_filters = [
826826
NodeFilter("debug_handle", "call_function", exclude_ops=["getitem"])
@@ -836,26 +836,29 @@ def get_aot_debug_handle_to_op_name_mapping(
836836
if isinstance(debug_handle, int)
837837
else tuple(debug_handle)
838838
)
839-
debug_handle_to_op_name[key] = node.name
839+
if key in debug_handle_to_op_name:
840+
debug_handle_to_op_name[key].append(node.name)
841+
else:
842+
debug_handle_to_op_name[key] = [node.name]
840843
return debug_handle_to_op_name
841844

842845

843846
def find_op_names(
844847
target_debug_handle: DebugHandle,
845-
debug_handle_to_op_name: Dict[DebugHandle, str],
848+
debug_handle_to_op_names: Dict[DebugHandle, List[str]],
846849
) -> List[str]:
847850
"""
848851
Record the operator names only if their debug handles are part of the target debug handle.
849-
The debug handles in `debug_handle_to_op_name` have undergone merging and remain unchanged,
852+
The debug handles in `debug_handle_to_op_names` have undergone merging and remain unchanged,
850853
and this function identifies operations corresponding to these transformed handles.
851854
"""
852855
dh_set = set(target_debug_handle)
853856
result = []
854857

855-
for key_tuple, op_name in debug_handle_to_op_name.items():
858+
for key_tuple, op_name in debug_handle_to_op_names.items():
856859
# Check if key is a subset of the target_debug_handle
857860
if set(key_tuple).issubset(dh_set):
858-
result.append(op_name)
861+
result.extend(op_name)
859862

860863
return result
861864

devtools/inspector/tests/inspector_test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
TimeScale,
4545
)
4646
from executorch.devtools.inspector.tests.inspector_test_utils import (
47-
check_if_debug_handle_to_op_name_match,
47+
check_if_debug_handle_to_op_names_match,
4848
check_if_final_outputs_match,
4949
model_registry,
5050
)
@@ -522,17 +522,18 @@ def test_etrecord_populates_correct_aot_intermediate_outputs(self):
522522
_representative_inputs=aten_model.example_inputs[0],
523523
)
524524
inspector_instance._etrecord = etrecord
525-
aot_intermediate_outputs, aot_debug_handle_to_op_name = (
525+
aot_intermediate_outputs, aot_debug_handle_to_op_names = (
526526
inspector_instance._get_aot_intermediate_outputs_and_op_names()
527527
)
528528
self.assertTrue(
529529
check_if_final_outputs_match(
530530
"ConvLinearModel", aot_intermediate_outputs
531531
)
532532
)
533+
533534
self.assertTrue(
534-
check_if_debug_handle_to_op_name_match(
535-
"ConvLinearModel", aot_debug_handle_to_op_name
535+
check_if_debug_handle_to_op_names_match(
536+
"ConvLinearModel", aot_debug_handle_to_op_names
536537
)
537538
)
538539

@@ -584,14 +585,14 @@ def test_get_runtime_intermediate_outputs_and_op_names(self):
584585
self.assertTrue(
585586
torch.allclose(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0]))
586587
)
587-
self.assertEqual(op_names[(4,)], "op_3")
588+
self.assertEqual(op_names[(4,)], ["op_3"])
588589

589590
# Check that keys (5,) to (8,) are in the dictionary and have values of the correct size
590591
for key in range(5, 9):
591592
self.assertIn((key,), runtime_outputs)
592593
self.assertIn((key,), op_names)
593594
self.assertEqual(runtime_outputs[(key,)][0].size(0), RAW_DATA_SIZE)
594-
self.assertEqual(op_names[(key,)], f"op_{key-1}")
595+
self.assertEqual(op_names[(key,)], [f"op_{key-1}"])
595596

596597
def test_calculate_numeric_gap(self):
597598
# Create a context manager to patch functions called by Inspector.__init__

devtools/inspector/tests/inspector_test_utils.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,22 +76,22 @@ def get_expected_intermediate_outputs():
7676
}
7777

7878
@staticmethod
79-
def get_expected_debug_handle_to_op_name():
79+
def get_expected_debug_handle_to_op_names():
8080
"""
81-
Returns the expected debug handle and op name mapping for this model for the given input.
81+
Returns the expected debug handle and op names mapping for this model for the given input.
8282
"""
8383
return {
84-
(1,): "aten_convolution_default",
85-
(2,): "aten_view_copy_default",
86-
(3,): "aten_addmm_default",
87-
(4,): "aten_add_tensor",
88-
(5,): "aten_sub_tensor",
89-
(6,): "aten_mul_tensor",
90-
(7,): "aten_add_tensor_1",
91-
(8,): "aten_div_tensor",
92-
(9,): "aten_relu_default",
93-
(10,): "aten_sigmoid_default",
94-
(11,): "aten_split_with_sizes_copy_default",
84+
(1,): ["aten_convolution_default"],
85+
(2,): ["aten_view_copy_default"],
86+
(3,): ["aten_permute_copy_default", "aten_addmm_default"],
87+
(4,): ["aten_add_tensor"],
88+
(5,): ["aten_sub_tensor"],
89+
(6,): ["aten_mul_tensor"],
90+
(7,): ["aten_add_tensor_1"],
91+
(8,): ["aten_div_tensor"],
92+
(9,): ["aten_relu_default"],
93+
(10,): ["aten_sigmoid_default"],
94+
(11,): ["aten_split_with_sizes_copy_default"],
9595
}
9696

9797

@@ -129,14 +129,14 @@ def check_if_final_outputs_match(model_name, actual_outputs_with_handles):
129129
return True
130130

131131

132-
def check_if_debug_handle_to_op_name_match(model_name, actual_debug_handle_to_op_name):
132+
def check_if_debug_handle_to_op_names_match(model_name, actual_debug_handle_to_op_name):
133133
"""
134134
Checks if the actual op names match the expected op names for the specified model.
135135
Returns True if all match, otherwise returns False.
136136
"""
137137
model_instance = model_registry[model_name]
138138
expected_debug_handle_to_op_name = (
139-
model_instance.get_expected_debug_handle_to_op_name()
139+
model_instance.get_expected_debug_handle_to_op_names()
140140
)
141141
if len(actual_debug_handle_to_op_name) != len(expected_debug_handle_to_op_name):
142142
return False

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def test_get_aot_debug_handle_to_op_name_mapping_single_debug_handle(self):
450450
)
451451
node.meta["debug_handle"] = 1
452452
debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module)
453-
expected_result = {(1,): "op1"}
453+
expected_result = {(1,): ["op1"]}
454454
self.assertEqual(debug_handle_to_op_name, expected_result)
455455

456456
def test_get_aot_debug_handle_to_op_name_mapping_multiple_debug_handles(self):
@@ -469,8 +469,8 @@ def test_get_aot_debug_handle_to_op_name_mapping_multiple_debug_handles(self):
469469
(
470470
1,
471471
2,
472-
): "op1",
473-
(3,): "op2",
472+
): ["op1"],
473+
(3,): ["op2"],
474474
}
475475
self.assertEqual(debug_handle_to_op_name, expected_result)
476476

@@ -550,21 +550,43 @@ def test_node_op_type_mismatch(self):
550550

551551
def test_find_op_names_empty_debug_handle(self):
552552
debug_handle = ()
553-
debug_handle_to_op_name = {(1, 2): "op1", (3, 4): "op2"}
553+
debug_handle_to_op_name = {(1, 2): ["op1"], (3, 4): ["op2"]}
554554
self.assertEqual(find_op_names(debug_handle, debug_handle_to_op_name), [])
555555

556556
def test_find_op_names_no_matching_handles(self):
557557
debug_handle = (1, 2)
558-
debug_handle_to_op_name = {(3, 4): "op1", (5, 6): "op2"}
558+
debug_handle_to_op_name = {(3, 4): ["op1"], (5, 6): ["op2"]}
559559
self.assertEqual(find_op_names(debug_handle, debug_handle_to_op_name), [])
560560

561561
def test_find_op_names_matching_handles(self):
562562
debug_handle = (1, 2, 3)
563-
debug_handle_to_op_name = {(1, 2): "op1", (2, 3): "op2", (4, 5, 6): "op3"}
563+
debug_handle_to_op_name = {(1, 2): ["op1"], (2, 3): ["op2"], (4, 5, 6): ["op3"]}
564564
self.assertEqual(
565565
find_op_names(debug_handle, debug_handle_to_op_name), ["op1", "op2"]
566566
)
567567

568+
def test_find_op_names_multiple_ops_single_handle(self):
569+
"""Test when a single debug handle maps to multiple operator names"""
570+
debug_handle = (1, 2, 3)
571+
debug_handle_to_op_name = {(1, 2): ["op1", "op2", "op3"], (4, 5): ["op4"]}
572+
self.assertEqual(
573+
find_op_names(debug_handle, debug_handle_to_op_name), ["op1", "op2", "op3"]
574+
)
575+
576+
def test_find_op_names_mixed_single_and_multiple_ops(self):
577+
"""Test mix of handles with single and multiple operator names"""
578+
debug_handle = (1, 2, 3, 4, 5)
579+
debug_handle_to_op_name = {
580+
(1, 2): ["op1"],
581+
(3,): ["op2", "op3"],
582+
(4,): ["op4"],
583+
(5,): ["op5", "op6", "op7"], # Multiple ops
584+
}
585+
self.assertEqual(
586+
find_op_names(debug_handle, debug_handle_to_op_name),
587+
["op1", "op2", "op3", "op4", "op5", "op6", "op7"],
588+
)
589+
568590
def test_compare_intermediate_outputs_sequences(self):
569591
a = [1.0, 2.0, 3.0]
570592
b = [1.0, 2.5, 3.5]

0 commit comments

Comments
 (0)