Skip to content

Commit 1c063fc

Browse files
yuzhongw-nvidiako3n1g
authored andcommitted
re-apply following commit properly post merge with main
[dev] feat(moe): Cherry-pick NVIDIA#1989 back to dev (NVIDIA#3011) Signed-off-by: oliver könig <[email protected]> Co-authored-by: oliver könig <[email protected]>
1 parent 08357d8 commit 1c063fc

File tree

3 files changed

+62
-23
lines changed

3 files changed

+62
-23
lines changed

gpt_builders.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from megatron.core.models.gpt.experimental_attention_variant_module_specs import (
1313
get_transformer_block_with_experimental_attention_variant_spec,
14+
get_transformer_layer_with_experimental_attention_variant_spec,
1415
)
1516
from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
1617
get_gpt_heterogeneous_layer_spec,
@@ -76,13 +77,19 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_
7677
mtp_transformer_layer_spec = import_module(args.spec)
7778
else:
7879
# Define the decoder block spec
79-
decoder_layer_specs = get_gpt_decoder_layer_specs(
80-
config,
81-
use_transformer_engine=use_te,
82-
normalization=args.normalization,
83-
qk_l2_norm=args.qk_l2_norm,
84-
vp_stage=vp_stage,
85-
)
80+
if args.experimental_attention_variant is not None:
81+
decoder_layer_specs = (
82+
get_transformer_layer_with_experimental_attention_variant_spec(
83+
config=config
84+
)
85+
)
86+
else:
87+
decoder_layer_specs = get_gpt_decoder_layer_specs(
88+
config,
89+
use_transformer_engine=use_te,
90+
normalization=args.normalization,
91+
qk_l2_norm=args.qk_l2_norm,
92+
)
8693
mtp_transformer_layer_spec = decoder_layer_specs[-1]
8794
# Use spec of the last layer in decoder block as spec of the transformer layer in MTP
8895
mtp_block_spec = get_gpt_mtp_block_spec(

megatron/core/models/gpt/experimental_attention_variant_module_specs.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,12 @@ def get_experimental_attention_variant_module_spec(
149149
##########
150150

151151

152-
def get_transformer_block_with_experimental_attention_variant_spec(
153-
config: TransformerConfig, vp_stage: Optional[int] = None, pp_rank: Optional[int] = None
154-
) -> TransformerBlockSubmodules:
155-
"""Build transformer block spec with experimental attention variants (e.g., linear attention).
152+
def get_transformer_layer_with_experimental_attention_variant_spec(
153+
config: TransformerConfig, backend: BackendSpecProvider = None
154+
) -> List[ModuleSpec]:
155+
"""Build transformer layer specs with experimental attention variants (e.g., linear attention).
156156
157-
This function constructs a heterogeneous transformer block that supports mixing different
157+
This function is for constructing a heterogeneous transformer that supports mixing different
158158
attention mechanisms (experimental vs standard) and MLP types (MoE vs dense) across layers.
159159
**Note that, this API is a experimental API in the short term, and might be deprecated in the
160160
future. In the long run, we will move to a new design that better support hybrid models.**
@@ -170,22 +170,19 @@ def get_transformer_block_with_experimental_attention_variant_spec(
170170
2. Per-Layer Spec Construction: Iterates through layers, constructing transformer
171171
layer specs based on attention and MLP patterns.
172172
173-
3. Pipeline Slicing: Extracts layer specs for the current pipeline stage.
174-
175173
Args:
176174
config: Transformer configuration containing model hyperparameters and feature flags.
177-
vp_stage: Virtual pipeline stage index for interleaved pipeline parallelism.
178-
pp_rank: Pipeline model parallel rank.
179175
180176
Returns:
181-
TransformerBlockSubmodules containing per-layer specs and final layer norm.
177+
List[ModuleSpec] containing per-layer specs.
182178
183179
Note:
184180
Currently only supports transformer_engine backend. Kitchen backend can be used as a
185181
wrapper with TE fallback for unsupported operations.
186182
"""
187183

188-
backend = _get_backend_spec_provider(config=config)
184+
if backend is None:
185+
backend = _get_backend_spec_provider(config=config)
189186

190187
# Get attention patterns and specs
191188
experimental_attention_pattern = [0] * config.num_layers
@@ -257,6 +254,42 @@ def get_transformer_block_with_experimental_attention_variant_spec(
257254
)
258255
)
259256

257+
return layer_specs
258+
259+
260+
def get_transformer_block_with_experimental_attention_variant_spec(
261+
config: TransformerConfig, vp_stage: Optional[int] = None, pp_rank: Optional[int] = None
262+
) -> TransformerBlockSubmodules:
263+
"""Build transformer block spec with experimental attention variants (e.g., linear attention).
264+
265+
This function constructs a heterogeneous transformer block that supports mixing different
266+
attention mechanisms (experimental vs standard) and MLP types (MoE vs dense) across layers.
267+
**Note that, this API is a experimental API in the short term, and might be deprecated in the
268+
future. In the long run, we will move to a new design that better support hybrid models.**
269+
270+
Constructing transformer layer specs by
271+
`get_transformer_layer_with_experimental_attention_variant_spec` and then slicing the
272+
layer specs to only include the layers that are built in this pipeline stage.
273+
274+
Args:
275+
config: Transformer configuration containing model hyperparameters and feature flags.
276+
vp_stage: Virtual pipeline stage index for interleaved pipeline parallelism.
277+
pp_rank: Pipeline model parallel rank.
278+
279+
Returns:
280+
TransformerBlockSubmodules containing per-layer specs and final layer norm.
281+
282+
Note:
283+
Currently only supports transformer_engine backend. Kitchen backend can be used as a
284+
wrapper with TE fallback for unsupported operations.
285+
"""
286+
287+
backend = _get_backend_spec_provider(config=config)
288+
289+
layer_specs = get_transformer_layer_with_experimental_attention_variant_spec(
290+
config=config, backend=backend
291+
)
292+
260293
# Slice the layer specs to only include the layers that are built in this pipeline stage.
261294
if config.pipeline_model_parallel_layout is not None:
262295
local_layer_ids = config.pipeline_model_parallel_layout.get_layer_id_list(
@@ -270,6 +303,7 @@ def get_transformer_block_with_experimental_attention_variant_spec(
270303
layer_specs = [layer_specs[layer_id] for layer_id in local_layer_ids]
271304

272305
# Get GPT decoder block spec
306+
rms_norm = config.normalization == "RMSNorm"
273307
gpt_decoder_block_spec = TransformerBlockSubmodules(
274308
layer_specs=layer_specs, layer_norm=backend.layer_norm(rms_norm=rms_norm, for_qk=False)
275309
)
@@ -359,7 +393,7 @@ def _get_backend_spec_provider(config: TransformerConfig) -> BackendSpecProvider
359393
)
360394
backend: BackendSpecProvider = (
361395
KitchenSpecProvider(
362-
fallback=TESpecProvider(),
396+
fallback=TESpecProvider(fallback_to_eager_attn=config.fallback_to_eager_attn),
363397
use_kitchen_attention=config.use_kitchen_attention,
364398
kitchen_attention_backend=config.kitchen_attention_backend,
365399
)
@@ -396,6 +430,7 @@ def _get_self_attention_module_spec(
396430
qk_l2_norm=config.qk_l2_norm,
397431
use_kitchen=config.use_kitchen,
398432
use_te_activation_func=config.use_te_activation_func,
433+
fallback_to_eager_attn=config.fallback_to_eager_attn,
399434
use_kitchen_attention=config.use_kitchen_attention,
400435
kitchen_attention_backend=config.kitchen_attention_backend,
401436
)

megatron/core/models/gpt/gpt_layer_specs.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,7 @@ def get_gpt_decoder_block_spec(
618618
layer_specs = get_gpt_decoder_layer_specs(
619619
config, use_transformer_engine, normalization, qk_l2_norm
620620
)
621+
621622
# Slice the layer specs to only include the layers that are built in this pipeline stage.
622623
# Note: MCore layer_number starts at 1
623624
num_layers_to_build = get_num_layers_to_build(config, vp_stage=vp_stage, pp_rank=pp_rank)
@@ -637,10 +638,6 @@ def get_gpt_decoder_block_spec(
637638
offset = get_transformer_layer_offset(config, vp_stage=vp_stage, pp_rank=pp_rank)
638639
local_layer_specs = layer_specs[offset : offset + num_layers_to_build]
639640

640-
if use_transformer_engine:
641-
layer_norm_impl = TENorm
642-
else:
643-
layer_norm_impl = LNImpl
644641
# Block spec.
645642
if use_transformer_engine:
646643
layer_norm_impl = TENorm

0 commit comments

Comments
 (0)