Skip to content

Commit 09a2e88

Browse files
authored
Add support for checking more than one output from delegate in numerical comparator
Differential Revision: D80272038 Pull Request resolved: #13722
1 parent d5a5164 commit 09a2e88

File tree

10 files changed

+370
-106
lines changed

10 files changed

+370
-106
lines changed

devtools/etrecord/_etrecord.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class ETRecordReservedFileNames(StrEnum):
5252
ET_DIALECT_GRAPH_MODULE = "et_dialect_graph_module"
5353
DEBUG_HANDLE_MAP_NAME = "debug_handle_map"
5454
DELEGATE_MAP_NAME = "delegate_map"
55+
INSTRUCTION_ID_TO_NUM_OUTS_MAP_NAME = "instruction_id_to_num_outs_map"
5556
REFERENCE_OUTPUTS = "reference_outputs"
5657
REPRESENTATIVE_INPUTS = "representative_inputs"
5758

@@ -67,6 +68,9 @@ def __init__(
6768
_delegate_map: Optional[
6869
Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]
6970
] = None,
71+
_instruction_id_to_num_outs_map: Optional[
72+
Dict[str, Dict[int, Union[int, List[int]]]]
73+
] = None,
7074
_reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None,
7175
_representative_inputs: Optional[List[ProgramInput]] = None,
7276
):
@@ -92,6 +96,7 @@ def __init__(
9296
self.graph_map = graph_map
9397
self._debug_handle_map = _debug_handle_map
9498
self._delegate_map = _delegate_map
99+
self._instruction_id_to_num_outs_map = _instruction_id_to_num_outs_map
95100
self._reference_outputs = _reference_outputs
96101
self._representative_inputs = _representative_inputs
97102

@@ -172,6 +177,12 @@ def _save_metadata(self, etrecord_zip: ZipFile) -> None:
172177
json.dumps(self._delegate_map),
173178
)
174179

