@@ -68,7 +68,7 @@ def __init__(
6868 Dict [str , Dict [int , Dict [str , Union [str , _DelegateDebugIdentifierMap ]]]]
6969 ] = None ,
7070 _reference_outputs : Optional [Dict [str , List [ProgramOutput ]]] = None ,
71- _representative_inputs : Optional [List [ProgramOutput ]] = None ,
71+ _representative_inputs : Optional [List [ProgramInput ]] = None ,
7272 ):
7373 self .exported_program = exported_program
7474 self .export_graph_id = export_graph_id
@@ -345,6 +345,56 @@ def add_edge_dialect_program(
345345 # Set the extracted data
346346 self .edge_dialect_program = processed_edge_dialect_program
347347
348+ def update_representative_inputs (
349+ self ,
350+ representative_inputs : Union [List [ProgramInput ], BundledProgram ],
351+ ) -> None :
352+ """
353+ Update the representative inputs in the ETRecord.
354+
355+ This method allows users to customize the representative inputs that will be
356+ included when the ETRecord is saved. The representative inputs can be provided
357+ directly as a list or extracted from a BundledProgram.
358+
359+ Args:
360+ representative_inputs: Either a list of ProgramInput objects or a BundledProgram
361+ from which representative inputs will be extracted.
362+ """
363+ if isinstance (representative_inputs , BundledProgram ):
364+ self ._representative_inputs = _get_representative_inputs (
365+ representative_inputs
366+ )
367+ else :
368+ self ._representative_inputs = representative_inputs
369+
370+ def update_reference_outputs (
371+ self ,
372+ reference_outputs : Union [
373+ Dict [str , List [ProgramOutput ]], List [ProgramOutput ], BundledProgram
374+ ],
375+ ) -> None :
376+ """
377+ Update the reference outputs in the ETRecord.
378+
379+ This method allows users to customize the reference outputs that will be
380+ included when the ETRecord is saved. The reference outputs can be provided
381+ directly as a dictionary mapping method names to lists of outputs, as a
382+ single list of outputs (which will be treated as {"forward": List[ProgramOutput]}),
383+ or extracted from a BundledProgram.
384+
385+ Args:
386+ reference_outputs: Either a dictionary mapping method names to lists of
387+ ProgramOutput objects, a single list of ProgramOutput objects (treated
388+ as outputs for the "forward" method), or a BundledProgram from which
389+ reference outputs will be extracted.
390+ """
391+ if isinstance (reference_outputs , BundledProgram ):
392+ self ._reference_outputs = _get_reference_outputs (reference_outputs )
393+ elif isinstance (reference_outputs , list ):
394+ self ._reference_outputs = {"forward" : reference_outputs }
395+ else :
396+ self ._reference_outputs = reference_outputs
397+
348398
349399def _get_reference_outputs (
350400 bundled_program : BundledProgram ,
0 commit comments