4444 TimeScale ,
4545)
4646from executorch .devtools .inspector .tests .inspector_test_utils import (
47+ check_if_debug_handle_to_op_name_match ,
4748 check_if_final_outputs_match ,
4849 model_registry ,
4950)
@@ -468,25 +469,7 @@ def test_populate_debugging_related_fields_passes_for_consistent_events(self):
468469 events = events ,
469470 )
470471
471- def test_no_capture_when_representative_inputs_are_none (self ):
472- # Create a context manager to patch functions called by Inspector.__init__
473- with patch .object (
474- _inspector , "parse_etrecord" , return_value = None
475- ), patch .object (
476- _inspector , "gen_etdump_object" , return_value = None
477- ), patch .object (
478- EventBlock , "_gen_from_etdump"
479- ), patch .object (
480- _inspector , "gen_graphs_from_etrecord"
481- ):
482- # Call the constructor of Inspector
483- inspector_instance = Inspector (
484- etdump_path = ETDUMP_PATH ,
485- etrecord = ETRECORD_PATH ,
486- )
487- self .assertIsNone (inspector_instance ._aot_intermediate_outputs )
488-
489- def test_consume_etrecord_populates_correct_aot_intermediate_outputs (self ):
472+ def test_etrecord_populates_correct_aot_intermediate_outputs (self ):
490473 with tempfile .NamedTemporaryFile (suffix = ".bin" ) as tmp_file :
491474 etrecord_path = tmp_file .name
492475 mod = model_registry ["ConvLinearModel" ]()
@@ -505,7 +488,6 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self):
505488 generate_etrecord (
506489 etrecord_path , edge_program_manager_copy , et_program_manager
507490 )
508- original_consume_etrecord = Inspector ._consume_etrecord
509491 with patch .object (
510492 Inspector , "_consume_etrecord" , return_value = None
511493 ), patch .object (
@@ -529,11 +511,17 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self):
529511 _representative_inputs = aten_model .example_inputs [0 ],
530512 )
531513 inspector_instance ._etrecord = etrecord
532- Inspector ._consume_etrecord = original_consume_etrecord
533- inspector_instance ._consume_etrecord ()
514+ aot_intermediate_outputs , aot_debug_handle_to_op_name = (
515+ inspector_instance ._get_aot_intermediate_outputs_and_op_names ()
516+ )
534517 self .assertTrue (
535518 check_if_final_outputs_match (
536- "ConvLinearModel" , inspector_instance ._aot_intermediate_outputs
519+ "ConvLinearModel" , aot_intermediate_outputs
520+ )
521+ )
522+ self .assertTrue (
523+ check_if_debug_handle_to_op_name_match (
524+ "ConvLinearModel" , aot_debug_handle_to_op_name
537525 )
538526 )
539527
@@ -605,6 +593,7 @@ def test_calculate_numeric_gap(self):
605593 ), patch .object (
606594 _inspector , "gen_graphs_from_etrecord"
607595 ):
596+
608597 # Call the constructor of Inspector
609598 inspector_instance = Inspector (
610599 etdump_path = ETDUMP_PATH ,
@@ -621,43 +610,44 @@ def test_calculate_numeric_gap(self):
621610 (1 ,): torch .tensor ([3.0 , 6.0 , 5.0 ]),
622611 }
623612
624- inspector_instance ._aot_intermediate_outputs = aot_intermediate_outputs
613+ aot_debug_handle_to_op_name = {(0 ,): "op_0" , (1 ,): "op_1" }
614+ runtime_debug_handle_to_op_name = {(0 ,): "op_0" , (1 ,): "op_1" }
615+
616+ inspector_instance ._get_aot_intermediate_outputs_and_op_names = lambda : (
617+ aot_intermediate_outputs ,
618+ aot_debug_handle_to_op_name ,
619+ )
625620 inspector_instance ._get_runtime_intermediate_outputs_and_op_names = (
626- lambda : (runtime_intermediate_outputs , {} )
621+ lambda : (runtime_intermediate_outputs , runtime_debug_handle_to_op_name )
627622 )
628623
629624 df = inspector_instance .calculate_numeric_gap (distance = "L1" )
630625 self .assertIsInstance (df , pd .DataFrame )
631626 self .assertEqual (len (df ), 2 )
632627 cols = set (df .columns )
633628 expected_cols = {
634- "aot_debug_handle " ,
629+ "aot_ops " ,
635630 "aot_intermediate_output" ,
636- "runtime_debug_handle " ,
631+ "runtime_ops " ,
637632 "runtime_intermediate_output" ,
638633 "gap" ,
639634 }
640635 self .assertEqual (cols , expected_cols )
641- founded_aot_debug_handle = set (df ["aot_debug_handle" ])
642- self .assertEqual (
643- founded_aot_debug_handle , set (aot_intermediate_outputs .keys ())
644- )
645- for _ , row in df .iterrows ():
646- aot_debuh_handle = row ["aot_debug_handle" ]
636+ for i , row in df .iterrows ():
637+ # Dummpy key to get the expected aot/runtime internmediate outputs
638+ key = (i ,)
647639 # aot_intermediate_output should equal aot_intermediate_outputs[h]
648640 self .assertTrue (
649641 torch .allclose (
650642 row ["aot_intermediate_output" ],
651- aot_intermediate_outputs [aot_debuh_handle ],
643+ aot_intermediate_outputs [key ],
652644 )
653645 )
654- # runtime_debug_hanlde equals aot_debug_handle at this case
655- self .assertEqual (row ["runtime_debug_handle" ], aot_debuh_handle )
656646 # runtime_intermediate_output should equal runtime_intermediate_outputs[h]
657647 self .assertTrue (
658648 torch .allclose (
659649 row ["runtime_intermediate_output" ],
660- runtime_intermediate_outputs [aot_debuh_handle ],
650+ runtime_intermediate_outputs [key ],
661651 )
662652 )
663653 # gap should equal 3.0
0 commit comments