@@ -681,7 +681,7 @@ def test_calculate_numeric_gap(self):
681
681
aot_debug_handle_to_op_name = {(0 ,): "op_0" , (1 ,): "op_1" }
682
682
runtime_debug_handle_to_op_name = {(0 ,): "op_0" , (1 ,): "op_1" }
683
683
684
- inspector_instance ._get_aot_intermediate_outputs_and_op_names = lambda : (
684
+ inspector_instance ._get_aot_intermediate_outputs_and_op_names = lambda x : (
685
685
aot_intermediate_outputs ,
686
686
aot_debug_handle_to_op_name ,
687
687
)
@@ -838,6 +838,122 @@ def _gen_random_runtime_output(
838
838
) -> List [Union [None , List [torch .Tensor ], bool , float , int , str , torch .Tensor ]]:
839
839
return [torch .randn (RAW_DATA_SIZE )]
840
840
841
+ def test_disable_debug_handle_validation_with_symbolic_shapes (self ):
842
+ """
843
+ Test that demonstrates the issue with symbolic shape related nodes losing from_node info
844
+ during dynamic shape based export, and shows how disable_debug_handle_valdiation parameter
845
+ in propagate_back_debug_handle allows validation to be bypassed.
846
+ """
847
+ from executorch .devtools .inspector ._inspector_utils import (
848
+ propagate_back_debug_handle ,
849
+ )
850
+
851
+ class SymbolicShapeModel (torch .nn .Module ):
852
+ """Model that will have symbolic shape related operations after export."""
853
+
854
+ def forward (self , x : torch .Tensor , mask : torch .Tensor ) -> torch .Tensor :
855
+ # This will create symbolic shape nodes during dynamic export
856
+ batch_size = x .shape [0 ]
857
+ x = x + torch .rand ((batch_size , 1 ))
858
+ # Masking operation that creates gt/lt nodes
859
+ valid_mask = mask > 0.5
860
+ x = torch .where (valid_mask , x , torch .zeros_like (x ))
861
+ return x
862
+
863
+ # Create model and dynamic inputs
864
+ model = SymbolicShapeModel ()
865
+ batch_size = 2
866
+ seq_len = 4
867
+ x = torch .randn (batch_size , seq_len )
868
+ mask = torch .rand (batch_size , seq_len )
869
+ example_inputs = (x , mask )
870
+
871
+ # Export with dynamic shapes to create symbolic shape related nodes
872
+ dynamic_shapes = {
873
+ "x" : {0 : torch .export .Dim ("batch_size" , min = 1 , max = 10 )},
874
+ "mask" : {0 : torch .export .Dim ("batch_size" , min = 1 , max = 10 )},
875
+ }
876
+
877
+ exported_program = torch .export .export (
878
+ model , example_inputs , dynamic_shapes = dynamic_shapes , strict = True
879
+ )
880
+
881
+ """
882
+ In this case origina aten graph has sym_size_int_2 node but when we look at
883
+ nodes metadata in edge_program_manager, its sym_size node's from_node says
884
+ sym_size_int_3 which is not in the original aten graph.
885
+ """
886
+ # Create edge program - this is where from_node info can be lost for symbolic shape nodes
887
+ edge_program_manager : EdgeProgramManager = to_edge (exported_program )
888
+ edge_program_manager_copy = copy .deepcopy (edge_program_manager )
889
+ et_program_manager : ExecutorchProgramManager = (
890
+ edge_program_manager .to_executorch ()
891
+ )
892
+
893
+ with tempfile .NamedTemporaryFile (suffix = ".bin" ) as tmp_file :
894
+ etrecord_path = tmp_file .name
895
+
896
+ # Generate ETRecord with the exported program (aten graph)
897
+ generate_etrecord (
898
+ etrecord_path ,
899
+ edge_program_manager_copy ,
900
+ et_program_manager ,
901
+ exported_program = exported_program ,
902
+ )
903
+
904
+ # Create Inspector and get etrecord
905
+ with patch .object (
906
+ _inspector , "gen_etdump_object" , return_value = None
907
+ ), patch .object (EventBlock , "_gen_from_etdump" ):
908
+ inspector_instance = Inspector (
909
+ etdump_path = ETDUMP_PATH ,
910
+ etrecord = etrecord_path ,
911
+ )
912
+
913
+ # Extract the necessary values from the inspector's etrecord
914
+ exported_program_from_etrecord = (
915
+ inspector_instance ._etrecord .exported_program
916
+ )
917
+ export_graph_id = inspector_instance ._etrecord .export_graph_id
918
+ edge_dialect_program = inspector_instance ._etrecord .edge_dialect_program
919
+
920
+ # Ensure we have all the necessary components
921
+ self .assertIsNotNone (exported_program_from_etrecord )
922
+ self .assertIsNotNone (export_graph_id )
923
+ self .assertIsNotNone (edge_dialect_program )
924
+
925
+ # Test propagate_back_debug_handle with validation enabled (should fail or return False)
926
+ # This demonstrates the issue with symbolic shape nodes losing from_node info
927
+ validation_enabled_result = propagate_back_debug_handle (
928
+ exported_program_from_etrecord ,
929
+ export_graph_id ,
930
+ edge_dialect_program ,
931
+ disable_debug_handle_valdiation = False ,
932
+ )
933
+
934
+ # With validation enabled, it should return False when from_node info is lost
935
+ self .assertFalse (
936
+ validation_enabled_result ,
937
+ "propagate_back_debug_handle should return False when validation is enabled "
938
+ "and symbolic shape nodes lose from_node info" ,
939
+ )
940
+
941
+ # Test propagate_back_debug_handle with validation disabled (should succeed)
942
+ # This shows how the disable_debug_handle_valdiation flag allows the function to work
943
+ validation_disabled_result = propagate_back_debug_handle (
944
+ exported_program_from_etrecord ,
945
+ export_graph_id ,
946
+ edge_dialect_program ,
947
+ disable_debug_handle_valdiation = True ,
948
+ )
949
+
950
+ # With validation disabled, it should return True even when from_node info is lost
951
+ self .assertTrue (
952
+ validation_disabled_result ,
953
+ "propagate_back_debug_handle should return True when validation is disabled, "
954
+ "allowing best effort comparison even when from_node info is lost" ,
955
+ )
956
+
841
957
def _gen_random_events (self ) -> List [Event ]:
842
958
events = []
843
959
for i in range (2 ):
0 commit comments