Skip to content

Commit 0cf9ab8

Browse files
committed
cleanup
Signed-off-by: Jieming Zhang <jiemingz@nvidia.com>
1 parent 33f18a0 commit 0cf9ab8

File tree

1 file changed

+36
-57
lines changed

1 file changed

+36
-57
lines changed

megatron/core/transformer/cuda_graphs.py

Lines changed: 36 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -131,27 +131,6 @@ def _check_supported_type(meta):
131131
), f"Cudagraphs recieved an arg of type {meta.type} which is not supported."
132132

133133

134-
def _determine_if_transformer_decoder_layer(base_module):
135-
"""Determine if the given module is a transformer decoder layer."""
136-
# import modules here to avoid a circular import
137-
from megatron.core.ssm.mamba_layer import MambaLayer
138-
from megatron.core.transformer.transformer_layer import BaseTransformerLayer, TransformerLayer
139-
140-
is_potential_decoder_layer = isinstance(
141-
base_module, (TransformerLayer, BaseTransformerLayer, MambaLayer)
142-
)
143-
if not is_potential_decoder_layer:
144-
return False
145-
if isinstance(base_module, TransformerLayer) and not isinstance(
146-
base_module.cross_attention, IdentityOp
147-
):
148-
# If the layer has a cross attention, it is not a decoder layer
149-
return False
150-
else:
151-
# Otherwise it is a decoder layer
152-
return True
153-
154-
155134
def _determine_if_first_last_layer_of_this_vp_chunk(base_module):
156135
"""Determine if the given module is the first/last layer of the PP+VPP chunk it belongs to.
157136
Returns a tuple of two booleans indicating if the module is the first/last layer of the chunk.
@@ -242,10 +221,6 @@ def create_cudagraphs(cls):
242221
gc.collect()
243222
torch.cuda.empty_cache()
244223

245-
_set_capture_start()
246-
if has_te_modules:
247-
te_set_capture_start()
248-
249224
def format_mem_bytes(mem_bytes):
250225
for power, suffix in [(4, "tb"), (3, "gb"), (2, "mb"), (1, "kb"), (0, "bytes")]:
251226
suffix_bytes = 1024**power
@@ -279,9 +254,8 @@ def format_mem_bytes(mem_bytes):
279254
runner.create_bwd_graph(global_tensor_pool)
280255

281256
global bwd_buffer_reuse_ref_count, fwd_buffer_reuse_ref_count
282-
# assert bwd_buffer_reuse_ref_count == 0
283-
# assert fwd_buffer_reuse_ref_count == 0
284-
257+
assert bwd_buffer_reuse_ref_count == 0
258+
assert fwd_buffer_reuse_ref_count == 0
285259

286260
# Memory usage.
287261
time_end = time.time()
@@ -317,11 +291,6 @@ def format_mem_bytes(mem_bytes):
317291
cls.cudagraph_created = True
318292
cls.cudagraph_record = []
319293

320-
# Finished capturing.
321-
_set_capture_end()
322-
if has_te_modules:
323-
te_set_capture_end()
324-
325294
# Return capture time and memory usage.
326295
return capture_stats
327296

@@ -547,8 +516,8 @@ def __init__(
547516
self.fp8_enabled = False
548517
self.fp4_enabled = False
549518
self.deallocate_pipeline_outputs = False
519+
self.num_warmup_steps = 1
550520

551-
self.is_transformer_decoder_layer = _determine_if_transformer_decoder_layer(base_module)
552521
self.grad_enabled = need_backward and torch.is_grad_enabled()
553522
self.func = super(MegatronModule, self.base_module).__call__ if func is None else func
554523
self.is_first_layer, self.is_last_layer = (
@@ -571,14 +540,17 @@ def __init__(
571540
self.fp8_runtime_enabled = None
572541
self.fp4_runtime_enabled = None
573542

574-
if self.fp8_enabled:
575-
self.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
576-
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)
543+
if HAVE_TE_GRAPHS:
544+
self.has_te_modules = any(
545+
[isinstance(m, TransformerEngineBaseModule) for m in self.base_module.modules()]
546+
)
577547

578-
if self.fp4_enabled:
579-
from megatron.core.fp4_utils import get_fp4_recipe # to avoid circular import
548+
if self.fp8_enabled:
549+
self.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
550+
if self.fp4_enabled:
551+
from megatron.core.fp4_utils import get_fp4_recipe # to avoid circular import
552+
self.fp4_recipe = get_fp4_recipe(self.base_module.config)
580553

581-
self.fp4_recipe = get_fp4_recipe(self.base_module.config)
582554
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)
583555

584556
def __str__(self):
@@ -669,7 +641,7 @@ def get_fwd_input_buffer(ten):
669641

670642
# cache the moe aux loss if needed, this is needed because the moe aux loss is accumulated inside
671643
# the transformer layer forward pass:
672-
is_moe = self.is_transformer_decoder_layer and hasattr(self.base_module, "is_moe_layer") and self.base_module.is_moe_layer
644+
is_moe = hasattr(self.base_module, "is_moe_layer") and self.base_module.is_moe_layer
673645
if is_moe:
674646
from megatron.core.transformer.moe.moe_utils import get_moe_layer_wise_logging_tracker
675647
tracker = get_moe_layer_wise_logging_tracker()
@@ -710,7 +682,7 @@ def get_fwd_input_buffer(ten):
710682

711683
with ctx:
712684
# warmup again as case graph capture mode may execute a different codepath
713-
for _ in range(1):
685+
for _ in range(num_warmup_steps):
714686
with self.get_quantization_context():
715687
def clone_ten(ten):
716688
clone = torch.zeros_like(ten)
@@ -744,6 +716,8 @@ def clone_ten(ten):
744716

745717
with self.get_quantization_context():
746718
_set_capture_start()
719+
if self.has_te_modules: te_set_capture_start()
720+
747721
# Freeze GC, to speed up capture time ~15-20x.
748722
if FREEZE_GC:
749723
gc.freeze()
@@ -763,7 +737,9 @@ def clone_ten(ten):
763737
# per-device to avoid slowing down graph creation.
764738
if self.is_last_layer:
765739
gc.collect()
740+
766741
_set_capture_end()
742+
if self.has_te_modules: te_set_capture_end()
767743

768744
# save cudagraph output buffer
769745
if isinstance(fwd_graph_outputs, torch.Tensor):
@@ -843,6 +819,9 @@ def create_bwd_graph(self, global_tensor_pool):
843819
input_tensors = self.get_tensors(self.fwd_graph_input_args, self.fwd_graph_input_kwargs)
844820
fwd_input_surface = input_tensors + tuple(self.params_to_backprop)
845821

822+
_set_capture_start()
823+
if self.has_te_modules: te_set_capture_start()
824+
846825
# Freeze GC, to speed up capture time ~15-20x.
847826
if FREEZE_GC:
848827
gc.freeze()
@@ -861,6 +840,9 @@ def create_bwd_graph(self, global_tensor_pool):
861840
if FREEZE_GC:
862841
gc.unfreeze()
863842

843+
_set_capture_end()
844+
if self.has_te_modules: te_set_capture_end()
845+
864846
grad_inputs = list(grad_inputs)
865847

866848
self.static_grad_outputs = static_grad_outputs
@@ -966,19 +948,17 @@ def record_graph_capture(self, args, kwargs):
966948
o.is_cudagraph_output = True
967949

968950
if not self.fwd_graph_recorded:
969-
if HAVE_TE_GRAPHS:
970-
if FP8GlobalStateManager.is_fp8_enabled():
971-
# check if the low precision recipe is either fp4 or fp8
972-
if is_te_min_version("2.7.0.dev0"):
973-
from transformer_engine.common.recipe import NVFP4BlockScaling
974-
recipe = FP8GlobalStateManager.get_fp8_recipe()
975-
if isinstance(recipe, NVFP4BlockScaling):
976-
self.fp4_runtime_enabled = True
977-
else:
978-
self.fp8_runtime_enabled = True
979-
else:
951+
if self.fp8_enabled or self.fp4_enabled:
952+
# check if any low precision recipe is enabled
953+
if is_te_min_version("2.7.0.dev0"):
954+
from transformer_engine.common.recipe import NVFP4BlockScaling
955+
recipe = FP8GlobalStateManager.get_fp8_recipe()
956+
if isinstance(recipe, NVFP4BlockScaling):
957+
self.fp4_runtime_enabled = True
958+
else:
980959
self.fp8_runtime_enabled = True
981-
960+
else:
961+
self.fp8_runtime_enabled = True
982962

983963
logger.debug(f"Recording forward graph creation...")
984964
m_args, m_kwargs = self.replace_tensors_with_weak_refs(args, kwargs, cache_refs=True)
@@ -1233,9 +1213,8 @@ def wrapped_func(*args, **kwargs):
12331213
self.inference_cudagraphs_lookup_table = defaultdict(lambda: None)
12341214
self.is_first_microbatch = False
12351215

1236-
# Without pipeline parallelism, microbatches execute one at a time.
1237-
# Therefore modules will always execute in the same order, so cudagraphs
1238-
# can both be reused and share a single mempool.
1216+
# Without pipeline parallelism, modules execute one at a time in the same order, so cudagraphs
1217+
# may be reused across microbatches
12391218
self.reuse_cudagraphs = parallel_state.get_pipeline_model_parallel_world_size() == 1
12401219
if CudaGraphManager.global_mempool is None:
12411220
CudaGraphManager.global_mempool = torch.cuda.graph_pool_handle()

0 commit comments

Comments
 (0)