|
25 | 25 | ETRecord,
|
26 | 26 | ETRecordReservedFileNames,
|
27 | 27 | )
|
28 |
| -from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge |
29 |
| -from executorch.exir.program._program import to_edge_transform_and_lower |
| 28 | +from executorch.exir import EdgeCompileConfig, EdgeProgramManager |
| 29 | +from executorch.exir.program._program import to_edge, to_edge_transform_and_lower |
30 | 30 | from torch.export import export
|
31 | 31 |
|
32 | 32 |
|
@@ -106,11 +106,13 @@ def assert_etrecord_saveable(self, etrecord: ETRecord) -> None:
|
106 | 106 | self.assertIsNotNone(etrecord._debug_handle_map)
|
107 | 107 | self.assertIsNotNone(etrecord._delegate_map)
|
108 | 108 |
|
109 |
| - def get_test_model(self): |
| 109 | + def get_test_model(self, generate_etrecord=False): |
110 | 110 | f = models.BasicSinMax()
|
111 | 111 | aten_dialect = export(f, f.get_random_inputs(), strict=True)
|
112 | 112 | edge_program: EdgeProgramManager = to_edge(
|
113 |
| - aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=False) |
| 113 | + aten_dialect, |
| 114 | + compile_config=EdgeCompileConfig(_check_ir_validity=False), |
| 115 | + generate_etrecord=generate_etrecord, |
114 | 116 | )
|
115 | 117 | edge_program_copy = copy.deepcopy(edge_program)
|
116 | 118 | return (aten_dialect, edge_program_copy, edge_program.to_executorch())
|
@@ -428,6 +430,82 @@ def test_get_etrecord_from_executorch_program_manager_without_generation(self):
|
428 | 430 |
|
429 | 431 | self.assertIn("ETRecord was not generated", str(context.exception))
|
430 | 432 |
|
| 433 | + def test_to_edge_with_etrecord_generation(self): |
| 434 | + """Test that to_edge generates ETRecord correctly.""" |
| 435 | + aten_program, edge_manager, _ = self.get_test_model(generate_etrecord=True) |
| 436 | + |
| 437 | + # Verify that ETRecord was generated and attached |
| 438 | + self.assertIsNotNone(edge_manager._etrecord) |
| 439 | + etrecord = edge_manager._etrecord |
| 440 | + self.assert_legal_etrecord_in_edge_program(etrecord) |
| 441 | + |
| 442 | + # Verify the exported program matches the input |
| 443 | + self.check_graph_closeness( |
| 444 | + etrecord.exported_program, |
| 445 | + aten_program.graph_module, |
| 446 | + ) |
| 447 | + self.assertEqual( |
| 448 | + etrecord.export_graph_id, |
| 449 | + id(aten_program.graph), |
| 450 | + ) |
| 451 | + |
| 452 | + # Verify the edge dialect program matches the edge manager |
| 453 | + self.check_graph_closeness( |
| 454 | + etrecord.edge_dialect_program, |
| 455 | + edge_manager.exported_program().graph_module, |
| 456 | + ) |
| 457 | + |
| 458 | + def test_to_edge_without_etrecord_generation(self): |
| 459 | + """Test that to_edge works correctly without ETRecord generation.""" |
| 460 | + # Test with generate_etrecord=False (default) |
| 461 | + _, edge_manager, et_manager = self.get_test_model() |
| 462 | + |
| 463 | + # Verify that no ETRecord was generated |
| 464 | + self.assertIsNone(edge_manager._etrecord) |
| 465 | + |
| 466 | + # Test get_etrecord method should raise RuntimeError |
| 467 | + with self.assertRaises(RuntimeError): |
| 468 | + et_manager.get_etrecord() |
| 469 | + |
| 470 | + def test_to_edge_etrecord_save_and_parse(self): |
| 471 | + """Test that ETRecord generated by to_edge can be saved and parsed.""" |
| 472 | + aten_program, _, et_manager = self.get_test_model(generate_etrecord=True) |
| 473 | + |
| 474 | + etrecord = et_manager.get_etrecord() |
| 475 | + |
| 476 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 477 | + etrecord_path = tmpdirname + "/etrecord_to_edge.bin" |
| 478 | + |
| 479 | + etrecord.save(etrecord_path) |
| 480 | + |
| 481 | + # Parse ETRecord back and verify |
| 482 | + parsed_etrecord = parse_etrecord(etrecord_path) |
| 483 | + |
| 484 | + # Validate that all components are preserved |
| 485 | + # Note: Skip graph structure comparison due to transformation differences |
| 486 | + self.check_graph_closeness( |
| 487 | + etrecord.exported_program, parsed_etrecord.exported_program |
| 488 | + ) |
| 489 | + self.check_graph_closeness( |
| 490 | + etrecord.edge_dialect_program, parsed_etrecord.edge_dialect_program |
| 491 | + ) |
| 492 | + |
| 493 | + # Validate executorch program data |
| 494 | + self.assertEqual( |
| 495 | + parsed_etrecord._debug_handle_map, |
| 496 | + json.loads(json.dumps(et_manager.debug_handle_map)), |
| 497 | + ) |
| 498 | + self.assertEqual( |
| 499 | + parsed_etrecord._delegate_map, |
| 500 | + json.loads(json.dumps(et_manager.delegate_map)), |
| 501 | + ) |
| 502 | + |
| 503 | + # Validate export graph id |
| 504 | + self.assertEqual( |
| 505 | + parsed_etrecord.export_graph_id, |
| 506 | + id(aten_program.graph), |
| 507 | + ) |
| 508 | + |
431 | 509 | def test_to_edge_transform_and_lower_etrecord_save_and_parse(self):
|
432 | 510 | """Test that ETRecord generated by to_edge_transform_and_lower can be saved and parsed."""
|
433 | 511 | f = models.BasicSinMax()
|
|
0 commit comments