@@ -131,27 +131,6 @@ def _check_supported_type(meta):
131131 ), f"Cudagraphs recieved an arg of type { meta .type } which is not supported."
132132
133133
134- def _determine_if_transformer_decoder_layer (base_module ):
135- """Determine if the given module is a transformer decoder layer."""
136- # import modules here to avoid a circular import
137- from megatron .core .ssm .mamba_layer import MambaLayer
138- from megatron .core .transformer .transformer_layer import BaseTransformerLayer , TransformerLayer
139-
140- is_potential_decoder_layer = isinstance (
141- base_module , (TransformerLayer , BaseTransformerLayer , MambaLayer )
142- )
143- if not is_potential_decoder_layer :
144- return False
145- if isinstance (base_module , TransformerLayer ) and not isinstance (
146- base_module .cross_attention , IdentityOp
147- ):
148- # If the layer has a cross attention, it is not a decoder layer
149- return False
150- else :
151- # Otherwise it is a decoder layer
152- return True
153-
154-
155134def _determine_if_first_last_layer_of_this_vp_chunk (base_module ):
156135 """Determine if the given module is the first/last layer of the PP+VPP chunk it belongs to.
157136 Returns a tuple of two booleans indicating if the module is the first/last layer of the chunk.
@@ -242,10 +221,6 @@ def create_cudagraphs(cls):
242221 gc .collect ()
243222 torch .cuda .empty_cache ()
244223
245- _set_capture_start ()
246- if has_te_modules :
247- te_set_capture_start ()
248-
249224 def format_mem_bytes (mem_bytes ):
250225 for power , suffix in [(4 , "tb" ), (3 , "gb" ), (2 , "mb" ), (1 , "kb" ), (0 , "bytes" )]:
251226 suffix_bytes = 1024 ** power
@@ -279,9 +254,8 @@ def format_mem_bytes(mem_bytes):
279254 runner .create_bwd_graph (global_tensor_pool )
280255
281256 global bwd_buffer_reuse_ref_count , fwd_buffer_reuse_ref_count
282- # assert bwd_buffer_reuse_ref_count == 0
283- # assert fwd_buffer_reuse_ref_count == 0
284-
257+ assert bwd_buffer_reuse_ref_count == 0
258+ assert fwd_buffer_reuse_ref_count == 0
285259
286260 # Memory usage.
287261 time_end = time .time ()
@@ -317,11 +291,6 @@ def format_mem_bytes(mem_bytes):
317291 cls .cudagraph_created = True
318292 cls .cudagraph_record = []
319293
320- # Finished capturing.
321- _set_capture_end ()
322- if has_te_modules :
323- te_set_capture_end ()
324-
325294 # Return capture time and memory usage.
326295 return capture_stats
327296
@@ -547,8 +516,8 @@ def __init__(
547516 self .fp8_enabled = False
548517 self .fp4_enabled = False
549518 self .deallocate_pipeline_outputs = False
519+ self .num_warmup_steps = 1
550520
551- self .is_transformer_decoder_layer = _determine_if_transformer_decoder_layer (base_module )
552521 self .grad_enabled = need_backward and torch .is_grad_enabled ()
553522 self .func = super (MegatronModule , self .base_module ).__call__ if func is None else func
554523 self .is_first_layer , self .is_last_layer = (
@@ -571,14 +540,17 @@ def __init__(
571540 self .fp8_runtime_enabled = None
572541 self .fp4_runtime_enabled = None
573542
574- if self .fp8_enabled :
575- self .fp8_recipe = FP8GlobalStateManager .get_fp8_recipe ()
576- FP8GlobalStateManager .set_skip_fp8_weight_update_tensor (False )
543+ if HAVE_TE_GRAPHS :
544+ self .has_te_modules = any (
545+ [isinstance (m , TransformerEngineBaseModule ) for m in self .base_module .modules ()]
546+ )
577547
578- if self .fp4_enabled :
579- from megatron .core .fp4_utils import get_fp4_recipe # to avoid circular import
548+ if self .fp8_enabled :
549+ self .fp8_recipe = FP8GlobalStateManager .get_fp8_recipe ()
550+ if self .fp4_enabled :
551+ from megatron .core .fp4_utils import get_fp4_recipe # to avoid circular import
552+ self .fp4_recipe = get_fp4_recipe (self .base_module .config )
580553
581- self .fp4_recipe = get_fp4_recipe (self .base_module .config )
582554 FP8GlobalStateManager .set_skip_fp8_weight_update_tensor (False )
583555
584556 def __str__ (self ):
@@ -669,7 +641,7 @@ def get_fwd_input_buffer(ten):
669641
670642 # cache the moe aux loss if needed, this is needed because the moe aux loss is accumulated inside
671643 # the transformer layer forward pass:
672- is_moe = self . is_transformer_decoder_layer and hasattr (self .base_module , "is_moe_layer" ) and self .base_module .is_moe_layer
644+ is_moe = hasattr (self .base_module , "is_moe_layer" ) and self .base_module .is_moe_layer
673645 if is_moe :
674646 from megatron .core .transformer .moe .moe_utils import get_moe_layer_wise_logging_tracker
675647 tracker = get_moe_layer_wise_logging_tracker ()
@@ -710,7 +682,7 @@ def get_fwd_input_buffer(ten):
710682
711683 with ctx :
712684 # warmup again as case graph capture mode may execute a different codepath
713- for _ in range (1 ):
685+ for _ in range (num_warmup_steps ):
714686 with self .get_quantization_context ():
715687 def clone_ten (ten ):
716688 clone = torch .zeros_like (ten )
@@ -744,6 +716,8 @@ def clone_ten(ten):
744716
745717 with self .get_quantization_context ():
746718 _set_capture_start ()
719+ if self .has_te_modules : te_set_capture_start ()
720+
747721 # Freeze GC, to speed up capture time ~15-20x.
748722 if FREEZE_GC :
749723 gc .freeze ()
@@ -763,7 +737,9 @@ def clone_ten(ten):
763737 # per-device to avoid slowing down graph creation.
764738 if self .is_last_layer :
765739 gc .collect ()
740+
766741 _set_capture_end ()
742+ if self .has_te_modules : te_set_capture_end ()
767743
768744 # save cudagraph output buffer
769745 if isinstance (fwd_graph_outputs , torch .Tensor ):
@@ -843,6 +819,9 @@ def create_bwd_graph(self, global_tensor_pool):
843819 input_tensors = self .get_tensors (self .fwd_graph_input_args , self .fwd_graph_input_kwargs )
844820 fwd_input_surface = input_tensors + tuple (self .params_to_backprop )
845821
822+ _set_capture_start ()
823+ if self .has_te_modules : te_set_capture_start ()
824+
846825 # Freeze GC, to speed up capture time ~15-20x.
847826 if FREEZE_GC :
848827 gc .freeze ()
@@ -861,6 +840,9 @@ def create_bwd_graph(self, global_tensor_pool):
861840 if FREEZE_GC :
862841 gc .unfreeze ()
863842
843+ _set_capture_end ()
844+ if self .has_te_modules : te_set_capture_end ()
845+
864846 grad_inputs = list (grad_inputs )
865847
866848 self .static_grad_outputs = static_grad_outputs
@@ -966,19 +948,17 @@ def record_graph_capture(self, args, kwargs):
966948 o .is_cudagraph_output = True
967949
968950 if not self .fwd_graph_recorded :
969- if HAVE_TE_GRAPHS :
970- if FP8GlobalStateManager .is_fp8_enabled ():
971- # check if the low precision recipe is either fp4 or fp8
972- if is_te_min_version ("2.7.0.dev0" ):
973- from transformer_engine .common .recipe import NVFP4BlockScaling
974- recipe = FP8GlobalStateManager .get_fp8_recipe ()
975- if isinstance (recipe , NVFP4BlockScaling ):
976- self .fp4_runtime_enabled = True
977- else :
978- self .fp8_runtime_enabled = True
979- else :
951+ if self .fp8_enabled or self .fp4_enabled :
952+ # check if any low precision recipe is enabled
953+ if is_te_min_version ("2.7.0.dev0" ):
954+ from transformer_engine .common .recipe import NVFP4BlockScaling
955+ recipe = FP8GlobalStateManager .get_fp8_recipe ()
956+ if isinstance (recipe , NVFP4BlockScaling ):
957+ self .fp4_runtime_enabled = True
958+ else :
980959 self .fp8_runtime_enabled = True
981-
960+ else :
961+ self .fp8_runtime_enabled = True
982962
983963 logger .debug (f"Recording forward graph creation..." )
984964 m_args , m_kwargs = self .replace_tensors_with_weak_refs (args , kwargs , cache_refs = True )
@@ -1233,9 +1213,8 @@ def wrapped_func(*args, **kwargs):
12331213 self .inference_cudagraphs_lookup_table = defaultdict (lambda : None )
12341214 self .is_first_microbatch = False
12351215
1236- # Without pipeline parallelism, microbatches execute one at a time.
1237- # Therefore modules will always execute in the same order, so cudagraphs
1238- # can both be reused and share a single mempool.
1216+ # Without pipeline parallelism, modules execute one at a time in the same order, so cudagraphs
1217+ # may be reused across microbatches
12391218 self .reuse_cudagraphs = parallel_state .get_pipeline_model_parallel_world_size () == 1
12401219 if CudaGraphManager .global_mempool is None :
12411220 CudaGraphManager .global_mempool = torch .cuda .graph_pool_handle ()
0 commit comments