Skip to content

Commit 4479a5f

Browse files
committed
make etrecord set representive IO
representive input and reference output in etrecord will not be set during export flow. To continue supporting the two functionalities, this diff creates two class methods to customize IO in etrecord. Differential Revision: [D79386896](https://our.internmc.facebook.com/intern/diff/D79386896/) [ghstack-poisoned]
1 parent 4b786bd commit 4479a5f

File tree

2 files changed

+370
-5
lines changed

2 files changed

+370
-5
lines changed

devtools/etrecord/_etrecord.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868
Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]
6969
] = None,
7070
_reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None,
71-
_representative_inputs: Optional[List[ProgramOutput]] = None,
71+
_representative_inputs: Optional[List[ProgramInput]] = None,
7272
):
7373
self.exported_program = exported_program
7474
self.export_graph_id = export_graph_id
@@ -345,6 +345,56 @@ def add_edge_dialect_program(
345345
# Set the extracted data
346346
self.edge_dialect_program = processed_edge_dialect_program
347347

348+
def update_representative_inputs(
349+
self,
350+
representative_inputs: Union[List[ProgramInput], BundledProgram],
351+
) -> None:
352+
"""
353+
Update the representative inputs in the ETRecord.
354+
355+
This method allows users to customize the representative inputs that will be
356+
included when the ETRecord is saved. The representative inputs can be provided
357+
directly as a list or extracted from a BundledProgram.
358+
359+
Args:
360+
representative_inputs: Either a list of ProgramInput objects or a BundledProgram
361+
from which representative inputs will be extracted.
362+
"""
363+
if isinstance(representative_inputs, BundledProgram):
364+
self._representative_inputs = _get_representative_inputs(
365+
representative_inputs
366+
)
367+
else:
368+
self._representative_inputs = representative_inputs
369+
370+
def update_reference_outputs(
371+
self,
372+
reference_outputs: Union[
373+
Dict[str, List[ProgramOutput]], List[ProgramOutput], BundledProgram
374+
],
375+
) -> None:
376+
"""
377+
Update the reference outputs in the ETRecord.
378+
379+
This method allows users to customize the reference outputs that will be
380+
included when the ETRecord is saved. The reference outputs can be provided
381+
directly as a dictionary mapping method names to lists of outputs, as a
382+
single list of outputs (which will be treated as {"forward": List[ProgramOutput]}),
383+
or extracted from a BundledProgram.
384+
385+
Args:
386+
reference_outputs: Either a dictionary mapping method names to lists of
387+
ProgramOutput objects, a single list of ProgramOutput objects (treated
388+
as outputs for the "forward" method), or a BundledProgram from which
389+
reference outputs will be extracted.
390+
"""
391+
if isinstance(reference_outputs, BundledProgram):
392+
self._reference_outputs = _get_reference_outputs(reference_outputs)
393+
elif isinstance(reference_outputs, list):
394+
self._reference_outputs = {"forward": reference_outputs}
395+
else:
396+
self._reference_outputs = reference_outputs
397+
348398

349399
def _get_reference_outputs(
350400
bundled_program: BundledProgram,

0 commit comments

Comments
 (0)