Skip to content

Commit 69ca2dc

Browse files
Juntian Liuhinriksnaer
authored andcommitted
Add functionality to map runtime debug_handles to op names
Differential Revision: D77266536 Pull Request resolved: pytorch#11987
1 parent 0032869 commit 69ca2dc

File tree

2 files changed

+43
-19
lines changed

2 files changed

+43
-19
lines changed

devtools/inspector/_inspector.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,12 +1161,15 @@ def _consume_etrecord(self) -> None:
11611161
)
11621162

11631163
# TODO: Make it more extensible to further merge overlapping debug handles
1164-
def _get_runtime_intermediate_outputs(self) -> Dict[Tuple[int, ...], Any]:
1164+
def _get_runtime_intermediate_outputs_and_op_names(
1165+
self,
1166+
) -> Tuple[Dict[Tuple[int, ...], Any], Dict[Tuple[int, ...], str]]:
11651167
"""
1166-
Retrieve the raw runtime intermediate outputs(debug handles and value mappings)
1167-
from the event blocks. These outputs will be processed later to merge overlapping debug handles.
1168+
Retrieve the runtime intermediate outputs(debug handles and intermediate values mappings)
1169+
from the event blocks, along with the corresponding debug handles and op names mapping.
11681170
"""
11691171
debug_handle_to_output = {}
1172+
debug_handle_to_op_name = {}
11701173
for event_block in self.event_blocks:
11711174
for event in event_block.events:
11721175
# Skip OPERATOR_CALL events to avoid double-counting and exclude framework tax
@@ -1175,20 +1178,23 @@ def _get_runtime_intermediate_outputs(self) -> Dict[Tuple[int, ...], Any]:
11751178
or not event.op_types
11761179
):
11771180
continue
1178-
# Normalize debug_handles to a tuple
1179-
debug_handles = event.debug_handles
1180-
if isinstance(debug_handles, int):
1181-
debug_handles = (debug_handles,)
1181+
# Normalize debug_handle to a tuple
1182+
debug_handle = event.debug_handles
1183+
if isinstance(debug_handle, int):
1184+
debug_handle = (debug_handle,)
11821185
else:
1183-
debug_handles = tuple(debug_handles)
1184-
current_entry = debug_handle_to_output.get(debug_handles, (-1, None))
1185-
# When event has same debug handles, only keep the one with the largest instruction id
1186+
debug_handle = tuple(debug_handle)
1187+
current_entry = debug_handle_to_output.get(debug_handle, (-1, None))
1188+
# When event has same debug_handle, only keep the one with the largest instruction id
11861189
if event._instruction_id > current_entry[0]:
1187-
debug_handle_to_output[debug_handles] = (
1190+
debug_handle_to_output[debug_handle] = (
11881191
event._instruction_id,
11891192
event.debug_data,
11901193
)
1191-
return {k: v[1] for k, v in debug_handle_to_output.items()}
1194+
debug_handle_to_op_name[debug_handle] = event.name
1195+
return {
1196+
k: v[1] for k, v in debug_handle_to_output.items()
1197+
}, debug_handle_to_op_name
11921198

11931199
def to_dataframe(
11941200
self,
@@ -1364,8 +1370,12 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
13641370
raise ValueError(
13651371
"The aot intermediate outputs is required but not populated."
13661372
)
1373+
# The runtime_op_names will be used later to map runtime debug_handle to op_name
1374+
runtime_intermediate_outputs, runtime_op_names = (
1375+
self._get_runtime_intermediate_outputs_and_op_names()
1376+
)
13671377
mapping = map_runtime_aot_intermediate_outputs(
1368-
self._aot_intermediate_outputs, self._get_runtime_intermediate_outputs()
1378+
self._aot_intermediate_outputs, runtime_intermediate_outputs
13691379
)
13701380
metric = distance.strip().upper()
13711381
if metric == "MSE":

devtools/inspector/tests/inspector_test.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self):
537537
)
538538
)
539539

540-
def test_get_runtime_intermediate_outputs(self):
540+
def test_get_runtime_intermediate_outputs_and_op_names(self):
541541
# Create a context manager to patch functions called by Inspector.__init__
542542
with patch.object(
543543
_inspector, "parse_etrecord", return_value=None
@@ -560,25 +560,39 @@ def test_get_runtime_intermediate_outputs(self):
560560
EventBlock(name=EVENT_BLOCK_NAME, events=self._gen_random_events())
561561
]
562562

563-
runtime_outputs = inspector_instance._get_runtime_intermediate_outputs()
564-
# This output should be a dictionary with 5 keys
563+
runtime_outputs, op_names = (
564+
inspector_instance._get_runtime_intermediate_outputs_and_op_names()
565+
)
566+
# These outputs and op_names dictionaries should all have 5 keys
565567
self.assertEqual(
566568
len(runtime_outputs),
567569
5,
568570
)
569-
# Check that keys (0,) and (1,) are not in the dictionary(skip OPERATOR_CALL and op_types are empty)
571+
self.assertEqual(
572+
len(op_names),
573+
5,
574+
)
575+
576+
# Check that keys (0,) and (1,) are not in these two dictionaries(skip OPERATOR_CALL and op_types are empty)
570577
self.assertNotIn((0,), runtime_outputs)
571578
self.assertNotIn((1,), runtime_outputs)
579+
self.assertNotIn((0,), op_names)
580+
self.assertNotIn((1,), op_names)
572581

573582
# Same debug_handle but different instruction_id, should record the last one
574583
self.assertIn((4,), runtime_outputs)
584+
self.assertIn((4,), op_names)
575585
self.assertTrue(
576586
torch.equal(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0]))
577587
)
588+
self.assertEqual(op_names[(4,)], "op_3")
589+
578590
# Check that keys (5,) to (8,) are in the dictionary and have values of the correct size
579591
for key in range(5, 9):
580592
self.assertIn((key,), runtime_outputs)
593+
self.assertIn((key,), op_names)
581594
self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE)
595+
self.assertEqual(op_names[(key,)], f"op_{key-1}")
582596

583597
def test_calculate_numeric_gap(self):
584598
# Create a context manager to patch functions called by Inspector.__init__
@@ -608,8 +622,8 @@ def test_calculate_numeric_gap(self):
608622
}
609623

610624
inspector_instance._aot_intermediate_outputs = aot_intermediate_outputs
611-
inspector_instance._get_runtime_intermediate_outputs = (
612-
lambda: runtime_intermediate_outputs
625+
inspector_instance._get_runtime_intermediate_outputs_and_op_names = (
626+
lambda: (runtime_intermediate_outputs, {})
613627
)
614628

615629
df = inspector_instance.calculate_numeric_gap(distance="L1")

0 commit comments

Comments
 (0)