@@ -681,7 +681,7 @@ def test_calculate_numeric_gap(self):
681681 aot_debug_handle_to_op_name = {(0 ,): "op_0" , (1 ,): "op_1" }
682682 runtime_debug_handle_to_op_name = {(0 ,): "op_0" , (1 ,): "op_1" }
683683
684- inspector_instance ._get_aot_intermediate_outputs_and_op_names = lambda : (
684+ inspector_instance ._get_aot_intermediate_outputs_and_op_names = lambda x : (
685685 aot_intermediate_outputs ,
686686 aot_debug_handle_to_op_name ,
687687 )
@@ -838,6 +838,122 @@ def _gen_random_runtime_output(
838838 ) -> List [Union [None , List [torch .Tensor ], bool , float , int , str , torch .Tensor ]]:
839839 return [torch .randn (RAW_DATA_SIZE )]
840840
841+ def test_disable_debug_handle_validation_with_symbolic_shapes (self ):
842+ """
843+ Test that demonstrates the issue with symbolic shape related nodes losing from_node info
844+ during dynamic shape based export, and shows how disable_debug_handle_valdiation parameter
845+ in propagate_back_debug_handle allows validation to be bypassed.
846+ """
847+ from executorch .devtools .inspector ._inspector_utils import (
848+ propagate_back_debug_handle ,
849+ )
850+
851+ class SymbolicShapeModel (torch .nn .Module ):
852+ """Model that will have symbolic shape related operations after export."""
853+
854+ def forward (self , x : torch .Tensor , mask : torch .Tensor ) -> torch .Tensor :
855+ # This will create symbolic shape nodes during dynamic export
856+ batch_size = x .shape [0 ]
857+ x = x + torch .rand ((batch_size , 1 ))
858+ # Masking operation that creates gt/lt nodes
859+ valid_mask = mask > 0.5
860+ x = torch .where (valid_mask , x , torch .zeros_like (x ))
861+ return x
862+
863+ # Create model and dynamic inputs
864+ model = SymbolicShapeModel ()
865+ batch_size = 2
866+ seq_len = 4
867+ x = torch .randn (batch_size , seq_len )
868+ mask = torch .rand (batch_size , seq_len )
869+ example_inputs = (x , mask )
870+
871+ # Export with dynamic shapes to create symbolic shape related nodes
872+ dynamic_shapes = {
873+ "x" : {0 : torch .export .Dim ("batch_size" , min = 1 , max = 10 )},
874+ "mask" : {0 : torch .export .Dim ("batch_size" , min = 1 , max = 10 )},
875+ }
876+
877+ exported_program = torch .export .export (
878+ model , example_inputs , dynamic_shapes = dynamic_shapes , strict = True
879+ )
880+
881+ """
882+ In this case origina aten graph has sym_size_int_2 node but when we look at
883+ nodes metadata in edge_program_manager, its sym_size node's from_node says
884+ sym_size_int_3 which is not in the original aten graph.
885+ """
886+ # Create edge program - this is where from_node info can be lost for symbolic shape nodes
887+ edge_program_manager : EdgeProgramManager = to_edge (exported_program )
888+ edge_program_manager_copy = copy .deepcopy (edge_program_manager )
889+ et_program_manager : ExecutorchProgramManager = (
890+ edge_program_manager .to_executorch ()
891+ )
892+
893+ with tempfile .NamedTemporaryFile (suffix = ".bin" ) as tmp_file :
894+ etrecord_path = tmp_file .name
895+
896+ # Generate ETRecord with the exported program (aten graph)
897+ generate_etrecord (
898+ etrecord_path ,
899+ edge_program_manager_copy ,
900+ et_program_manager ,
901+ exported_program = exported_program ,
902+ )
903+
904+ # Create Inspector and get etrecord
905+ with patch .object (
906+ _inspector , "gen_etdump_object" , return_value = None
907+ ), patch .object (EventBlock , "_gen_from_etdump" ):
908+ inspector_instance = Inspector (
909+ etdump_path = ETDUMP_PATH ,
910+ etrecord = etrecord_path ,
911+ )
912+
913+ # Extract the necessary values from the inspector's etrecord
914+ exported_program_from_etrecord = (
915+ inspector_instance ._etrecord .exported_program
916+ )
917+ export_graph_id = inspector_instance ._etrecord .export_graph_id
918+ edge_dialect_program = inspector_instance ._etrecord .edge_dialect_program
919+
920+ # Ensure we have all the necessary components
921+ self .assertIsNotNone (exported_program_from_etrecord )
922+ self .assertIsNotNone (export_graph_id )
923+ self .assertIsNotNone (edge_dialect_program )
924+
925+ # Test propagate_back_debug_handle with validation enabled (should fail or return False)
926+ # This demonstrates the issue with symbolic shape nodes losing from_node info
927+ validation_enabled_result = propagate_back_debug_handle (
928+ exported_program_from_etrecord ,
929+ export_graph_id ,
930+ edge_dialect_program ,
931+ disable_debug_handle_valdiation = False ,
932+ )
933+
934+ # With validation enabled, it should return False when from_node info is lost
935+ self .assertFalse (
936+ validation_enabled_result ,
937+ "propagate_back_debug_handle should return False when validation is enabled "
938+ "and symbolic shape nodes lose from_node info" ,
939+ )
940+
941+ # Test propagate_back_debug_handle with validation disabled (should succeed)
942+ # This shows how the disable_debug_handle_valdiation flag allows the function to work
943+ validation_disabled_result = propagate_back_debug_handle (
944+ exported_program_from_etrecord ,
945+ export_graph_id ,
946+ edge_dialect_program ,
947+ disable_debug_handle_valdiation = True ,
948+ )
949+
950+ # With validation disabled, it should return True even when from_node info is lost
951+ self .assertTrue (
952+ validation_disabled_result ,
953+ "propagate_back_debug_handle should return True when validation is disabled, "
954+ "allowing best effort comparison even when from_node info is lost" ,
955+ )
956+
841957 def _gen_random_events (self ) -> List [Event ]:
842958 events = []
843959 for i in range (2 ):
0 commit comments