2929from executorch .exir .serde .export_serialize import SerializedArtifact
3030from executorch .exir .serde .serialize import deserialize , serialize
3131
32+ ProgramInput = List [Value ]
3233ProgramOutput = List [Value ]
3334
3435try :
@@ -49,6 +50,7 @@ class ETRecordReservedFileNames(StrEnum):
4950 DEBUG_HANDLE_MAP_NAME = "debug_handle_map"
5051 DELEGATE_MAP_NAME = "delegate_map"
5152 REFERENCE_OUTPUTS = "reference_outputs"
53+ REPRESENTATIVE_INPUTS = "representative_inputs"
5254
5355
5456@dataclass
@@ -60,6 +62,7 @@ class ETRecord:
6062 Dict [str , Dict [int , Dict [str , Union [str , _DelegateDebugIdentifierMap ]]]]
6163 ] = None
6264 _reference_outputs : Optional [Dict [str , List [ProgramOutput ]]] = None
65+ _representative_inputs : Optional [List [ProgramOutput ]] = None
6366
6467
6568def _handle_exported_program (
@@ -157,6 +160,24 @@ def _get_reference_outputs(
157160 return reference_outputs
158161
159162
163+ def _get_representative_inputs (
164+ bundled_program : BundledProgram ,
165+ ) -> List [ProgramInput ]:
166+ """
167+ Extracts out the inputs from the bundled program, keyed by the method names.
168+ """
169+ for method_test_suite in bundled_program .method_test_suites :
170+ if method_test_suite .method_name == "forward" :
171+ if not method_test_suite .test_cases :
172+ raise ValueError (
173+ "The 'forward' method is defined, but no corresponding input test cases are provided."
174+ )
175+ # Get first example input from the forward method
176+ test_case = method_test_suite .test_cases [0 ]
177+ return test_case .inputs
178+ raise ValueError ("No 'forward' method found in the bundled program." )
179+
180+
160181def generate_etrecord (
161182 et_record : Union [str , os .PathLike , BinaryIO , IO [bytes ]],
162183 edge_dialect_program : Union [EdgeProgramManager , ExirExportedProgram ],
@@ -244,8 +265,17 @@ def generate_etrecord(
244265 # @lint-ignore PYTHONPICKLEISBAD
245266 pickle .dumps (reference_outputs ),
246267 )
268+
269+ representative_inputs = _get_representative_inputs (executorch_program )
270+ etrecord_zip .writestr (
271+ ETRecordReservedFileNames .REPRESENTATIVE_INPUTS ,
272+ # @lint-ignore PYTHONPICKLEISBAD
273+ pickle .dumps (representative_inputs ),
274+ )
247275 executorch_program = executorch_program .executorch_program
248276
277+
278+
249279 etrecord_zip .writestr (
250280 ETRecordReservedFileNames .DEBUG_HANDLE_MAP_NAME ,
251281 json .dumps (executorch_program .debug_handle_map ),
@@ -290,6 +320,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
290320 delegate_map = None
291321 edge_dialect_program = None
292322 reference_outputs = None
323+ representative_inputs = None
293324
294325 serialized_exported_program_files = set ()
295326 serialized_state_dict_files = set ()
@@ -321,6 +352,11 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
321352 reference_outputs = pickle .loads (
322353 etrecord_zip .read (ETRecordReservedFileNames .REFERENCE_OUTPUTS )
323354 )
355+ elif entry == ETRecordReservedFileNames .REPRESENTATIVE_INPUTS :
356+ # @lint-ignore PYTHONPICKLEISBAD
357+ representative_inputs = pickle .loads (
358+ etrecord_zip .read (ETRecordReservedFileNames .REPRESENTATIVE_INPUTS )
359+ )
324360 else :
325361 if entry .endswith ("state_dict" ):
326362 serialized_state_dict_files .add (entry )
@@ -352,4 +388,5 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
352388 _debug_handle_map = debug_handle_map ,
353389 _delegate_map = delegate_map ,
354390 _reference_outputs = reference_outputs ,
391+ _representative_inputs = representative_inputs ,
355392 )
0 commit comments