Skip to content

Commit b7a9cf8

Browse files
Juntian Liufacebook-github-bot
authored andcommitted
Add functionality to map runtime debug_handles to op names (#11987)
Summary: This PR adds a functionality to map runtime debug handles to operator names. It will be used later to enhance how numerical discrepancy results are shown, making it easier for users to understand. Differential Revision: D77266536
1 parent 47d1592 commit b7a9cf8

File tree

2 files changed

+44
-19
lines changed

2 files changed

+44
-19
lines changed

devtools/inspector/_inspector.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,12 +1161,16 @@ def _consume_etrecord(self) -> None:
11611161
)
11621162

11631163
# TODO: Make it more extensible to further merge overlapping debug handles
1164-
def _get_runtime_intermediate_outputs(self) -> Dict[Tuple[int, ...], Any]:
1164+
def _get_runtime_intermediate_outputs_and_op_names(
1165+
self,
1166+
) -> Tuple[Dict[Tuple[int, ...], Any], Dict[Tuple[int, ...], str]]:
11651167
"""
1166-
Retrieve the raw runtime intermediate outputs(debug handles and value mappings)
1167-
from the event blocks. These outputs will be processed later to merge overlapping debug handles.
1168+
Retrieve the runtime intermediate outputs(debug handles and intermediate values mappings)
1169+
from the event blocks, along with the corresponding debug handles and op names mapping
1170+
These outputs will be processed later to merge overlapping debug handles.
11681171
"""
11691172
debug_handle_to_output = {}
1173+
debug_handle_to_op_name = {}
11701174
for event_block in self.event_blocks:
11711175
for event in event_block.events:
11721176
# Skip OPERATOR_CALL events to avoid double-counting and exclude framework tax
@@ -1175,20 +1179,23 @@ def _get_runtime_intermediate_outputs(self) -> Dict[Tuple[int, ...], Any]:
11751179
or not event.op_types
11761180
):
11771181
continue
1178-
# Normalize debug_handles to a tuple
1179-
debug_handles = event.debug_handles
1180-
if isinstance(debug_handles, int):
1181-
debug_handles = (debug_handles,)
1182+
# Normalize debug_handle to a tuple
1183+
debug_handle = event.debug_handles
1184+
if isinstance(debug_handle, int):
1185+
debug_handle = (debug_handle,)
11821186
else:
1183-
debug_handles = tuple(debug_handles)
1184-
current_entry = debug_handle_to_output.get(debug_handles, (-1, None))
1185-
# When event has same debug handles, only keep the one with the largest instruction id
1187+
debug_handle = tuple(debug_handle)
1188+
current_entry = debug_handle_to_output.get(debug_handle, (-1, None))
1189+
# When event has same debug_handle, only keep the one with the largest instruction id
11861190
if event._instruction_id > current_entry[0]:
1187-
debug_handle_to_output[debug_handles] = (
1191+
debug_handle_to_output[debug_handle] = (
11881192
event._instruction_id,
11891193
event.debug_data,
11901194
)
1191-
return {k: v[1] for k, v in debug_handle_to_output.items()}
1195+
debug_handle_to_op_name[debug_handle] = event.name
1196+
return {
1197+
k: v[1] for k, v in debug_handle_to_output.items()
1198+
}, debug_handle_to_op_name
11921199

11931200
def to_dataframe(
11941201
self,
@@ -1364,8 +1371,12 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
13641371
raise ValueError(
13651372
"The aot intermediate outputs is required but not populated."
13661373
)
1374+
# The runtime_op_names will be used later to map runtime debug_handle to op_name
1375+
runtime_intermediate_outputs, runtime_op_names = (
1376+
self._get_runtime_intermediate_outputs_and_op_names()
1377+
)
13671378
mapping = map_runtime_aot_intermediate_outputs(
1368-
self._aot_intermediate_outputs, self._get_runtime_intermediate_outputs()
1379+
self._aot_intermediate_outputs, runtime_intermediate_outputs
13691380
)
13701381
metric = distance.strip().upper()
13711382
if metric == "MSE":

devtools/inspector/tests/inspector_test.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self):
537537
)
538538
)
539539

540-
def test_get_runtime_intermediate_outputs(self):
540+
def test_get_runtime_intermediate_outputs_and_op_names(self):
541541
# Create a context manager to patch functions called by Inspector.__init__
542542
with patch.object(
543543
_inspector, "parse_etrecord", return_value=None
@@ -560,25 +560,39 @@ def test_get_runtime_intermediate_outputs(self):
560560
EventBlock(name=EVENT_BLOCK_NAME, events=self._gen_random_events())
561561
]
562562

563-
runtime_outputs = inspector_instance._get_runtime_intermediate_outputs()
564-
# This output should be a dictionary with 5 keys
563+
runtime_outputs, op_names = (
564+
inspector_instance._get_runtime_intermediate_outputs_and_op_names()
565+
)
566+
# These outputs and op_names dictionaries should all have 5 keys
565567
self.assertEqual(
566568
len(runtime_outputs),
567569
5,
568570
)
569-
# Check that keys (0,) and (1,) are not in the dictionary(skip OPERATOR_CALL and op_types are empty)
571+
self.assertEqual(
572+
len(op_names),
573+
5,
574+
)
575+
576+
# Check that keys (0,) and (1,) are not in these two dictionaries(skip OPERATOR_CALL and op_types are empty)
570577
self.assertNotIn((0,), runtime_outputs)
571578
self.assertNotIn((1,), runtime_outputs)
579+
self.assertNotIn((0,), op_names)
580+
self.assertNotIn((1,), op_names)
572581

573582
# Same debug_handle but different instruction_id, should record the last one
574583
self.assertIn((4,), runtime_outputs)
584+
self.assertIn((4,), op_names)
575585
self.assertTrue(
576586
torch.equal(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0]))
577587
)
588+
self.assertEqual(op_names[(4,)], "op_3")
589+
578590
# Check that keys (5,) to (8,) are in the dictionary and have values of the correct size
579591
for key in range(5, 9):
580592
self.assertIn((key,), runtime_outputs)
593+
self.assertIn((key,), op_names)
581594
self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE)
595+
self.assertEqual(op_names[(key,)], f"op_{key-1}")
582596

583597
def test_calculate_numeric_gap(self):
584598
# Create a context manager to patch functions called by Inspector.__init__
@@ -608,8 +622,8 @@ def test_calculate_numeric_gap(self):
608622
}
609623

610624
inspector_instance._aot_intermediate_outputs = aot_intermediate_outputs
611-
inspector_instance._get_runtime_intermediate_outputs = (
612-
lambda: runtime_intermediate_outputs
625+
inspector_instance._get_runtime_intermediate_outputs_and_op_names = (
626+
lambda: (runtime_intermediate_outputs, {})
613627
)
614628

615629
df = inspector_instance.calculate_numeric_gap(distance="L1")

0 commit comments

Comments
 (0)