Skip to content

Commit 9287f6d

Browse files
authored
Source transform to use static attention
Differential Revision: D84769599 Pull Request resolved: pytorch#15176
1 parent 4ba2a66 commit 9287f6d

File tree

2 files changed

+80
-13
lines changed

2 files changed

+80
-13
lines changed

examples/models/llama/static_attention.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,39 @@ def __init__(
764764
self.q_norm = torch.nn.Identity()
765765
self.k_norm = torch.nn.Identity()
766766

767+
@classmethod
768+
def from_attention_mha(
769+
cls,
770+
other: AttentionMHA,
771+
split_mha: bool = True,
772+
rms_norm_class=torch.nn.RMSNorm,
773+
**kwargs: Any,
774+
) -> "StaticAttention":
775+
config = ModelArgs(
776+
dim=other.dim,
777+
n_layers=1, # Not used in attention layer
778+
n_heads=other.n_heads,
779+
n_kv_heads=other.n_kv_heads,
780+
head_dim=other.head_dim,
781+
max_batch_size=other.max_batch_size,
782+
max_context_len=other.max_context_len,
783+
attention_qkv_bias=other.attention_qkv_bias,
784+
use_qk_norm=other.use_qk_norm,
785+
qk_norm_before_rope=other.qk_norm_before_rope,
786+
norm_eps=other.q_norm_fn.eps if other.use_qk_norm else 1e-5,
787+
)
788+
789+
instance = cls(
790+
config=config,
791+
layer_id=other.layer_id,
792+
rope=other.rope,
793+
split_mha=split_mha,
794+
**kwargs,
795+
)
796+
instance.load_weights_from_attention_mha(other, rms_norm_class=rms_norm_class)
797+
798+
return instance
799+
767800
def forward(
768801
self,
769802
x: torch.Tensor,
@@ -1059,3 +1092,37 @@ def transfer_weight(linear, conv2d):
10591092
class StaticAttentionMHA(StaticAttention):
10601093
def __init__(self, config: ModelArgs, layer_id: int, rope: Rope, **kwargs: Any):
10611094
super().__init__(config, layer_id, rope, split_mha=False, **kwargs)
1095+
1096+
1097+
def transform_attention_mha_to_static_attention(
1098+
model: nn.Module,
1099+
split_mha: bool = True,
1100+
inplace: bool = True,
1101+
use_conv2d: bool = False,
1102+
use_hf_rope: bool = False,
1103+
**kwargs: Any,
1104+
) -> nn.Module:
1105+
if not inplace:
1106+
import copy
1107+
1108+
model = copy.deepcopy(model)
1109+
1110+
def helper(m):
1111+
for name, child in list(m.named_children()):
1112+
if isinstance(child, AttentionMHA):
1113+
static_attn = StaticAttention.from_attention_mha(
1114+
child, split_mha=split_mha, **kwargs
1115+
)
1116+
# Note: HF RoPE needs to be applied before linear to conv2d
1117+
if use_hf_rope:
1118+
static_attn.adopt_hf_rope()
1119+
if use_conv2d:
1120+
static_attn.linear_to_conv2d()
1121+
1122+
setattr(m, name, static_attn)
1123+
else:
1124+
helper(child)
1125+
1126+
return m
1127+
1128+
return helper(model)

examples/models/llama/tests/test_static_attention.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
StaticAttentionMask,
1515
StaticKCache,
1616
StaticKVCache,
17+
transform_attention_mha_to_static_attention,
1718
)
1819

1920

@@ -76,7 +77,6 @@ def test(
7677
layer_id = 0
7778
rope = Rope(config)
7879
attn_mha = AttentionMHA(config, layer_id, rope).eval()
79-
static_attn = StaticAttention(config, layer_id, rope).eval()
8080
if use_qk_norm:
8181
with torch.no_grad():
8282
attn_mha.q_norm_fn.weight.copy_(
@@ -85,7 +85,9 @@ def test(
8585
attn_mha.k_norm_fn.weight.copy_(
8686
torch.rand(config.head_dim) * 0.2 + 0.9
8787
)
88-
static_attn.load_weights_from_attention_mha(attn_mha)
88+
static_attn = StaticAttention.from_attention_mha(
89+
attn_mha, split_mha=split_mha
90+
).eval()
8991
if adopt_hf_rope:
9092
static_attn.adopt_hf_rope()
9193
if use_conv2d:
@@ -131,8 +133,7 @@ def test_with_cache(self):
131133
layer_id = 0
132134
rope = Rope(config)
133135
attn_mha = AttentionMHA(config, layer_id, rope).eval()
134-
static_attn = StaticAttention(config, layer_id, rope).eval()
135-
static_attn.load_weights_from_attention_mha(attn_mha)
136+
static_attn = StaticAttention.from_attention_mha(attn_mha).eval()
136137
static_attn.adopt_hf_rope()
137138

138139
x = torch.rand(1, config.max_seq_len, config.dim)
@@ -198,17 +199,16 @@ def test_with_style(style):
198199
def _get_test_transformers(self, config, attention_type="static", use_conv2d=False):
199200
mha_transformer = construct_transformer(config).eval()
200201

202+
static_transformer = transform_attention_mha_to_static_attention(
203+
mha_transformer,
204+
split_mha=(attention_type == "static"),
205+
inplace=False,
206+
use_conv2d=use_conv2d,
207+
use_hf_rope=True,
208+
).eval()
209+
201210
config = copy.copy(config)
202211
config.attention_type = attention_type
203-
static_transformer = construct_transformer(config).eval()
204-
static_transformer.load_state_dict(mha_transformer.state_dict(), strict=False)
205-
for mha_layer, static_layer in zip(
206-
mha_transformer.layers, static_transformer.layers
207-
):
208-
static_layer.attention.load_weights_from_attention_mha(mha_layer.attention)
209-
static_layer.attention.adopt_hf_rope()
210-
if use_conv2d:
211-
static_layer.linear_to_conv2d()
212212
config.use_hf_rope = True
213213

214214
return mha_transformer, static_transformer, config

0 commit comments

Comments
 (0)