|
27 | 27 | )
|
28 | 28 | from executorch.exir import EdgeCompileConfig, EdgeProgramManager
|
29 | 29 | from executorch.exir.program._program import to_edge, to_edge_transform_and_lower
|
| 30 | + |
| 31 | +from executorch.export import export as etexport, ExportRecipe, StageType |
30 | 32 | from torch.export import export
|
31 | 33 |
|
32 | 34 |
|
@@ -137,6 +139,33 @@ def get_test_model_with_bundled_program(self):
|
137 | 139 | bundled_program = BundledProgram(et_output, method_test_suites)
|
138 | 140 | return (aten_dialect, edge_program_copy, bundled_program)
|
139 | 141 |
|
| 142 | + def get_test_export_session(self, generate_etrecord=False, to_edge_flow=False): |
| 143 | + f = models.BasicSinMax() |
| 144 | + example_inputs = [f.get_random_inputs()] |
| 145 | + export_recipe = None |
| 146 | + |
| 147 | + if to_edge_flow: |
| 148 | + export_recipe = ExportRecipe( |
| 149 | + pipeline_stages=[ |
| 150 | + StageType.TORCH_EXPORT, |
| 151 | + StageType.TO_EDGE, |
| 152 | + StageType.TO_BACKEND, |
| 153 | + StageType.TO_EXECUTORCH, |
| 154 | + ] |
| 155 | + ) |
| 156 | + else: |
| 157 | + export_recipe = ExportRecipe() |
| 158 | + |
| 159 | + # Test with generate_etrecord=True |
| 160 | + export_session = etexport( |
| 161 | + model=f, |
| 162 | + example_inputs=example_inputs, |
| 163 | + export_recipe=export_recipe, |
| 164 | + generate_etrecord=generate_etrecord, |
| 165 | + ) |
| 166 | + |
| 167 | + return export_session |
| 168 | + |
140 | 169 | # Serialized and deserialized graph modules are not completely the same, so we check
|
141 | 170 | # that they are close enough and match especially on the parameters we care about in the Developer Tools.
|
142 | 171 | def check_graph_closeness(self, graph_a, graph_b):
|
@@ -1298,6 +1327,113 @@ def test_add_all_programs_sequentially(self):
|
1298 | 1327 | json.loads(json.dumps(et_output.delegate_map)),
|
1299 | 1328 | )
|
1300 | 1329 |
|
| 1330 | + def test_executorch_export_with_etrecord_generation(self): |
| 1331 | + """Test that executorch.export generates ETRecord correctly when generate_etrecord=True.""" |
| 1332 | + # Verify that ETRecord was generated and can be retrieved |
| 1333 | + export_session = self.get_test_export_session(generate_etrecord=True) |
| 1334 | + etrecord = export_session.get_etrecord() |
| 1335 | + self.assertIsNotNone(etrecord) |
| 1336 | + self.assert_etrecord_saveable(etrecord) |
| 1337 | + |
| 1338 | + # Verify the executorch program data matches |
| 1339 | + et_manager = export_session.get_executorch_program_manager() |
| 1340 | + self.assertEqual(etrecord._debug_handle_map, et_manager.debug_handle_map) |
| 1341 | + self.assertEqual(etrecord._delegate_map, et_manager.delegate_map) |
| 1342 | + |
| 1343 | + def test_executorch_export_without_etrecord_generation(self): |
| 1344 | + """Test that executorch.export works correctly without ETRecord generation.""" |
| 1345 | + # Test with generate_etrecord=False (default) |
| 1346 | + export_session = self.get_test_export_session(generate_etrecord=False) |
| 1347 | + |
| 1348 | + # Verify that no ETRecord was generated |
| 1349 | + with self.assertRaises(RuntimeError) as context: |
| 1350 | + export_session.get_etrecord() |
| 1351 | + |
| 1352 | + self.assertIn("ETRecord was not generated", str(context.exception)) |
| 1353 | + |
| 1354 | + # Verify that the export session still works correctly |
| 1355 | + self.assertIsNotNone(export_session.get_executorch_program_manager()) |
| 1356 | + self.assertTrue(len(export_session.get_pte_buffer()) > 0) |
| 1357 | + |
| 1358 | + def test_executorch_export_etrecord_save_and_parse(self): |
| 1359 | + """Test that ETRecord generated by executorch.export can be saved and parsed.""" |
| 1360 | + export_session = self.get_test_export_session(generate_etrecord=True) |
| 1361 | + |
| 1362 | + etrecord = export_session.get_etrecord() |
| 1363 | + |
| 1364 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 1365 | + etrecord_path = tmpdirname + "/etrecord_export.bin" |
| 1366 | + |
| 1367 | + etrecord.save(etrecord_path) |
| 1368 | + |
| 1369 | + # Parse ETRecord back and verify |
| 1370 | + parsed_etrecord = parse_etrecord(etrecord_path) |
| 1371 | + |
| 1372 | + # Validate that all components are preserved |
| 1373 | + self.assertIsNotNone(parsed_etrecord.exported_program) |
| 1374 | + self.assertIsNotNone(parsed_etrecord.edge_dialect_program) |
| 1375 | + |
| 1376 | + # Validate executorch program data |
| 1377 | + et_manager = export_session.get_executorch_program_manager() |
| 1378 | + self.assertEqual( |
| 1379 | + parsed_etrecord._debug_handle_map, |
| 1380 | + json.loads(json.dumps(et_manager.debug_handle_map)), |
| 1381 | + ) |
| 1382 | + self.assertEqual( |
| 1383 | + parsed_etrecord._delegate_map, |
| 1384 | + json.loads(json.dumps(et_manager.delegate_map)), |
| 1385 | + ) |
| 1386 | + |
| 1387 | + # Validate export graph id is preserved |
| 1388 | + self.assertIsNotNone(parsed_etrecord.export_graph_id) |
| 1389 | + |
| 1390 | + def test_executorch_export_with_to_edge_flow(self): |
| 1391 | + """Test executorch.export with TO_EDGE flow and ETRecord generation.""" |
| 1392 | + export_session = self.get_test_export_session( |
| 1393 | + generate_etrecord=True, |
| 1394 | + to_edge_flow=True, |
| 1395 | + ) |
| 1396 | + |
| 1397 | + # Verify that ETRecord was generated |
| 1398 | + etrecord = export_session.get_etrecord() |
| 1399 | + self.assertIsNotNone(etrecord) |
| 1400 | + self.assert_etrecord_saveable(etrecord) |
| 1401 | + |
| 1402 | + def test_executorch_export_etrecord_with_to_edge_flow_save_and_parse(self): |
| 1403 | + """Test that ETRecord generated by executorch.export can be saved and parsed.""" |
| 1404 | + export_session = self.get_test_export_session( |
| 1405 | + generate_etrecord=True, |
| 1406 | + to_edge_flow=True, |
| 1407 | + ) |
| 1408 | + |
| 1409 | + etrecord = export_session.get_etrecord() |
| 1410 | + |
| 1411 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 1412 | + etrecord_path = tmpdirname + "/etrecord_export.bin" |
| 1413 | + |
| 1414 | + etrecord.save(etrecord_path) |
| 1415 | + |
| 1416 | + # Parse ETRecord back and verify |
| 1417 | + parsed_etrecord = parse_etrecord(etrecord_path) |
| 1418 | + |
| 1419 | + # Validate that all components are preserved |
| 1420 | + self.assertIsNotNone(parsed_etrecord.exported_program) |
| 1421 | + self.assertIsNotNone(parsed_etrecord.edge_dialect_program) |
| 1422 | + |
| 1423 | + # Validate executorch program data |
| 1424 | + et_manager = export_session.get_executorch_program_manager() |
| 1425 | + self.assertEqual( |
| 1426 | + parsed_etrecord._debug_handle_map, |
| 1427 | + json.loads(json.dumps(et_manager.debug_handle_map)), |
| 1428 | + ) |
| 1429 | + self.assertEqual( |
| 1430 | + parsed_etrecord._delegate_map, |
| 1431 | + json.loads(json.dumps(et_manager.delegate_map)), |
| 1432 | + ) |
| 1433 | + |
| 1434 | + # Validate export graph id is preserved |
| 1435 | + self.assertIsNotNone(parsed_etrecord.export_graph_id) |
| 1436 | + |
1301 | 1437 | def test_update_representative_inputs_with_list(self):
|
1302 | 1438 | """Test update_representative_inputs with a list of ProgramInput objects."""
|
1303 | 1439 | captured_output, edge_output, et_output = self.get_test_model()
|
|
0 commit comments