180+
if self._instruction_id_to_num_outs_map is not None:
181+
etrecord_zip.writestr(
182+
ETRecordReservedFileNames.INSTRUCTION_ID_TO_NUM_OUTS_MAP_NAME,
183+
json.dumps(self._instruction_id_to_num_outs_map),
184+
)
185+
175186
if self._reference_outputs is not None:
176187
etrecord_zip.writestr(
177188
ETRecordReservedFileNames.REFERENCE_OUTPUTS,
@@ -284,6 +295,7 @@ def add_executorch_program(
284295
if (
285296
self._debug_handle_map is not None
286297
or self._delegate_map is not None
298+
or self._instruction_id_to_num_outs_map is not None
287299
or self._reference_outputs is not None
288300
or self._representative_inputs is not None
289301
):
@@ -293,13 +305,18 @@ def add_executorch_program(
293305
)
294306

295307
# Process executorch program and extract data
296-
debug_handle_map, delegate_map, reference_outputs, representative_inputs = (
297-
_process_executorch_program(executorch_program)
298-
)
308+
(
309+
debug_handle_map,
310+
delegate_map,
311+
instruction_id_to_num_outs_map,
312+
reference_outputs,
313+
representative_inputs,
314+
) = _process_executorch_program(executorch_program)
299315

300316
# Set the extracted data
301317
self._debug_handle_map = debug_handle_map
302318
self._delegate_map = delegate_map
319+
self._instruction_id_to_num_outs_map = instruction_id_to_num_outs_map
303320
self._reference_outputs = reference_outputs
304321
self._representative_inputs = representative_inputs
305322

@@ -593,7 +610,9 @@ def _process_executorch_program(
593610
executorch_program: Union[
594611
ExecutorchProgram, ExecutorchProgramManager, BundledProgram
595612
]
596-
) -> tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[List]]:
613+
) -> tuple[
614+
Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict], Optional[List]
615+
]:
597616
"""Process executorch program and return debug maps and bundled program data."""
598617
if isinstance(executorch_program, BundledProgram):
599618
reference_outputs = _get_reference_outputs(executorch_program)
@@ -602,11 +621,30 @@ def _process_executorch_program(
602621
debug_handle_map = executorch_program.executorch_program.debug_handle_map
603622
# pyre-ignore[16]: Item `None` of `typing.Union[None, exir.program._program.ExecutorchProgram, exir.program._program.ExecutorchProgramManager]` has no attribute `debug_handle_map`
604623
delegate_map = executorch_program.executorch_program.delegate_map
605-
return debug_handle_map, delegate_map, reference_outputs, representative_inputs
624+
# pyre-ignore[16]: Item `None` of `typing.Union[None, exir.program._program.ExecutorchProgram, exir.program._program.ExecutorchProgramManager]` has no attribute `instruction_id_to_num_outs_map`
625+
instruction_id_to_num_outs_map = (
626+
executorch_program.executorch_program.instruction_id_to_num_outs_map
627+
)
628+
return (
629+
debug_handle_map,
630+
delegate_map,
631+
instruction_id_to_num_outs_map,
632+
reference_outputs,
633+
representative_inputs,
634+
)
606635
else:
607636
debug_handle_map = executorch_program.debug_handle_map
608637
delegate_map = executorch_program.delegate_map
609-
return debug_handle_map, delegate_map, None, None
638+
instruction_id_to_num_outs_map = (
639+
executorch_program.instruction_id_to_num_outs_map
640+
)
641+
return (
642+
debug_handle_map,
643+
delegate_map,
644+
instruction_id_to_num_outs_map,
645+
None,
646+
None,
647+
)
610648

611649

612650
def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
@@ -640,6 +678,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
640678
graph_map: Dict[str, ExportedProgram] = {}
641679
debug_handle_map = None
642680
delegate_map = None
681+
instruction_id_to_num_outs_map = None
643682
exported_program = None
644683
edge_dialect_program = None
645684
reference_outputs = None
@@ -659,6 +698,12 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
659698
delegate_map = json.loads(
660699
etrecord_zip.read(ETRecordReservedFileNames.DELEGATE_MAP_NAME)
661700
)
701+
elif entry == ETRecordReservedFileNames.INSTRUCTION_ID_TO_NUM_OUTS_MAP_NAME:
702+
instruction_id_to_num_outs_map = json.loads(
703+
etrecord_zip.read(
704+
ETRecordReservedFileNames.INSTRUCTION_ID_TO_NUM_OUTS_MAP_NAME
705+
)
706+
)
662707
elif entry == ETRecordReservedFileNames.ETRECORD_IDENTIFIER:
663708
continue
664709
elif entry == ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM:
@@ -724,6 +769,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
724769
graph_map=graph_map,
725770
_debug_handle_map=debug_handle_map,
726771
_delegate_map=delegate_map,
772+
_instruction_id_to_num_outs_map=instruction_id_to_num_outs_map,
727773
_reference_outputs=reference_outputs,
728774
_representative_inputs=representative_inputs,
729775
export_graph_id=export_graph_id,

devtools/etrecord/tests/etrecord_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,10 @@ def test_etrecord_generation(self):
219219
etrecord._debug_handle_map,
220220
json.loads(json.dumps(et_output.debug_handle_map)),
221221
)
222+
self.assertEqual(
223+
etrecord._instruction_id_to_num_outs_map,
224+
json.loads(json.dumps(et_output.instruction_id_to_num_outs_map)),
225+
)
222226

