1212import os
1313from collections import defaultdict
1414from contextlib import contextmanager
15+ from dataclasses import dataclass
16+ from functools import total_ordering
1517from time import perf_counter
1618from typing import (
1719 Any ,
2224 Protocol ,
2325 runtime_checkable ,
2426 Sequence ,
25- Tuple ,
2627)
2728
2829import numpy as np
3334from torch .distributed .distributed_c10d import Work
3435from torchtnt .utils .distributed import PGWrapper
3536
36- logger : logging .Logger = logging .getLogger (__name__ )
3737
38- _TABLE_ROW = Tuple [str , float , int , float , float ]
39- _TABLE_DATA = List [_TABLE_ROW ]
38+ logger : logging .Logger = logging .getLogger (__name__ )
4039
4140
4241@contextmanager
@@ -69,6 +68,28 @@ def log_elapsed_time(
6968 logger .info (f"{ action_name } took { interval_time } seconds" )
7069
7170
71+ @total_ordering
72+ @dataclass
73+ class TimedActionStats :
74+ """Dataclass for storing timed action stats. These can be consumed by report generation methods, so metrics should be aggregated."""
75+
76+ action_name : str
77+ mean_duration : float = 0.0
78+ num_calls : int = 0
79+ total_duration : float = 0.0
80+ percentage_of_total_time : float = 0.0
81+
82+ def __le__ (self , other : "TimedActionStats" ) -> bool :
83+ return self .percentage_of_total_time <= other .percentage_of_total_time
84+
85+
86+ @dataclass
87+ class TimerReport :
88+ timed_action_stats : List [TimedActionStats ]
89+ total_calls : int
90+ total_duration : float
91+
92+
7293@runtime_checkable
7394class TimerProtocol (Protocol ):
7495 """
@@ -194,31 +215,30 @@ def _apply_bounds(self, action_name: str) -> None:
194215 )
195216
196217
197- def _get_total_time ( timer : TimerProtocol ) -> float :
218+ def _make_report ( self : TimerProtocol ) -> TimerReport :
198219 total_time = 0.0
199- for _ , durations in timer .recorded_durations .items ():
220+ for _ , durations in self .recorded_durations .items ():
200221 array_value = np .array (durations )
201222 array_sum = np .sum (array_value )
202223 total_time += array_sum
203224
204- return total_time
205-
206-
207- def _make_report (timer : TimerProtocol ) -> Tuple [_TABLE_DATA , float , float ]:
208- total_time = _get_total_time (timer )
209- report = [
210- (
211- a ,
212- np .mean (d ),
213- len (d ),
214- np .sum (d ),
215- 100.0 * np .sum (d ) / total_time ,
225+ action_stats = [
226+ TimedActionStats (
227+ action_name = a ,
228+ mean_duration = np .mean (d ),
229+ num_calls = len (d ),
230+ total_duration = np .sum (d ),
231+ percentage_of_total_time = 100.0 * np .sum (d ) / total_time ,
216232 )
217- for a , d in timer .recorded_durations .items ()
233+ for a , d in self .recorded_durations .items ()
218234 ]
219- report .sort (key = lambda x : x [4 ], reverse = True )
220- total_calls = sum (x [2 ] for x in report )
221- return report , total_calls , total_time
235+ action_stats .sort (reverse = True )
236+ total_calls = sum (x .num_calls for x in action_stats )
237+ return TimerReport (
238+ timed_action_stats = action_stats ,
239+ total_calls = total_calls ,
240+ total_duration = total_time ,
241+ )
222242
223243
224244def get_timer_summary (timer : TimerProtocol ) -> str :
@@ -231,13 +251,16 @@ def get_timer_summary(timer: TimerProtocol) -> str:
231251 ValueError
232252 If the input Timer has no recorded actions
233253 """
254+ report : TimerReport = _make_report (timer )
255+
234256 sep : str = os .linesep
235257 output_string = f"Timer Report{ sep } "
236258
237- if len (timer .recorded_durations ) == 0 :
259+ # Handle empty timer case
260+ if not report .timed_action_stats :
238261 return output_string
239262
240- max_key = max (len (k ) for k in timer . recorded_durations . keys () )
263+ max_key = max (len (a . action_name ) for a in report . timed_action_stats )
241264
242265 # pyre-fixme[53]: Captured variable `max_key` is not annotated.
243266 def log_row (action : str , mean : str , num_calls : str , total : str , per : str ) -> str :
@@ -252,32 +275,23 @@ def log_row(action: str, mean: str, num_calls: str, total: str, per: str) -> str
252275 "Total time (s)" ,
253276 "Percentage %" ,
254277 )
278+
255279 output_string_len = len (header_string .expandtabs ()) - 1
256280 sep_lines = f"{ sep } { '-' * output_string_len } "
257281 output_string += sep_lines + header_string + sep_lines
258- report : _TABLE_DATA
259- (
260- report ,
261- total_calls ,
262- total_duration ,
263- ) = _make_report (timer )
282+
264283 output_string += log_row (
265- "Total" , "-" , f"{ total_calls :} " , f"{ total_duration :.5} " , "100 %"
284+ "Total" , "-" , f"{ report . total_calls :} " , f"{ report . total_duration :.5} " , "100 %"
266285 )
267286 output_string += sep_lines
268- for (
269- action ,
270- mean_duration ,
271- num_calls ,
272- total_duration ,
273- duration_per ,
274- ) in report :
287+
288+ for action in report .timed_action_stats :
275289 output_string += log_row (
276- action ,
277- f"{ mean_duration :.5} " ,
278- f"{ num_calls } " ,
279- f"{ total_duration :.5} " ,
280- f"{ duration_per :.5} " ,
290+ action . action_name ,
291+ f"{ action . mean_duration :.5} " ,
292+ f"{ action . num_calls } " ,
293+ f"{ action . total_duration :.5} " ,
294+ f"{ action . percentage_of_total_time :.5} " ,
281295 )
282296 output_string += sep_lines
283297
0 commit comments