@@ -149,12 +149,12 @@ def get_experimental_attention_variant_module_spec(
149149##########
150150
151151
152- def get_transformer_block_with_experimental_attention_variant_spec (
153- config : TransformerConfig , vp_stage : Optional [ int ] = None , pp_rank : Optional [ int ] = None
154- ) -> TransformerBlockSubmodules :
155- """Build transformer block spec with experimental attention variants (e.g., linear attention).
152+ def get_transformer_layer_with_experimental_attention_variant_spec (
153+ config : TransformerConfig , backend : BackendSpecProvider = None
154+ ) -> List [ ModuleSpec ] :
155+ """Build transformer layer specs with experimental attention variants (e.g., linear attention).
156156
157- This function constructs a heterogeneous transformer block that supports mixing different
157+ This function is for constructing a heterogeneous transformer that supports mixing different
158158 attention mechanisms (experimental vs standard) and MLP types (MoE vs dense) across layers.
159159 **Note that, this API is a experimental API in the short term, and might be deprecated in the
160160 future. In the long run, we will move to a new design that better support hybrid models.**
@@ -170,22 +170,19 @@ def get_transformer_block_with_experimental_attention_variant_spec(
170170 2. Per-Layer Spec Construction: Iterates through layers, constructing transformer
171171 layer specs based on attention and MLP patterns.
172172
173- 3. Pipeline Slicing: Extracts layer specs for the current pipeline stage.
174-
175173 Args:
176174 config: Transformer configuration containing model hyperparameters and feature flags.
177- vp_stage: Virtual pipeline stage index for interleaved pipeline parallelism.
178- pp_rank: Pipeline model parallel rank.
179175
180176 Returns:
181- TransformerBlockSubmodules containing per-layer specs and final layer norm .
177+ List[ModuleSpec] containing per-layer specs.
182178
183179 Note:
184180 Currently only supports transformer_engine backend. Kitchen backend can be used as a
185181 wrapper with TE fallback for unsupported operations.
186182 """
187183
188- backend = _get_backend_spec_provider (config = config )
184+ if backend is None :
185+ backend = _get_backend_spec_provider (config = config )
189186
190187 # Get attention patterns and specs
191188 experimental_attention_pattern = [0 ] * config .num_layers
@@ -257,6 +254,42 @@ def get_transformer_block_with_experimental_attention_variant_spec(
257254 )
258255 )
259256
257+ return layer_specs
258+
259+
260+ def get_transformer_block_with_experimental_attention_variant_spec (
261+ config : TransformerConfig , vp_stage : Optional [int ] = None , pp_rank : Optional [int ] = None
262+ ) -> TransformerBlockSubmodules :
263+ """Build transformer block spec with experimental attention variants (e.g., linear attention).
264+
265+ This function constructs a heterogeneous transformer block that supports mixing different
266+ attention mechanisms (experimental vs standard) and MLP types (MoE vs dense) across layers.
267+ **Note that, this API is a experimental API in the short term, and might be deprecated in the
268+ future. In the long run, we will move to a new design that better support hybrid models.**
269+
270+ Constructing transformer layer specs by
271+ `get_transformer_layer_with_experimental_attention_variant_spec` and then slicing the
272+ layer specs to only include the layers that are built in this pipeline stage.
273+
274+ Args:
275+ config: Transformer configuration containing model hyperparameters and feature flags.
276+ vp_stage: Virtual pipeline stage index for interleaved pipeline parallelism.
277+ pp_rank: Pipeline model parallel rank.
278+
279+ Returns:
280+ TransformerBlockSubmodules containing per-layer specs and final layer norm.
281+
282+ Note:
283+ Currently only supports transformer_engine backend. Kitchen backend can be used as a
284+ wrapper with TE fallback for unsupported operations.
285+ """
286+
287+ backend = _get_backend_spec_provider (config = config )
288+
289+ layer_specs = get_transformer_layer_with_experimental_attention_variant_spec (
290+ config = config , backend = backend
291+ )
292+
260293 # Slice the layer specs to only include the layers that are built in this pipeline stage.
261294 if config .pipeline_model_parallel_layout is not None :
262295 local_layer_ids = config .pipeline_model_parallel_layout .get_layer_id_list (
@@ -270,6 +303,7 @@ def get_transformer_block_with_experimental_attention_variant_spec(
270303 layer_specs = [layer_specs [layer_id ] for layer_id in local_layer_ids ]
271304
272305 # Get GPT decoder block spec
306+ rms_norm = config .normalization == "RMSNorm"
273307 gpt_decoder_block_spec = TransformerBlockSubmodules (
274308 layer_specs = layer_specs , layer_norm = backend .layer_norm (rms_norm = rms_norm , for_qk = False )
275309 )
@@ -359,7 +393,7 @@ def _get_backend_spec_provider(config: TransformerConfig) -> BackendSpecProvider
359393 )
360394 backend : BackendSpecProvider = (
361395 KitchenSpecProvider (
362- fallback = TESpecProvider (),
396+ fallback = TESpecProvider (fallback_to_eager_attn = config . fallback_to_eager_attn ),
363397 use_kitchen_attention = config .use_kitchen_attention ,
364398 kitchen_attention_backend = config .kitchen_attention_backend ,
365399 )
@@ -396,6 +430,7 @@ def _get_self_attention_module_spec(
396430 qk_l2_norm = config .qk_l2_norm ,
397431 use_kitchen = config .use_kitchen ,
398432 use_te_activation_func = config .use_te_activation_func ,
433+ fallback_to_eager_attn = config .fallback_to_eager_attn ,
399434 use_kitchen_attention = config .use_kitchen_attention ,
400435 kitchen_attention_backend = config .kitchen_attention_backend ,
401436 )
0 commit comments