223227
def test_etrecord_generation_with_bundled_program(self):
224228
(

devtools/inspector/_inspector.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,8 @@ class Event:
317317
op_type: List of op types corresponding to the event.
318318
delegate_debug_identifier: Supplemental identifier used in combination with instruction id.
319319
debug_handles: Debug handles in the model graph to which this event is correlated.
320+
num_outputs: Indicates the number of outputs generated by the node.
321+
Right now only used for call_delegate nodes that output more than one tensor.
320322
stack_trace: A dictionary mapping the name of each associated op to its stack trace.
321323
module_hierarchy: A dictionary mapping the name of each associated op to its module hierarchy.
322324
is_delegated_op: Whether or not the event was delegated.
@@ -337,6 +339,7 @@ class Event:
337339
op_types: List[str] = dataclasses.field(default_factory=list)
338340
delegate_debug_identifier: Optional[Union[int, str]] = None
339341
debug_handles: Optional[Union[int, Sequence[int]]] = None
342+
num_outputs: int = 1
340343
stack_traces: Dict[str, str] = dataclasses.field(default_factory=dict)
341344
module_hierarchy: Dict[str, Dict] = dataclasses.field(default_factory=dict)
342345
is_delegated_op: Optional[bool] = None
@@ -928,6 +931,7 @@ def _gen_resolve_debug_handles(
928931
self,
929932
handle_map: Dict[str, List[int]],
930933
delegate_map: Optional[Dict[str, DelegateMetadata]] = None,
934+
instruction_id_to_num_outs_map: Dict[int, int] = None,
931935
):
932936
"""
933937
Given mappings from instruction id to debug handles, populate the
@@ -945,6 +949,10 @@ def _gen_resolve_debug_handles(
945949
if (instruction_id := str(event._instruction_id)) not in handle_map:
946950
continue
947951

952+
num_outputs = 1
953+
if instruction_id_to_num_outs_map is not None:
954+
num_outputs = instruction_id_to_num_outs_map.get(instruction_id, 1)
955+
event.num_outputs = num_outputs
948956
# For non-delegated event, handles are found in handle_map
949957
if (delegate_debug_id := event.delegate_debug_identifier) is None:
950958
event.debug_handles = handle_map[instruction_id]
@@ -1131,6 +1139,7 @@ def _consume_etrecord(self) -> None:
11311139
if self._etrecord._delegate_map is not None
11321140
else None
11331141
),
1142+
self._etrecord._instruction_id_to_num_outs_map[FORWARD],
11341143
)
11351144

11361145
# (2) Event Metadata Association
@@ -1196,7 +1205,7 @@ def _get_aot_intermediate_outputs_and_op_names(
11961205
# TODO: Make it more extensible to further merge overlapping debug handles
11971206
def _get_runtime_intermediate_outputs_and_op_names(
11981207
self,
1199-
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]:
1208+
) -> Tuple[Dict[DebugHandle, Tuple[Any, int]], Dict[DebugHandle, List[str]]]:
12001209
"""
12011210
Retrieve the runtime intermediate outputs(debug handles and intermediate values mappings)
12021211
from the event blocks, along with the corresponding debug handles and op names mapping.
@@ -1217,12 +1226,15 @@ def _get_runtime_intermediate_outputs_and_op_names(
12171226
debug_handle = (debug_handle,)
12181227
else:
12191228
debug_handle = tuple(debug_handle)
1220-
current_entry = debug_handle_to_output.get(debug_handle, (-1, None))
1229+
current_entry = debug_handle_to_output.get(
1230+
debug_handle, (-1, None, event.num_outputs)
1231+
)
12211232
# When event has same debug_handle, only keep the one with the largest instruction id
12221233
if event._instruction_id > current_entry[0]:
12231234
debug_handle_to_output[debug_handle] = (
12241235
event._instruction_id,
12251236
event.debug_data,
1237+
event.num_outputs,
12261238
)
12271239
# TODO: One debug handle can be associated with multiple op names
12281240
debug_handle_to_op_names[debug_handle] = [event.name]
@@ -1231,7 +1243,7 @@ def _get_runtime_intermediate_outputs_and_op_names(
12311243
debug_handle_to_output
12321244
)
12331245
return {
1234-
k: v[1] for k, v in debug_handle_to_output.items()
1246+
k: (v[1], v[2]) for k, v in debug_handle_to_output.items()
12351247
}, debug_handle_to_op_names
12361248

12371249
def to_dataframe(

0 commit comments

Comments
 (0)