11import functools
2+ import logging
3+ import sys
24import time
35from typing import Callable , Optional
46from mpi4py import MPI
79# TODO (tharitt): later move to env file or something
810ENABLE_BENCHMARK = True
911
12+ logging .basicConfig (level = logging .INFO , force = True )
1013# Stack of active mark functions for nested support
1114_mark_func_stack = []
1215_markers = []
1316
1417
15- def _parse_output_tree (markers ):
18+ def _parse_output_tree (markers : list [ str ] ):
1619 output = []
1720 stack = []
1821 i = 0
1922 while i < len (markers ):
2023 label , time , level = markers [i ]
2124 if label .startswith ("[header]" ):
22- output .append (f"{ "\t " * (level - 1 )} { label } : total runtime: { time :6f} \n " )
25+ output .append (f"{ "\t " * (level - 1 )} { label } : total runtime: { time :6f} s \n " )
2326 else :
2427 if stack :
2528 prev_label , prev_time , prev_level = stack [- 1 ]
2629 if prev_level == level :
27- output .append (f"{ "\t " * level } { prev_label } -->{ label } : { time - prev_time : 6f} \n " )
30+ output .append (f"{ "\t " * level } { prev_label } -->{ label } : { time - prev_time :6f} s \n " )
2831 stack .pop ()
2932
30- # Push to the stack only if it is going deeper or the same level
33+ # Push to the stack only if it is going deeper or still at the same level
3134 if i + 1 < len (markers ) - 1 :
3235 _ , _ , next_level = markers [i + 1 ]
3336 if next_level >= level :
@@ -86,10 +89,6 @@ def decorator(func):
8689 def wrapper (* args , ** kwargs ):
8790 rank = MPI .COMM_WORLD .Get_rank ()
8891
89- # Here we rely on the closure property of Python.
90- # This marks will isolate from (shadow) the marks previously
91- # defined in the function currently on top of the _mark_func_stack.
92-
9392 level = len (_mark_func_stack ) + 1
9493 # The header is needed for later tree parsing. Here it is allocating its spot.
9594 # the tuple at this index will be replaced after elapsed time is calculated.
@@ -114,11 +113,22 @@ def local_mark(label):
114113 # the top of the stack.
115114 _mark_func_stack .pop ()
116115
117- # finish all the calls
116+ # all the calls have fininshed
118117 if not _mark_func_stack :
119118 if rank == 0 :
120119 output = _parse_output_tree (_markers )
121- print ("" .join (output ))
120+ logger = logging .getLogger ()
121+ # remove the stdout
122+ for h in logger .handlers [:]:
123+ logger .removeHandler (h )
124+ handler = logging .FileHandler (file_path , mode = 'w' ) if save_file else logging .StreamHandler (sys .stdout )
125+ handler .setLevel (logging .INFO )
126+ logger .addHandler (handler )
127+ logger .info ("" .join (output ))
128+ logger .removeHandler (handler )
129+ if save_file :
130+ handler .close ()
131+
122132 return result
123133 return wrapper
124134 if func is not None :
0 commit comments