@@ -198,6 +198,9 @@ def __init__(self, enable_micro_timers, enable_global_timers):
198198 STEP_GLOBAL_TIMER
199199 ]
200200
201+ def active_timers (self ):
202+ return self .micro_timers + self .global_timers
203+
201204
202205class DeepSpeedEngine (Module ):
203206 r"""DeepSpeed engine for training."""
@@ -407,6 +410,8 @@ def __init__(self,
407410 enable_global_timers = self .wall_clock_breakdown ()
408411 or self .flops_profiler_enabled ())
409412
413+ self .engine_timers_cache = {}
414+
410415 if self .global_rank == 0 :
411416 self ._config .print ("DeepSpeedEngine configuration" )
412417 if self .dump_state ():
@@ -2334,7 +2339,10 @@ def _backward_prologue(self):
23342339 if self .zero_optimization ():
23352340 self .optimizer .is_gradient_accumulation_boundary = self .is_gradient_accumulation_boundary ()
23362341
2342+ self ._start_timers (self .engine_timers .backward_inner_timers )
2343+
23372344 def _backward_epilogue (self ):
2345+ self ._stop_timers (self .engine_timers .backward_inner_timers )
23382346 self ._start_timers (self .engine_timers .backward_reduce_timers )
23392347 if self .enable_backward_allreduce and not self .inside_no_sync_ctxt :
23402348 # Traditional code path that allreduces the module parameter grads
@@ -2723,6 +2731,9 @@ def step(self, lr_kwargs=None):
27232731 self ._autotuning_exit ()
27242732
27252733 if self .wall_clock_breakdown ():
2734+ # Update client accessible wall clock timers cache
2735+ self ._update_wall_clock_timers ()
2736+
27262737 # Log micro timing and reset
27272738 self .timers .log (names = self .engine_timers .micro_timers , memory_breakdown = self .memory_breakdown ())
27282739
@@ -2752,6 +2763,17 @@ def _stop_timers(self, timer_names):
27522763 for name in timer_names :
27532764 self .timers (name ).stop (record = record )
27542765
2766+ def _update_wall_clock_timers (self ):
2767+ self .engine_timers_cache = {}
2768+ for name in self .engine_timers .active_timers ():
2769+ self .engine_timers_cache [name ] = self .timers (name ).elapsed (reset = False )
2770+
2771+ def get_wall_clock_timers (self ):
2772+ r"""
2773+ Return a dict snapshot of the Engine's wall clock timers.
2774+ """
2775+ return self .engine_timers_cache
2776+
27552777 def _autotuning_exit (self ):
27562778 if self .global_rank == 0 :
27572779 msg = self .timers .get_mean ([
0 commit comments