@@ -281,7 +281,6 @@ def record_bwd_graph(cls, runner):
281281 def create_cudagraphs (cls ):
282282 """Iterate through 'cudagraph_record' creating graphs in the order in which
283283 they were recorded."""
284-
285284 # Cudagraphs have already been created, check that no cudagraphed modules ran in eager mode
286285 if cls .cudagraph_created :
287286 assert len (cls .cudagraph_record ) == 0 , (
@@ -303,11 +302,11 @@ def create_cudagraphs(cls):
303302 [isinstance (m , TransformerEngineBaseModule ) for m in base_module .modules ()]
304303 )
305304
306- if torch . distributed . get_rank () == 0 :
307- time_start = time .time ()
308- mem_stats_start = torch .cuda .memory_stats ()
305+ progress_bar = enumerate ( cls . cudagraph_record )
306+ time_start = time .time ()
307+ mem_stats_start = torch .cuda .memory_stats ()
309308
310- progress_bar = enumerate ( cls . cudagraph_record )
309+ if torch . distributed . get_rank () == 0 :
311310 if HAVE_TQDM :
312311 progress_bar = tqdm (
313312 progress_bar , "create cuda graphs" , total = len (cls .cudagraph_record )
@@ -361,22 +360,22 @@ def format_mem_bytes(mem_bytes):
361360 assert fwd_buffer_reuse_ref_count == 0
362361 runner .create_bwd_graph ()
363362
364- if torch .distributed .get_rank () == 0 :
365- # Memory usage.
366- time_end = time .time ()
367- mem_stats_end = torch .cuda .memory_stats ()
368- capture_stats = {
369- "time" : time_end - time_start ,
370- "allocated_bytes" : (
371- mem_stats_end ["allocated_bytes.all.current" ]
372- - mem_stats_start ["allocated_bytes.all.current" ]
373- ),
374- "reserved_bytes" : (
375- mem_stats_end ["reserved_bytes.all.current" ]
376- - mem_stats_start ["reserved_bytes.all.current" ]
377- ),
378- }
363+ # Memory usage.
364+ time_end = time .time ()
365+ mem_stats_end = torch .cuda .memory_stats ()
366+ capture_stats = {
367+ "time" : time_end - time_start ,
368+ "allocated_bytes" : (
369+ mem_stats_end ["allocated_bytes.all.current" ]
370+ - mem_stats_start ["allocated_bytes.all.current" ]
371+ ),
372+ "reserved_bytes" : (
373+ mem_stats_end ["reserved_bytes.all.current" ]
374+ - mem_stats_start ["reserved_bytes.all.current" ]
375+ ),
376+ }
379377
378+ if torch .distributed .get_rank () == 0 :
380379 logger .info (
381380 "> built %d cuda graph(s) in %.2f sec, with total memory usage: "
382381 "allocated %s, reserved %s."
0 commit comments