Skip to content

Commit 2a1ccfa

Browse files
committed
fix logging bug
Signed-off-by: Jieming Zhang <[email protected]>
1 parent 7b5e0fa commit 2a1ccfa

File tree

1 file changed

+19
-20
lines changed

1 file changed

+19
-20
lines changed

megatron/core/transformer/cuda_graphs.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,6 @@ def record_bwd_graph(cls, runner):
281281
def create_cudagraphs(cls):
282282
"""Iterate through 'cudagraph_record' creating graphs in the order in which
283283
they were recorded."""
284-
285284
# Cudagraphs have already been created, check that no cudagraphed modules ran in eager mode
286285
if cls.cudagraph_created:
287286
assert len(cls.cudagraph_record) == 0, (
@@ -303,11 +302,11 @@ def create_cudagraphs(cls):
303302
[isinstance(m, TransformerEngineBaseModule) for m in base_module.modules()]
304303
)
305304

306-
if torch.distributed.get_rank() == 0:
307-
time_start = time.time()
308-
mem_stats_start = torch.cuda.memory_stats()
305+
progress_bar = enumerate(cls.cudagraph_record)
306+
time_start = time.time()
307+
mem_stats_start = torch.cuda.memory_stats()
309308

310-
progress_bar = enumerate(cls.cudagraph_record)
309+
if torch.distributed.get_rank() == 0:
311310
if HAVE_TQDM:
312311
progress_bar = tqdm(
313312
progress_bar, "create cuda graphs", total=len(cls.cudagraph_record)
@@ -361,22 +360,22 @@ def format_mem_bytes(mem_bytes):
361360
assert fwd_buffer_reuse_ref_count == 0
362361
runner.create_bwd_graph()
363362

364-
if torch.distributed.get_rank() == 0:
365-
# Memory usage.
366-
time_end = time.time()
367-
mem_stats_end = torch.cuda.memory_stats()
368-
capture_stats = {
369-
"time": time_end - time_start,
370-
"allocated_bytes": (
371-
mem_stats_end["allocated_bytes.all.current"]
372-
- mem_stats_start["allocated_bytes.all.current"]
373-
),
374-
"reserved_bytes": (
375-
mem_stats_end["reserved_bytes.all.current"]
376-
- mem_stats_start["reserved_bytes.all.current"]
377-
),
378-
}
363+
# Memory usage.
364+
time_end = time.time()
365+
mem_stats_end = torch.cuda.memory_stats()
366+
capture_stats = {
367+
"time": time_end - time_start,
368+
"allocated_bytes": (
369+
mem_stats_end["allocated_bytes.all.current"]
370+
- mem_stats_start["allocated_bytes.all.current"]
371+
),
372+
"reserved_bytes": (
373+
mem_stats_end["reserved_bytes.all.current"]
374+
- mem_stats_start["reserved_bytes.all.current"]
375+
),
376+
}
379377

378+
if torch.distributed.get_rank() == 0:
380379
logger.info(
381380
"> built %d cuda graph(s) in %.2f sec, with total memory usage: "
382381
"allocated %s, reserved %s."

0 commit comments

Comments
 (0)