77# TODO (tharitt): later move to env file or something 
88ENABLE_BENCHMARK  =  True 
99
10+ # Stack of active mark functions for nested support 
11+ _mark_func_stack  =  []
12+ _markers  =  []
13+ 
14+ 
15+ def  _parse_output_tree (markers ):
16+     output  =  []
17+     stack  =  []
18+     i  =  0 
19+     while  i  <  len (markers ):
20+         label , time , level  =  markers [i ]
21+         if  label .startswith ("[header]" ):
22+             output .append (f"{ "\t "  *  (level  -  1 )} { label } { time :6f} \n " )
23+         else :
24+             if  stack :
25+                 prev_label , prev_time , prev_level  =  stack [- 1 ]
26+                 if  prev_level  ==  level :
27+                     output .append (f"{ "\t "  *  level } { prev_label } { label } { time  -  prev_time : 6f} \n " )
28+                     stack .pop ()
29+ 
30+             # Push to the stack only if it is going deeper or the same level 
31+             if  i  +  1  <  len (markers ) -  1 :
32+                 _ , _  , next_level  =  markers [i  +  1 ]
33+                 if  next_level  >=  level :
34+                     stack .append (markers [i ])
35+         i  +=  1 
36+     return  output 
37+ 
1038
11- # This function allows users to measure time arbitary lines of the function 
1239def  mark (label ):
40+     """This function allows users to measure time arbitary lines of the function 
41+ 
42+     Parameters 
43+     ---------- 
44+     label: :obj:`str` 
45+         A label of the mark. This signifies both 1) the end of the 
46+         previous mark 2) the beginning of the new mark 
47+     """ 
1348    if  not  _mark_func_stack :
1449        raise  RuntimeError ("mark() called outside of a benchmarked region" )
1550    _mark_func_stack [- 1 ](label )
1651
1752
18- # Stack of active mark functions for nested support 
19- _mark_func_stack  =  []
20- 
21- 
2253def  benchmark (func : Optional [Callable ] =  None ,
2354              description = "" ,
2455              save_file = False ,
@@ -28,9 +59,9 @@ def benchmark(func: Optional[Callable] = None,
2859
2960    This wrapper measure the start-to-end time of the wrapped function when 
3061    decorated without any argument. 
31-     It also allows users to put a call to mark() anywhere inside the wrapped function   
62+     It also allows users to put a call to mark() anywhere inside the wrapped function 
3263    for fine-grain time benchmark. This wrapper defines the local_mark() and push it 
33-     the the _mark_func_stack for isolation in case of nested call.   
64+     the the _mark_func_stack for isolation in case of nested call. 
3465    The user-facing mark() will always call the function at the top of the _mark_func_stack. 
3566
3667    Parameters 
@@ -58,38 +89,35 @@ def wrapper(*args, **kwargs):
5889            # Here we rely on the closure property of Python. 
5990            # This marks will isolate from (shadow) the marks previously 
6091            # defined in the function currently on top of the _mark_func_stack. 
61-             marks  =  []
92+ 
93+             level  =  len (_mark_func_stack ) +  1 
94+             # The header is needed for later tree parsing. Here it is allocating its spot. 
95+             # the tuple at this index will be replaced after elapsed time is calculated. 
96+             _markers .append ((f"[header]{ description  or  func .__name__ }  , None , level ))
97+             header_index  =  len (_markers ) -  1 
6298
6399            def  local_mark (label ):
64-                 marks .append ((label , time .perf_counter ()))
100+                 _markers .append ((label , time .perf_counter (), level ))
101+ 
65102            _mark_func_stack .append (local_mark )
66103
67104            start_time  =  time .perf_counter ()
68105            # the mark() called in wrapped function will now call local_mark 
69106            result  =  func (* args , ** kwargs )
70107            end_time  =  time .perf_counter ()
71108
72-             _mark_func_stack .pop ()
73- 
74-             output  =  []
75-             # start-to-end time 
76109            elapsed  =  end_time  -  start_time 
110+             _markers [header_index ] =  (f"[header]{ description  or  func .__name__ }  , elapsed , level )
111+ 
112+             # In case of nesting, the wrapped callee must pop its closure from stack so that 
113+             # when the callee returns, the wrapped caller operates on its closure (and its level label), which now becomes 
114+             # the top of the stack. 
115+             _mark_func_stack .pop ()
77116
78-             # TODO (tharitt): Both MPI + NCCL collective calls have implicit synchronization inside the stream 
79-             # So, output only from rank=0 should suffice. We can add per-rank output later on if makes sense. 
80-             if  rank  ==  0 :
81-                 level  =  len (_mark_func_stack )
82-                 output .append (f"{ '---'  *  level } { description  or  func .__name__ } { elapsed :6f} { rank } \n " )
83-                 if  marks :
84-                     prev_label , prev_t  =  marks [0 ]
85-                     for  label , t  in  marks [1 :]:
86-                         output .append (f"{ '---'  *  level } { prev_label } { label } { t  -  prev_t :.6f} \n " )
87-                         prev_label , prev_t  =  label , t 
88- 
89-                 if  save_file :
90-                     with  open (file_path , "a" ) as  f :
91-                         f .write ("" .join (output ))
92-                 else :
117+             # finish all the calls 
118+             if  not  _mark_func_stack :
119+                 if  rank  ==  0 :
120+                     output  =  _parse_output_tree (_markers )
93121                    print ("" .join (output ))
94122            return  result 
95123        return  wrapper 
0 commit comments