|
10 | 10 | import unittest |
11 | 11 | from typing import Dict, Tuple |
12 | 12 |
|
13 | | -import torch |
| 13 | +import executorch.exir.tests.models as models |
14 | 14 |
|
| 15 | +import torch |
15 | 16 | from executorch.devtools import generate_etrecord, parse_etrecord |
16 | 17 |
|
17 | 18 | from executorch.devtools.debug_format.base_schema import ( |
|
41 | 42 | map_runtime_aot_intermediate_outputs, |
42 | 43 | merge_runtime_overlapping_debug_handles, |
43 | 44 | NodeFilter, |
| 45 | + propagate_back_debug_handle, |
44 | 46 | TimeScale, |
45 | 47 | ) |
46 | 48 | from executorch.devtools.inspector.numerical_comparator import L1Comparator |
| 49 | +from executorch.exir import to_edge |
| 50 | +from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY |
| 51 | +from torch.export import export |
47 | 52 |
|
48 | 53 |
|
49 | 54 | class TestInspectorUtils(unittest.TestCase): |
@@ -583,6 +588,113 @@ def test_compare_intermediate_outputs_sequence_and_non_sequence(self): |
583 | 588 | with self.assertRaises(ValueError): |
584 | 589 | compare_intermediate_outputs(a, b, L1Comparator()) |
585 | 590 |
|
| 591 | + def test_equip_debug_handle_to_export_program_success(self): |
| 592 | + """Test that propagate_back_debug_handle returns True and properly equips debug handles.""" |
| 593 | + # Create a test model |
| 594 | + model = models.FeedForwardBlock(5, 10) |
| 595 | + inputs = (torch.rand(5, 5),) |
| 596 | + |
| 597 | + # Export the model |
| 598 | + exported_program = export(model, inputs) |
| 599 | + export_graph_id = id(exported_program.graph) |
| 600 | + |
| 601 | + # Convert to edge dialect |
| 602 | + edge_dialect_program = to_edge(exported_program).exported_program() |
| 603 | + |
| 604 | + # Call propagate_back_debug_handle |
| 605 | + result = propagate_back_debug_handle( |
| 606 | + exported_program, export_graph_id, edge_dialect_program |
| 607 | + ) |
| 608 | + |
| 609 | + self.assertTrue(result) |
| 610 | + |
| 611 | + # Check that debug handles are properly equipped in the exported program |
| 612 | + exported_program_debug_handles = [] |
| 613 | + for node in exported_program.graph.nodes: |
| 614 | + if node.op not in ("placeholder", "output"): |
| 615 | + self.assertIn(DEBUG_HANDLE_KEY, node.meta) |
| 616 | + self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY]) |
| 617 | + exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY]) |
| 618 | + |
| 619 | + edge_dialect_program_debug_handles = [] |
| 620 | + for node in edge_dialect_program.graph.nodes: |
| 621 | + if node.op not in ("placeholder", "output"): |
| 622 | + self.assertIn(DEBUG_HANDLE_KEY, node.meta) |
| 623 | + self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY]) |
| 624 | + edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY]) |
| 625 | + |
| 626 | + # The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem) |
| 627 | + # So they should have the same debug handle |
| 628 | + self.assertEqual( |
| 629 | + exported_program_debug_handles[0], edge_dialect_program_debug_handles[0] |
| 630 | + ) |
| 631 | + self.assertEqual( |
| 632 | + exported_program_debug_handles[0], edge_dialect_program_debug_handles[1] |
| 633 | + ) |
| 634 | + |
| 635 | + def test_equip_debug_handle_to_export_program_failure(self): |
| 636 | + """Test that propagate_back_debug_handle returns False when there's a mismatch.""" |
| 637 | + # Create a test model |
| 638 | + model = models.FeedForwardBlock(5, 10) |
| 639 | + inputs = (torch.rand(5, 5),) |
| 640 | + |
| 641 | + exported_program = export(model, inputs) |
| 642 | + edge_dialect_program = to_edge(exported_program).exported_program() |
| 643 | + |
| 644 | + # Create a different exported program (reexport) to cause mismatch |
| 645 | + reexported_program = export(model, inputs) |
| 646 | + reexport_graph_id = id(reexported_program.graph) |
| 647 | + |
| 648 | + # Call propagate_back_debug_handle with mismatched programs |
| 649 | + # This should return False because the reexported program has different node identifiers |
| 650 | + result = propagate_back_debug_handle( |
| 651 | + reexported_program, reexport_graph_id, edge_dialect_program |
| 652 | + ) |
| 653 | + |
| 654 | + # Check that it returns False due to mismatch |
| 655 | + self.assertFalse(result) |
| 656 | + |
| 657 | + def test_equip_debug_handle_to_export_program_op_to_be_removed_in_to_edge(self): |
| 658 | + """Test that propagate_back_debug_handle returns True and properly equips debug handles when an op is removed in to_edge""" |
| 659 | + |
| 660 | + class M(torch.nn.Module): |
| 661 | + """ |
| 662 | + Simple model with ops that will be removed in to_edge |
| 663 | + """ |
| 664 | + |
| 665 | + def __init__(self) -> None: |
| 666 | + super().__init__() |
| 667 | + |
| 668 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 669 | + x = x + 1 |
| 670 | + x = x.to(x.dtype) |
| 671 | + x = x + 1 |
| 672 | + return x |
| 673 | + |
| 674 | + inputs = (torch.rand(5, 5),) |
| 675 | + exported_program = torch.export.export(M(), inputs) |
| 676 | + export_graph_id = id(exported_program.graph) |
| 677 | + edge_dialect_program = to_edge(exported_program).exported_program() |
| 678 | + |
| 679 | + self.assertTrue( |
| 680 | + propagate_back_debug_handle( |
| 681 | + exported_program, export_graph_id, edge_dialect_program |
| 682 | + ) |
| 683 | + ) |
| 684 | + |
| 685 | + # only two add ops in the exported program will keep in edge dialect program, so the debug handles for removed op will be three |
| 686 | + debug_handle_for_removed_node = 3 |
| 687 | + |
| 688 | + for node in exported_program.graph.nodes: |
| 689 | + if node.name == "add": |
| 690 | + self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 1) |
| 691 | + elif node.name == "add_1": |
| 692 | + self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 2) |
| 693 | + elif node.op not in ("placeholder", "output"): |
| 694 | + self.assertEqual( |
| 695 | + node.meta[DEBUG_HANDLE_KEY], debug_handle_for_removed_node |
| 696 | + ) |
| 697 | + |
586 | 698 |
|
587 | 699 | def gen_mock_operator_graph_with_expected_map() -> ( |
588 | 700 | Tuple[OperatorGraph, Dict[int, OperatorNode]] |
|
0 commit comments