diff --git a/gpt_builders.py b/gpt_builders.py index d393961bb04..24b5f89d311 100644 --- a/gpt_builders.py +++ b/gpt_builders.py @@ -115,43 +115,42 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_ def _get_transformer_layer_spec(use_te, config): """Get transformer layer specification based on configuration. - + Args: use_te (bool): Whether to use Transformer Engine - args: Training arguments config: Model configuration - + Returns: transformer_layer_spec: The transformer layer specification """ - args = get_args() if use_te: return get_gpt_layer_with_transformer_engine_spec( - args.num_experts, - args.moe_grouped_gemm, - args.qk_layernorm, - args.multi_latent_attention, - args.experimental_attention_variant, - qk_l2_norm=args.qk_l2_norm, + config.num_moe_experts, + config.moe_grouped_gemm, + config.qk_layernorm, + config.multi_latent_attention, + config.experimental_attention_variant, + qk_l2_norm=config.qk_l2_norm, use_kitchen=config.use_kitchen, use_te_activation_func=config.use_te_activation_func, use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, + mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False), ) elif config.transformer_impl == "inference_optimized": return get_gpt_layer_with_inference_spec( - args.qk_layernorm, - args.multi_latent_attention, - qk_l2_norm=args.qk_l2_norm, + config.qk_layernorm, + config.multi_latent_attention, + qk_l2_norm=config.qk_l2_norm, ) else: return get_gpt_layer_local_spec( - args.num_experts, - args.moe_grouped_gemm, - args.qk_layernorm, - args.multi_latent_attention, - args.experimental_attention_variant, - normalization=args.normalization, + config.num_moe_experts, + config.moe_grouped_gemm, + config.qk_layernorm, + config.multi_latent_attention, + config.experimental_attention_variant, + normalization=config.normalization, use_kitchen=config.use_kitchen, use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, diff --git a/megatron/core/models/gpt/experimental_attention_variant_module_specs.py b/megatron/core/models/gpt/experimental_attention_variant_module_specs.py index e20f88ee4d1..6608073136c 100644 --- a/megatron/core/models/gpt/experimental_attention_variant_module_specs.py +++ b/megatron/core/models/gpt/experimental_attention_variant_module_specs.py @@ -397,6 +397,7 @@ def _get_self_attention_module_spec( use_te_activation_func=config.use_te_activation_func, use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, + mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False), ) attn_spec = layer_spec.submodules.self_attention if config.multi_latent_attention: diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 328d1c2f07f..cf2e0165f1e 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -14,6 +14,7 @@ from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.multi_latent_attention import ( + FusedMLASelfAttention, MLASelfAttention, MLASelfAttentionSubmodules, ) @@ -184,6 +185,7 @@ def get_gpt_layer_with_transformer_engine_submodules( use_te_activation_func: bool = False, use_kitchen_attention: bool = False, kitchen_attention_backend: str = "sdpa", + mla_down_proj_fusion: bool = False, ) -> TransformerLayerSubmodules: """Use these submodules to use lower-level Transformer Engine modules (required for fp8 training). @@ -198,6 +200,9 @@ def get_gpt_layer_with_transformer_engine_submodules( qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False. use_te_op_fuser (bool, optional): Use Transformer Engine's operation-based API, which may enable certain operation fusions. Defaults to False. + mla_down_proj_fusion (bool, optional): Enable fused q/kv down-projection and fused input + layernorm when backend supports. Otherwise fall back + to the unfused MLA. Returns: TransformerLayerSubmodules: TE modules to construct a TransformerLayer @@ -243,6 +248,45 @@ def get_gpt_layer_with_transformer_engine_submodules( if qk_layernorm else backend.column_parallel_linear() ) + + if mla_down_proj_fusion: + fuse_input_layernorm = backend.column_parallel_layer_norm_linear() is not None + input_layernorm = IdentityOp if fuse_input_layernorm else backend.layer_norm() + down_proj_linear = ( + backend.column_parallel_layer_norm_linear() + if fuse_input_layernorm + else backend.linear() + ) + return TransformerLayerSubmodules( + input_layernorm=input_layernorm, + self_attention=ModuleSpec( + module=FusedMLASelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=MLASelfAttentionSubmodules( + linear_q_proj=backend.column_parallel_linear(), + linear_qkv_down_proj=down_proj_linear, + linear_q_up_proj=linear_q_up_proj, + linear_kv_up_proj=linear_kv_up_proj, + core_attention=backend.core_attention(), + linear_proj=backend.row_parallel_linear(), + q_layernorm=IdentityOp, + kv_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + sharded_state_dict_keys_map=( + { + "self_attention.linear_q_down_proj.layer_norm_": "input_layernorm.", + "self_attention.linear_kv_down_proj.layer_norm_": "input_layernorm.", + "self_attention.linear_qkv_down_proj.layer_norm_": "input_layernorm.", + } + if fuse_input_layernorm + else {} + ), + ) return TransformerLayerSubmodules( input_layernorm=backend.layer_norm(has_residual=True), self_attention=ModuleSpec( @@ -526,6 +570,7 @@ def get_gpt_decoder_layer_specs( use_te_activation_func=config.use_te_activation_func, use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, + mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False), ) moe_layer_spec = get_gpt_layer_with_transformer_engine_spec( num_experts=config.num_moe_experts, @@ -537,6 +582,7 @@ def get_gpt_decoder_layer_specs( use_te_activation_func=config.use_te_activation_func, use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, + mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False), ) elif config.transformer_impl == "inference_optimized": layer_norm_impl = TENorm diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index a9cdc697cc8..8a7192e7694 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -16,6 +16,7 @@ from megatron.core import tensor_parallel +from megatron.core.dist_checkpointing.mapping import ShardedObject from megatron.core.extensions.transformer_engine import split_te_layernorm_column_parallel_linear from megatron.core.models.common.embeddings import ( RotaryEmbedding, @@ -39,7 +40,12 @@ from megatron.core.transformer.torch_norm import LayerNormBuilder from megatron.core.transformer.transformer_config import MLATransformerConfig from megatron.core.typed_torch import apply_module -from megatron.core.utils import deprecate_inference_params, get_pg_size, is_te_min_version +from megatron.core.utils import ( + deprecate_inference_params, + get_pg_size, + is_te_min_version, + make_tp_sharded_tensor_for_checkpoint, +) try: from megatron.core.fusions.fused_mla_yarn_rope_apply import ( @@ -54,6 +60,7 @@ try: from megatron.core.extensions.transformer_engine import ( TEColumnParallelLinear, + TELayerNormColumnParallelLinear, TELinear, set_save_original_input, ) @@ -61,7 +68,13 @@ HAVE_TE = True except ImportError: - TEColumnParallelLinear, TELinear, Linear, set_save_original_input = None, None, None, None + ( + TEColumnParallelLinear, + TELayerNormColumnParallelLinear, + TELinear, + Linear, + set_save_original_input, + ) = (None, None, None, None, None) HAVE_TE = False @@ -78,6 +91,7 @@ class MLASelfAttentionSubmodules: linear_q_up_proj: Union[ModuleSpec, type] = None linear_kv_down_proj: Union[ModuleSpec, type] = None linear_kv_up_proj: Union[ModuleSpec, type] = None + linear_qkv_down_proj: Union[ModuleSpec, type] = None core_attention: Union[ModuleSpec, type] = None linear_proj: Union[ModuleSpec, type] = None @@ -509,6 +523,34 @@ def __init__( eps=self.config.layernorm_epsilon, ) + def _qkv_down_projection(self, hidden_states): + """Unfused q/kv down projection path.""" + if self.config.q_lora_rank is not None: + # if linear_q_down_proj is ColumnParallelLinear: + # q_compressed: [s, b, q_lora_rank / TP] + # elif linear_q_down_proj is Linear: + # q_compressed: [s / TP, b, q_lora_rank] + q_compressed, _ = self.linear_q_down_proj(hidden_states) + + # When output is sharded (ColumnParallelLinear), two things are needed to be + # identical to a normal Linear. + # 1. Manually gather output to restore output dim q_lora_rank; + # 2. Scatter sequence back to s / TP if sequence-parallel since it was + # gathered by ColumnParallelLinear. + if q_compressed.size(-1) != self.config.q_lora_rank: + q_compressed = gather_from_tensor_model_parallel_region(q_compressed) + if self.config.sequence_parallel: + q_compressed = scatter_to_sequence_parallel_region(q_compressed) + else: + q_compressed = hidden_states + + # if linear_kv_down_proj is ColumnParallelLinear: + # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim) / TP] + # elif linear_kv_down_proj is Linear: + # kv_combined: [s / TP, b, (kv_lora_rank + qk_pos_emb_head_dim)] + kv_combined, _ = self.linear_kv_down_proj(hidden_states) + return q_compressed, kv_combined + def get_query_key_value_tensors( self, hidden_states, @@ -578,30 +620,7 @@ def get_query_key_value_tensors( # ========================================= # QKV down projection and layernorm # ========================================= - if self.config.q_lora_rank is not None: - # if linear_q_down_proj is ColumnParallelLinear: - # q_compressed: [s, b, q_lora_rank / TP] - # elif linear_q_down_proj is Linear: - # q_compressed: [s / TP, b, q_lora_rank] - q_compressed, _ = self.linear_q_down_proj(hidden_states) - - # When output is sharded (ColumnParallelLinear), two things are needed to be - # identical to a normal Linear. - # 1. Manually gather output to restore output dim q_lora_rank; - # 2. Scatter sequence back to s / TP if sequence-parallel since it was - # gathered by ColumnParallelLinear. - if q_compressed.size(-1) != self.config.q_lora_rank: - q_compressed = gather_from_tensor_model_parallel_region(q_compressed) - if self.config.sequence_parallel: - q_compressed = scatter_to_sequence_parallel_region(q_compressed) - else: - q_compressed = hidden_states - - # if linear_kv_down_proj is ColumnParallelLinear: - # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim) / TP] - # elif linear_kv_down_proj is Linear: - # kv_combined: [s / TP, b, (kv_lora_rank + qk_pos_emb_head_dim)] - kv_combined, _ = self.linear_kv_down_proj(hidden_states) + q_compressed, kv_combined = self._qkv_down_projection(hidden_states) if kv_combined.size(-1) != self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim: # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim)] kv_combined = gather_from_tensor_model_parallel_region(kv_combined) @@ -1093,3 +1112,217 @@ def _clip_kv_proj_weight(self, weight): ) return weight_kv_updated + + +class FusedMLASelfAttention(MLASelfAttention): + """MLA self-attention with fused q/kv down projection.""" + + def __init__( + self, + config: MLATransformerConfig, + submodules: MLASelfAttentionSubmodules, + layer_number: int, + attn_mask_type=AttnMaskType.padding, + cp_comm_type: Optional[str] = None, + pg_collection: Optional[ProcessGroupCollection] = None, + ): + if pg_collection is None: + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + + MultiLatentAttention.__init__( + self, + config=config, + submodules=submodules, + layer_number=layer_number, + attn_mask_type=attn_mask_type, + attention_type="self", + cp_comm_type=cp_comm_type, + pg_collection=pg_collection, + ) + + assert self.config.q_lora_rank is not None, ( + "FusedMLASelfAttention requires q_lora_rank to be set; " + "fallback to MLASelfAttention for q_lora_rank=None." + ) + + qkv_down_proj_kwargs = {} + if submodules.linear_qkv_down_proj in [TELinear]: + qkv_down_proj_kwargs['parallel_mode'] = 'duplicated' + elif submodules.linear_qkv_down_proj in [ + Linear, + TEColumnParallelLinear, + ColumnParallelLinear, + TELayerNormColumnParallelLinear, + ]: + qkv_down_proj_kwargs['gather_output'] = False + else: + raise ValueError(f"Unsupported linear_qkv_down_proj: {submodules.linear_qkv_down_proj}") + + self.linear_qkv_down_proj = build_module( + submodules.linear_qkv_down_proj, + self.config.hidden_size, + self.config.q_lora_rank + self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim, + config=self.config, + init_method=self.config.init_method, + bias=False, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='qkv_down_proj', + skip_weight_param_allocation=False, + tp_group=( + pg_collection.tp + if qkv_down_proj_kwargs.get('parallel_mode') != 'duplicated' + else None + ), + **qkv_down_proj_kwargs, + ) + + self.linear_q_up_proj = build_module( + submodules.linear_q_up_proj, + self.config.q_lora_rank, + self.config.num_attention_heads * self.q_head_dim, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=False, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='q_up_proj', + tp_group=pg_collection.tp, + ) + + self.linear_kv_up_proj = build_module( + submodules.linear_kv_up_proj, + self.config.kv_lora_rank, + self.config.num_attention_heads * (self.config.qk_head_dim + self.config.v_head_dim), + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=False, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='kv_up_proj', + tp_group=pg_collection.tp, + ) + + self.q_layernorm = submodules.q_layernorm( + hidden_size=self.config.q_lora_rank, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + self.kv_layernorm = submodules.kv_layernorm( + hidden_size=self.config.kv_lora_rank, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + + def _qkv_down_projection(self, hidden_states): + """Fused q/kv down projection path.""" + qkv, _ = self.linear_qkv_down_proj(hidden_states) + q_compressed, kv_combined = torch.split( + qkv, + [self.config.q_lora_rank, self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim], + dim=-1, + ) + return q_compressed, kv_combined + + def sharded_state_dict(self, prefix: str = "", sharded_offsets: tuple = (), metadata=None): + """Return a sharded state dict compatible with pre-fusion checkpoints.""" + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + + def _clone_sharded_object_with_key(obj: ShardedObject, new_key: str) -> ShardedObject: + return ShardedObject( + key=new_key, + data=obj.data, + global_shape=obj.global_shape, + global_offset=obj.global_offset, + replica_id=obj.replica_id, + ) + + fused_prefix = f"{prefix}linear_qkv_down_proj." + + fused_extra_keys = [ + k + for k in sharded_state_dict.keys() + if k.startswith(fused_prefix) and "_extra_state" in k + ] + for fused_extra_key in fused_extra_keys: + suffix = fused_extra_key[len(fused_prefix) :] + q_extra_key = f"{prefix}linear_q_down_proj.{suffix}" + kv_extra_key = f"{prefix}linear_kv_down_proj.{suffix}" + fused_obj = sharded_state_dict.get(fused_extra_key) + if isinstance(fused_obj, ShardedObject): + sharded_state_dict[q_extra_key] = _clone_sharded_object_with_key( + fused_obj, q_extra_key + ) + sharded_state_dict[kv_extra_key] = _clone_sharded_object_with_key( + fused_obj, kv_extra_key + ) + elif fused_obj is not None: + sharded_state_dict[q_extra_key] = fused_obj + sharded_state_dict[kv_extra_key] = fused_obj + + for key in list(sharded_state_dict.keys()): + if key.startswith(fused_prefix): + del sharded_state_dict[key] + + fused_weight = self.linear_qkv_down_proj.weight + total_out = ( + self.config.q_lora_rank + self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim + ) + tp_size = get_pg_size(self.tp_group) + + if fused_weight.size(0) == total_out: + q_split = self.config.q_lora_rank + kv_split = self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim + else: + assert ( + self.config.q_lora_rank % tp_size == 0 + ), "q_lora_rank must be divisible by tensor-parallel size" + assert ( + self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim + ) % tp_size == 0, ( + "kv_lora_rank + qk_pos_emb_head_dim must be divisible by tensor-parallel size" + ) + q_split = self.config.q_lora_rank // tp_size + kv_split = (self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim) // tp_size + + if q_split + kv_split != fused_weight.size(0): + raise ValueError( + "Unexpected fused qkv-down weight shape: " + f"got {tuple(fused_weight.size())}, expected dim0 {q_split + kv_split}" + ) + + q_weight, kv_weight = torch.split(fused_weight, [q_split, kv_split], dim=0) + + q_key = f"{prefix}linear_q_down_proj.weight" + kv_key = f"{prefix}linear_kv_down_proj.weight" + + sharded_state_dict[q_key] = make_tp_sharded_tensor_for_checkpoint( + tensor=q_weight, key=q_key, tp_axis=0, prepend_offsets=sharded_offsets + ) + sharded_state_dict[kv_key] = make_tp_sharded_tensor_for_checkpoint( + tensor=kv_weight, key=kv_key, tp_axis=0, prepend_offsets=sharded_offsets + ) + + return sharded_state_dict + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + """Load state dict with automatic unfused->fused conversion.""" + q_key = f"{prefix}linear_q_down_proj.weight" + kv_key = f"{prefix}linear_kv_down_proj.weight" + fused_key = f"{prefix}linear_qkv_down_proj.weight" + + def _as_tensor(x): + return x.data if hasattr(x, 'data') else x + + if fused_key not in state_dict and q_key in state_dict and kv_key in state_dict: + q_weight = _as_tensor(state_dict[q_key]) + kv_weight = _as_tensor(state_dict[kv_key]) + state_dict[fused_key] = torch.cat([q_weight, kv_weight], dim=0) + del state_dict[q_key] + del state_dict[kv_key] + state_dict.pop(f"{prefix}linear_q_down_proj.bias", None) + state_dict.pop(f"{prefix}linear_kv_down_proj.bias", None) + + return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index fd4025de9f7..11c60742742 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -2260,6 +2260,11 @@ class MLATransformerConfig(TransformerConfig): This is only for the dynamic inference backend and requires that Flash MLA is installed.""" + mla_down_proj_fusion: bool = False + """Enable fused q/kv down-projection and fused input layernorm when backend supports. + Otherwise fall back to the unfused MLA. + """ + def __post_init__(self): super().__post_init__() if self.multi_latent_attention and self.apply_rope_fusion and self.rope_type != "yarn": diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index f969473d2a0..2b279f3fd53 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1524,6 +1524,9 @@ def validate_args(args, defaults={}): if args.multi_latent_attention: assert not args.group_query_attention, "Group query attention is mutually exclusive with multi latent attention." + + if args.mla_down_proj_fusion: + assert args.multi_latent_attention, "--mla-down-proj-fusion requires --multi-latent-attention" # MoE latent projections if args.moe_latent_size is not None: @@ -3009,6 +3012,13 @@ def _add_mla_args(parser): help="Mscale all dimensions for YaRN RoPE in multi-latent attention.") group.add_argument('--cache-mla-latents', action='store_true', default=False, help="If set caches the mla down projected latents with mla flash decode.") + group.add_argument( + '--mla-down-proj-fusion', + action='store_true', + default=False, + help="Enable fused q/kv down-projection and fused input layernorm when backend supports. " + "Otherwise fall back to the unfused MLA.", + ) return parser diff --git a/tests/unit_tests/transformer/test_multi_latent_attention.py b/tests/unit_tests/transformer/test_multi_latent_attention.py index 83c65877fbb..51c36268eb4 100644 --- a/tests/unit_tests/transformer/test_multi_latent_attention.py +++ b/tests/unit_tests/transformer/test_multi_latent_attention.py @@ -22,6 +22,7 @@ from megatron.core.transformer.attention import Attention from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.multi_latent_attention import ( + FusedMLASelfAttention, MLASelfAttention, MLASelfAttentionSubmodules, MultiLatentAttention, @@ -106,6 +107,16 @@ def get_mla_self_attn_submodules(linear_qkv_down_proj=None): return submodules +def get_fused_mla_submodules(): + """Get submodules for FusedMLASelfAttention via the mla_down_proj_fusion spec path.""" + submodules = get_gpt_layer_with_transformer_engine_submodules( + multi_latent_attention=True, mla_down_proj_fusion=True + ).self_attention.submodules + assert isinstance(submodules, MLASelfAttentionSubmodules) + assert submodules.linear_qkv_down_proj is not None + return submodules + + backend = TESpecProvider() linear_qkv_down_proj_options = [backend.linear(), backend.column_parallel_linear()] @@ -1543,3 +1554,274 @@ def get_tensor_on_this_rank(tensor): os.environ.clear() os.environ.update(_environ) + + +@pytest.mark.parametrize("rope_type", ('yarn', 'rope')) +class TestFusedMLASelfAttention: + + @pytest.fixture(scope='function', autouse=True) + def setup_and_teardown(self, rope_type): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + self.transformer_config = MLATransformerConfig( + num_layers=2, + hidden_size=12, + num_attention_heads=4, + use_cpu_initialization=True, + q_lora_rank=32, + kv_lora_rank=32, + qk_head_dim=128, + v_head_dim=128, + qk_pos_emb_head_dim=64, + rope_type=rope_type, + rotary_base=10000, + original_max_position_embeddings=32, + ) + self.fused_attention = FusedMLASelfAttention( + self.transformer_config, + get_fused_mla_submodules(), + layer_number=1, + attn_mask_type=AttnMaskType.causal, + ) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def test_constructor(self): + assert isinstance(self.fused_attention, FusedMLASelfAttention) + assert isinstance(self.fused_attention, MLASelfAttention) + assert self.fused_attention.layer_number == 1 + assert hasattr(self.fused_attention, 'linear_qkv_down_proj') + + def test_fused_weight_shape(self): + config = self.transformer_config + expected_out = config.q_lora_rank + config.kv_lora_rank + config.qk_pos_emb_head_dim + weight = self.fused_attention.linear_qkv_down_proj.weight + assert weight.shape[0] == expected_out + assert weight.shape[1] == config.hidden_size + + def test_qkv_down_projection_split(self): + if not is_te_min_version("1.10.0"): + pytest.skip("Requires TE >= 1.10.0") + config = self.transformer_config + self.fused_attention.cuda() + + seq_len, batch = 16, 2 + hidden = torch.randn(seq_len, batch, config.hidden_size).cuda() + q_compressed, kv_combined = self.fused_attention._qkv_down_projection(hidden) + + assert q_compressed.shape == (seq_len, batch, config.q_lora_rank) + assert kv_combined.shape == ( + seq_len, + batch, + config.kv_lora_rank + config.qk_pos_emb_head_dim, + ) + + def test_gpu_forward(self): + if not is_te_min_version("1.10.0"): + pytest.skip("Requires TE >= 1.10.0") + + config = self.fused_attention.config + sequence_length = 32 + micro_batch_size = 2 + + self.fused_attention.cuda() + + hidden_states = torch.ones((sequence_length, micro_batch_size, config.hidden_size)).cuda() + attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() + + output, bias = self.fused_attention(hidden_states, attention_mask) + + assert output.shape[0] == sequence_length + assert output.shape[1] == micro_batch_size + assert output.shape[2] == config.hidden_size + assert bias.shape[0] == config.hidden_size + + def test_gpu_forward_bf16(self): + if not is_te_min_version("1.10.0"): + pytest.skip("Requires TE >= 1.10.0") + + config = self.fused_attention.config + sequence_length = 32 + micro_batch_size = 2 + + self.fused_attention.cuda().bfloat16() + + hidden_states = ( + torch.ones((sequence_length, micro_batch_size, config.hidden_size)).cuda().bfloat16() + ) + attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() + + output, bias = self.fused_attention(hidden_states, attention_mask) + + assert output.shape[0] == sequence_length + assert output.shape[1] == micro_batch_size + assert output.shape[2] == config.hidden_size + assert output.dtype == torch.bfloat16 + + +class TestFusedMLAGradientFlow: + + @pytest.fixture(scope='function', autouse=True) + def setup_and_teardown(self): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + self.transformer_config = MLATransformerConfig( + num_layers=2, + hidden_size=12, + num_attention_heads=4, + use_cpu_initialization=True, + q_lora_rank=32, + kv_lora_rank=32, + qk_head_dim=128, + v_head_dim=128, + qk_pos_emb_head_dim=64, + rope_type="rope", + rotary_base=10000, + original_max_position_embeddings=32, + ) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def test_backward_pass(self): + if not is_te_min_version("1.10.0"): + pytest.skip("Requires TE >= 1.10.0") + + config = self.transformer_config + fused = FusedMLASelfAttention( + config, get_fused_mla_submodules(), layer_number=1, attn_mask_type=AttnMaskType.causal + ) + fused.cuda() + + seq_len, batch = 32, 2 + hidden_states = torch.randn( + seq_len, batch, config.hidden_size, device='cuda', requires_grad=True + ) + attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=bool, device='cuda') + + output, bias = fused(hidden_states, attention_mask) + loss = output.sum() + loss.backward() + + assert fused.linear_qkv_down_proj.weight.grad is not None + assert ( + fused.linear_qkv_down_proj.weight.grad.shape == fused.linear_qkv_down_proj.weight.shape + ) + assert hidden_states.grad is not None + + +class TestFusedMLALoadFromStateDict: + + @pytest.fixture(scope='function', autouse=True) + def setup_and_teardown(self): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + self.transformer_config = MLATransformerConfig( + num_layers=2, + hidden_size=12, + num_attention_heads=4, + use_cpu_initialization=True, + q_lora_rank=32, + kv_lora_rank=32, + qk_head_dim=128, + v_head_dim=128, + qk_pos_emb_head_dim=64, + rope_type="rope", + rotary_base=10000, + original_max_position_embeddings=32, + ) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def test_load_unfused_state_dict(self): + if not is_te_min_version("1.10.0"): + pytest.skip("Requires TE >= 1.10.0") + + unfused = MLASelfAttention( + self.transformer_config, + get_mla_self_attn_submodules(), + layer_number=1, + attn_mask_type=AttnMaskType.causal, + ) + fused = FusedMLASelfAttention( + self.transformer_config, + get_fused_mla_submodules(), + layer_number=1, + attn_mask_type=AttnMaskType.causal, + ) + + unfused_sd = unfused.state_dict() + + q_down_keys = [k for k in unfused_sd if 'linear_q_down_proj' in k] + kv_down_keys = [k for k in unfused_sd if 'linear_kv_down_proj' in k] + assert len(q_down_keys) > 0, "Expected q_down_proj keys in unfused state dict" + assert len(kv_down_keys) > 0, "Expected kv_down_proj keys in unfused state dict" + + fused.load_state_dict(unfused_sd, strict=False) + + config = self.transformer_config + expected_out = config.q_lora_rank + config.kv_lora_rank + config.qk_pos_emb_head_dim + assert fused.linear_qkv_down_proj.weight.shape[0] == expected_out + + q_w = unfused_sd['linear_q_down_proj.weight'] + kv_w = unfused_sd['linear_kv_down_proj.weight'] + expected_fused = torch.cat([q_w, kv_w], dim=0) + torch.testing.assert_close(fused.linear_qkv_down_proj.weight.data, expected_fused) + + def test_sharded_state_dict_splits_back(self): + if not is_te_min_version("1.10.0"): + pytest.skip("Requires TE >= 1.10.0") + + fused = FusedMLASelfAttention( + self.transformer_config, + get_fused_mla_submodules(), + layer_number=1, + attn_mask_type=AttnMaskType.causal, + ) + + sharded_sd = fused.sharded_state_dict(prefix="") + assert any( + 'linear_q_down_proj.weight' in k for k in sharded_sd + ), f"Expected linear_q_down_proj.weight in sharded state dict, got keys: {list(sharded_sd.keys())}" + assert any( + 'linear_kv_down_proj.weight' in k for k in sharded_sd + ), f"Expected linear_kv_down_proj.weight in sharded state dict, got keys: {list(sharded_sd.keys())}" + assert not any( + 'linear_qkv_down_proj.weight' in k for k in sharded_sd + ), f"Unexpected linear_qkv_down_proj.weight in sharded state dict" + + +class TestFusedMLARequiresQLora: + + @pytest.fixture(scope='function', autouse=True) + def setup_and_teardown(self): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def test_raises_without_q_lora_rank(self): + config = MLATransformerConfig( + num_layers=2, + hidden_size=12, + num_attention_heads=4, + use_cpu_initialization=True, + q_lora_rank=None, + kv_lora_rank=32, + qk_head_dim=128, + v_head_dim=128, + qk_pos_emb_head_dim=64, + rope_type="rope", + rotary_base=10000, + original_max_position_embeddings=32, + ) + with pytest.raises(AssertionError, match="q_lora_rank"): + FusedMLASelfAttention( + config, + get_fused_mla_submodules(), + layer_number=1, + attn_mask_type=AttnMaskType.causal, + )