diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 849718527ed..1880a09f5c6 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -764,6 +764,39 @@ def __init__( self.q_norm = torch.nn.Identity() self.k_norm = torch.nn.Identity() + @classmethod + def from_attention_mha( + cls, + other: AttentionMHA, + split_mha: bool = True, + rms_norm_class=torch.nn.RMSNorm, + **kwargs: Any, + ) -> "StaticAttention": + config = ModelArgs( + dim=other.dim, + n_layers=1, # Not used in attention layer + n_heads=other.n_heads, + n_kv_heads=other.n_kv_heads, + head_dim=other.head_dim, + max_batch_size=other.max_batch_size, + max_context_len=other.max_context_len, + attention_qkv_bias=other.attention_qkv_bias, + use_qk_norm=other.use_qk_norm, + qk_norm_before_rope=other.qk_norm_before_rope, + norm_eps=other.q_norm_fn.eps if other.use_qk_norm else 1e-5, + ) + + instance = cls( + config=config, + layer_id=other.layer_id, + rope=other.rope, + split_mha=split_mha, + **kwargs, + ) + instance.load_weights_from_attention_mha(other, rms_norm_class=rms_norm_class) + + return instance + def forward( self, x: torch.Tensor, @@ -1059,3 +1092,37 @@ def transfer_weight(linear, conv2d): class StaticAttentionMHA(StaticAttention): def __init__(self, config: ModelArgs, layer_id: int, rope: Rope, **kwargs: Any): super().__init__(config, layer_id, rope, split_mha=False, **kwargs) + + +def transform_attention_mha_to_static_attention( + model: nn.Module, + split_mha: bool = True, + inplace: bool = True, + use_conv2d: bool = False, + use_hf_rope: bool = False, + **kwargs: Any, +) -> nn.Module: + if not inplace: + import copy + + model = copy.deepcopy(model) + + def helper(m): + for name, child in list(m.named_children()): + if isinstance(child, AttentionMHA): + static_attn = StaticAttention.from_attention_mha( + child, split_mha=split_mha, **kwargs + ) + # Note: HF RoPE needs to be applied before linear to conv2d + if use_hf_rope: + static_attn.adopt_hf_rope() + if use_conv2d: + static_attn.linear_to_conv2d() + + setattr(m, name, static_attn) + else: + helper(child) + + return m + + return helper(model) diff --git a/examples/models/llama/tests/test_static_attention.py b/examples/models/llama/tests/test_static_attention.py index 8786c70da11..0d407968c0e 100644 --- a/examples/models/llama/tests/test_static_attention.py +++ b/examples/models/llama/tests/test_static_attention.py @@ -14,6 +14,7 @@ StaticAttentionMask, StaticKCache, StaticKVCache, + transform_attention_mha_to_static_attention, ) @@ -76,7 +77,6 @@ def test( layer_id = 0 rope = Rope(config) attn_mha = AttentionMHA(config, layer_id, rope).eval() - static_attn = StaticAttention(config, layer_id, rope).eval() if use_qk_norm: with torch.no_grad(): attn_mha.q_norm_fn.weight.copy_( @@ -85,7 +85,9 @@ def test( attn_mha.k_norm_fn.weight.copy_( torch.rand(config.head_dim) * 0.2 + 0.9 ) - static_attn.load_weights_from_attention_mha(attn_mha) + static_attn = StaticAttention.from_attention_mha( + attn_mha, split_mha=split_mha + ).eval() if adopt_hf_rope: static_attn.adopt_hf_rope() if use_conv2d: @@ -131,8 +133,7 @@ def test_with_cache(self): layer_id = 0 rope = Rope(config) attn_mha = AttentionMHA(config, layer_id, rope).eval() - static_attn = StaticAttention(config, layer_id, rope).eval() - static_attn.load_weights_from_attention_mha(attn_mha) + static_attn = StaticAttention.from_attention_mha(attn_mha).eval() static_attn.adopt_hf_rope() x = torch.rand(1, config.max_seq_len, config.dim) @@ -198,17 +199,16 @@ def test_with_style(style): def _get_test_transformers(self, config, attention_type="static", use_conv2d=False): mha_transformer = construct_transformer(config).eval() + static_transformer = transform_attention_mha_to_static_attention( + mha_transformer, + split_mha=(attention_type == "static"), + inplace=False, + use_conv2d=use_conv2d, + use_hf_rope=True, + ).eval() + config = copy.copy(config) config.attention_type = attention_type - static_transformer = construct_transformer(config).eval() - static_transformer.load_state_dict(mha_transformer.state_dict(), strict=False) - for mha_layer, static_layer in zip( - mha_transformer.layers, static_transformer.layers - ): - static_layer.attention.load_weights_from_attention_mha(mha_layer.attention) - static_layer.attention.adopt_hf_rope() - if use_conv2d: - static_layer.linear_to_conv2d() config.use_hf_rope = True return mha_transformer, static_transformer, config