22import torch
33
44from megatron .core .fusions .fused_bias_dropout import get_bias_dropout_add
5+ from megatron .core .ssm .mamba_block import MambaStack , MambaStackSubmodules
6+ from megatron .core .ssm .mamba_layer import MambaLayer , MambaLayerSubmodules
7+ from megatron .core .ssm .mamba_mixer import MambaMixer , MambaMixerSubmodules
8+ from megatron .core .ssm .mlp_layer import MLPLayer
59from megatron .core .tensor_parallel .layers import ColumnParallelLinear , RowParallelLinear
610from megatron .core .transformer .attention import SelfAttention , SelfAttentionSubmodules
711from megatron .core .transformer .dot_product_attention import DotProductAttention
1014from megatron .core .transformer .mlp import MLP , MLPSubmodules
1115from megatron .core .transformer .spec_utils import ModuleSpec
1216from megatron .core .transformer .transformer_layer import TransformerLayer , TransformerLayerSubmodules
13- from megatron .core .ssm .mamba_block import MambaStack , MambaStackSubmodules
14- from megatron .core .ssm .mamba_layer import MambaLayer , MambaLayerSubmodules
15- from megatron .core .ssm .mamba_mixer import MambaMixer , MambaMixerSubmodules
16- from megatron .core .ssm .mlp_layer import MLPLayer
17+ from megatron .core .typed_torch import not_none
1718
1819try :
1920 from megatron .core .extensions .transformer_engine import (
2627
2728 HAVE_TE = True
2829except ImportError :
30+ (
31+ TEColumnParallelLinear ,
32+ TEDotProductAttention ,
33+ TELayerNormColumnParallelLinear ,
34+ TENorm ,
35+ TERowParallelLinear ,
36+ ) = (None , None , None , None , None )
2937 HAVE_TE = False
3038
3139try :
@@ -54,12 +62,8 @@ def get_layer_spec(is_vit, normalization) -> ModuleSpec:
5462 norm = TENorm
5563 else :
5664 version = torch .__version__ .split ('.' )
57- version_geq_2_4 = (
58- int (TORCH_VERSION [0 ]) > 2
59- or (
60- int (TORCH_VERSION [0 ]) == 2
61- and int (TORCH_VERSION [1 ]) >= 4
62- )
65+ version_geq_2_4 = int (TORCH_VERSION [0 ]) > 2 or (
66+ int (TORCH_VERSION [0 ]) == 2 and int (TORCH_VERSION [1 ]) >= 4
6367 )
6468 assert version_geq_2_4 , "Torch version >= 2.4.0 is required for RMSNorm"
6569 if HAVE_APEX :
@@ -108,8 +112,8 @@ def get_layer_spec_te(is_vit=False, padding=False) -> ModuleSpec:
108112 module = SelfAttention ,
109113 params = {"attn_mask_type" : attn_mask_type },
110114 submodules = SelfAttentionSubmodules (
111- linear_qkv = TELayerNormColumnParallelLinear ,
112- core_attention = TEDotProductAttention ,
115+ linear_qkv = not_none ( TELayerNormColumnParallelLinear ) ,
116+ core_attention = not_none ( TEDotProductAttention ) ,
113117 linear_proj = TERowParallelLinear ,
114118 q_layernorm = IdentityOp ,
115119 k_layernorm = IdentityOp ,
@@ -122,6 +126,7 @@ def get_layer_spec_te(is_vit=False, padding=False) -> ModuleSpec:
122126 ),
123127 )
124128
129+
125130def get_mamba_layer_spec_te (padding = False ) -> ModuleSpec :
126131 attn_mask_type = AttnMaskType .causal
127132 # Padding mask is needed for e.g. Context Parallel.
@@ -153,8 +158,8 @@ def get_mamba_layer_spec_te(padding=False) -> ModuleSpec:
153158 module = SelfAttention ,
154159 params = {"attn_mask_type" : attn_mask_type },
155160 submodules = SelfAttentionSubmodules (
156- linear_qkv = TELayerNormColumnParallelLinear ,
157- core_attention = TEDotProductAttention ,
161+ linear_qkv = not_none ( TELayerNormColumnParallelLinear ) ,
162+ core_attention = not_none ( TEDotProductAttention ) ,
158163 linear_proj = TERowParallelLinear ,
159164 ),
160165 ),
@@ -170,7 +175,8 @@ def get_mamba_layer_spec_te(padding=False) -> ModuleSpec:
170175 mlp = ModuleSpec (
171176 module = MLP ,
172177 submodules = MLPSubmodules (
173- linear_fc1 = TELayerNormColumnParallelLinear , linear_fc2 = TERowParallelLinear
178+ linear_fc1 = TELayerNormColumnParallelLinear ,
179+ linear_fc2 = TERowParallelLinear ,
174180 ),
175181 ),
176182 mlp_bda = get_bias_dropout_add ,
@@ -179,6 +185,7 @@ def get_mamba_layer_spec_te(padding=False) -> ModuleSpec:
179185 ),
180186 )
181187
188+
182189def get_mlp_module_spec (use_te : bool = True ) -> ModuleSpec :
183190 # Dense MLP w/ or w/o TE modules.
184191 return ModuleSpec (
0 commit comments