77# pyre-unsafe
88
99import copy
10+ import torch
1011import random
1112import statistics
1213import tempfile
1314import unittest
1415from contextlib import redirect_stdout
1516
16- from typing import Callable , List
17+ from typing import Callable , List , Union
1718
1819from unittest .mock import patch
1920
5657
5758OP_TYPE = "aten::add"
5859EVENT_BLOCK_NAME = "block_0"
59- EVENTS_SIZE = 5
60+ EVENTS_SIZE = 10
6061RAW_DATA_SIZE = 10
6162ETDUMP_PATH = "unittest_etdump_path"
6263ETRECORD_PATH = "unittest_etrecord_path"
@@ -72,7 +73,7 @@ def test_perf_data(self) -> None:
7273 self .assertAlmostEqual (perfData .p50 , statistics .median (random_floats ))
7374
7475 def test_event_block_to_dataframe (self ) -> None :
75- eventBlock = EventBlock (name = EVENT_BLOCK_NAME , events = self ._gen_random_events ())
76+ eventBlock = EventBlock (name = EVENT_BLOCK_NAME , events = self ._gen_events ())
7677
7778 df = eventBlock .to_dataframe ()
7879 # Check some fields of the returned dataframe
@@ -154,7 +155,7 @@ def test_inspector_print_data_tabular(self):
154155 # The mock inspector instance starts with having an empty event blocks list.
155156 # Add non-empty event blocks to test print_data_tabular().
156157 inspector_instance .event_blocks = [
157- EventBlock (name = EVENT_BLOCK_NAME , events = self ._gen_random_events ())
158+ EventBlock (name = EVENT_BLOCK_NAME , events = self ._gen_events ())
158159 ]
159160 # Call print_data_tabular(), make sure it doesn't crash
160161 with redirect_stdout (None ):
@@ -535,17 +536,111 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self):
535536 )
536537 )
537538
539+ def test_get_runtime_intermediate_outputs (self ):
540+ # Create a context manager to patch functions called by Inspector.__init__
541+ with patch .object (
542+ _inspector , "parse_etrecord" , return_value = None
543+ ), patch .object (
544+ _inspector , "gen_etdump_object" , return_value = None
545+ ), patch .object (
546+ EventBlock , "_gen_from_etdump"
547+ ), patch .object (
548+ _inspector , "gen_graphs_from_etrecord"
549+ ):
550+ # Call the constructor of Inspector
551+ inspector_instance = Inspector (
552+ etdump_path = ETDUMP_PATH ,
553+ etrecord = ETRECORD_PATH ,
554+ )
555+
556+ # The mock inspector instance starts with having an empty event blocks list.
557+ # Add pre-defined event blocks to test _get_runtime_outputs().
558+ inspector_instance .event_blocks = [
559+ EventBlock (name = EVENT_BLOCK_NAME , events = self ._gen_events ())
560+ ]
561+
562+ runtime_outputs = inspector_instance ._get_runtime_intermediate_outputs ()
563+ # This output should be a dictionary with 5 keys
564+ self .assertEqual (len (runtime_outputs ), 5 , )
565+ # Check that keys (0,) and (1,) are not in the dictionary(skip OPERATOR_CALL and op_types are empty)
566+ self .assertNotIn ((0 ,), runtime_outputs )
567+ self .assertNotIn ((1 ,), runtime_outputs )
568+
569+ # Same debug_handle but different instruction_id, should record the last one
570+ self .assertIn ((4 ,), runtime_outputs )
571+ self .assertTrue (torch .equal (runtime_outputs [(4 ,)][0 ], torch .tensor ([4.0 , 5.0 , 6.0 ])))
572+ # Check that keys (5,) to (8,) are in the dictionary and have values of the correct size
573+ for key in range (5 , 9 ):
574+ self .assertIn ((key ,), runtime_outputs )
575+ self .assertEqual (len (runtime_outputs [(key ,)]), RAW_DATA_SIZE )
576+
538577 def _gen_random_float_list (self ) -> List [float ]:
539578 return [random .uniform (0 , 10 ) for _ in range (RAW_DATA_SIZE )]
540579
541- def _gen_random_events (self ) -> List [Event ]:
580+ def _gen_random_runtime_output (self ) -> List [Union [None , List [torch .Tensor ], bool , float , int , str , torch .Tensor ]]:
581+ return list (torch .randn (RAW_DATA_SIZE ))
582+
583+ def _gen_events (self ) -> List [Event ]:
542584 events = []
543- for i in range (EVENTS_SIZE ):
585+ for i in range (2 ):
586+ events .append (
587+ # OPERATOR_CALL with debug_hanldes/instruction_id 0 and 2
588+ Event (
589+ name = "OPERATOR_CALL" ,
590+ op_types = [OP_TYPE ],
591+ perf_data = PerfData (self ._gen_random_float_list ()),
592+ debug_handles = i * 2 ,
593+ _instruction_id = i * 2 ,
594+ debug_data = self ._gen_random_runtime_output ()
595+ )
596+ )
597+ events .append (
598+ # op_0/op_1 wiht empty op_types and with debug_hanldes/instruction_id 1 and 3
599+ Event (
600+ name = f"op_{ i } " ,
601+ op_types = [],
602+ perf_data = PerfData (self ._gen_random_float_list ()),
603+ debug_handles = i * 2 + 1 ,
604+ _instruction_id = i * 2 + 1 ,
605+ debug_data = self ._gen_random_runtime_output ()
606+ )
607+ )
608+
609+ # op_2 with debug_hanldes/instruction_id 4
610+ events .append (
611+ Event (
612+ name = f"op_2" ,
613+ op_types = [OP_TYPE ],
614+ perf_data = PerfData (self ._gen_random_float_list ()),
615+ debug_handles = 4 ,
616+ debug_data = [torch .tensor ([1.0 , 2.0 , 3.0 ])],
617+ _instruction_id = 4
618+
619+ )
620+ )
621+ # op_3 also with debug_hanldes 4 but with instruction_id 5
622+ events .append (
623+ Event (
624+ name = f"op_3" ,
625+ op_types = [OP_TYPE ],
626+ perf_data = PerfData (self ._gen_random_float_list ()),
627+ debug_handles = 4 ,
628+ debug_data = [torch .tensor ([4.0 , 5.0 , 6.0 ])],
629+ _instruction_id = 5
630+
631+ )
632+ )
633+
634+ # op_4 to op_7 with debug_hanldes 5 to 8 and instruction_id 6 to 9
635+ for i in range (4 , EVENTS_SIZE - 2 ):
544636 events .append (
545637 Event (
546638 name = f"op_{ i } " ,
547639 op_types = [OP_TYPE ],
548640 perf_data = PerfData (self ._gen_random_float_list ()),
641+ debug_handles = i + 1 ,
642+ debug_data = self ._gen_random_runtime_output (),
643+ _instruction_id = i + 2
549644 )
550645 )
551646 return events
0 commit comments