|
24 | 24 | ETRecord, |
25 | 25 | ETRecordReservedFileNames, |
26 | 26 | ) |
27 | | -from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge |
28 | | -from executorch.exir.program._program import to_edge_transform_and_lower |
| 27 | +from executorch.exir import EdgeCompileConfig, EdgeProgramManager |
| 28 | +from executorch.exir.program._program import to_edge, to_edge_transform_and_lower |
29 | 29 | from torch.export import export |
30 | 30 |
|
31 | 31 |
|
@@ -105,11 +105,13 @@ def assert_etrecord_saveable(self, etrecord: ETRecord) -> None: |
105 | 105 | self.assertIsNotNone(etrecord._debug_handle_map) |
106 | 106 | self.assertIsNotNone(etrecord._delegate_map) |
107 | 107 |
|
108 | | - def get_test_model(self): |
| 108 | + def get_test_model(self, generate_etrecord=False): |
109 | 109 | f = models.BasicSinMax() |
110 | 110 | aten_dialect = export(f, f.get_random_inputs(), strict=True) |
111 | 111 | edge_program: EdgeProgramManager = to_edge( |
112 | | - aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=False) |
| 112 | + aten_dialect, |
| 113 | + compile_config=EdgeCompileConfig(_check_ir_validity=False), |
| 114 | + generate_etrecord=generate_etrecord, |
113 | 115 | ) |
114 | 116 | edge_program_copy = copy.deepcopy(edge_program) |
115 | 117 | return (aten_dialect, edge_program_copy, edge_program.to_executorch()) |
@@ -392,6 +394,82 @@ def test_get_etrecord_from_executorch_program_manager_without_generation(self): |
392 | 394 |
|
393 | 395 | self.assertIn("ETRecord was not generated", str(context.exception)) |
394 | 396 |
|
| 397 | + def test_to_edge_with_etrecord_generation(self): |
| 398 | + """Test that to_edge generates ETRecord correctly.""" |
| 399 | + aten_program, edge_manager, _ = self.get_test_model(generate_etrecord=True) |
| 400 | + |
| 401 | + # Verify that ETRecord was generated and attached |
| 402 | + self.assertIsNotNone(edge_manager._etrecord) |
| 403 | + etrecord = edge_manager._etrecord |
| 404 | + self.assert_legal_etrecord_in_edge_program(etrecord) |
| 405 | + |
| 406 | + # Verify the exported program matches the input |
| 407 | + self.check_graph_closeness( |
| 408 | + etrecord.exported_program, |
| 409 | + aten_program.graph_module, |
| 410 | + ) |
| 411 | + self.assertEqual( |
| 412 | + etrecord.export_graph_id, |
| 413 | + id(aten_program.graph), |
| 414 | + ) |
| 415 | + |
| 416 | + # Verify the edge dialect program matches the edge manager |
| 417 | + self.check_graph_closeness( |
| 418 | + etrecord.edge_dialect_program, |
| 419 | + edge_manager.exported_program().graph_module, |
| 420 | + ) |
| 421 | + |
| 422 | + def test_to_edge_without_etrecord_generation(self): |
| 423 | + """Test that to_edge works correctly without ETRecord generation.""" |
| 424 | + # Test with generate_etrecord=False (default) |
| 425 | + _, edge_manager, et_manager = self.get_test_model() |
| 426 | + |
| 427 | + # Verify that no ETRecord was generated |
| 428 | + self.assertIsNone(edge_manager._etrecord) |
| 429 | + |
| 430 | + # Test get_etrecord method should raise RuntimeError |
| 431 | + with self.assertRaises(RuntimeError): |
| 432 | + et_manager.get_etrecord() |
| 433 | + |
| 434 | + def test_to_edge_etrecord_save_and_parse(self): |
| 435 | + """Test that ETRecord generated by to_edge can be saved and parsed.""" |
| 436 | + aten_program, _, et_manager = self.get_test_model(generate_etrecord=True) |
| 437 | + |
| 438 | + etrecord = et_manager.get_etrecord() |
| 439 | + |
| 440 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 441 | + etrecord_path = tmpdirname + "/etrecord_to_edge.bin" |
| 442 | + |
| 443 | + etrecord.save(etrecord_path) |
| 444 | + |
| 445 | + # Parse ETRecord back and verify |
| 446 | + parsed_etrecord = parse_etrecord(etrecord_path) |
| 447 | + |
| 448 | + # Validate that all components are preserved |
| 449 | + # Note: Skip graph structure comparison due to transformation differences |
| 450 | + self.check_graph_closeness( |
| 451 | + etrecord.exported_program, parsed_etrecord.exported_program |
| 452 | + ) |
| 453 | + self.check_graph_closeness( |
| 454 | + etrecord.edge_dialect_program, parsed_etrecord.edge_dialect_program |
| 455 | + ) |
| 456 | + |
| 457 | + # Validate executorch program data |
| 458 | + self.assertEqual( |
| 459 | + parsed_etrecord._debug_handle_map, |
| 460 | + json.loads(json.dumps(et_manager.debug_handle_map)), |
| 461 | + ) |
| 462 | + self.assertEqual( |
| 463 | + parsed_etrecord._delegate_map, |
| 464 | + json.loads(json.dumps(et_manager.delegate_map)), |
| 465 | + ) |
| 466 | + |
| 467 | + # Validate export graph id |
| 468 | + self.assertEqual( |
| 469 | + parsed_etrecord.export_graph_id, |
| 470 | + id(aten_program.graph), |
| 471 | + ) |
| 472 | + |
395 | 473 | def test_to_edge_transform_and_lower_etrecord_save_and_parse(self): |
396 | 474 | """Test that ETRecord generated by to_edge_transform_and_lower can be saved and parsed.""" |
397 | 475 | f = models.BasicSinMax() |
|
0 commit comments