Skip to content

Commit 92c5b47

Browse files
authored
Create bridge for every module in Qwen 2 (#1061)
1 parent ff89e67 commit 92c5b47

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

transformer_lens/model_bridge/generalized_components/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,4 +288,7 @@ def has_bias(self) -> bool:
288288
raise RuntimeError(
289289
f"Original component not set for {self.name}. Call set_original_component() first."
290290
)
291+
292+
if not hasattr(self.original_component, "bias"):
293+
return False
291294
return self.original_component.bias is not None

transformer_lens/model_bridge/generalized_components/normalization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,11 @@ def forward(
7474
hidden_states = self.hook_normalized(hidden_states / scale)
7575

7676
if not self.config.layer_norm_folding:
77-
if self.config.uses_rms_norm:
78-
# No bias if using RMSNorm
77+
if self.config.uses_rms_norm or not self.has_bias():
78+
# No bias if using RMSNorm or if the original component has no bias
7979
hidden_states = hidden_states * self.weight
8080
else:
81-
# Add bias if using LayerNorm
81+
# Add bias if using LayerNorm and the original component has a bias
8282
hidden_states = hidden_states * self.weight + self.bias
8383

8484
output = self.hook_out(hidden_states)

transformer_lens/model_bridge/supported_architectures/qwen2.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
AttentionBridge,
1212
BlockBridge,
1313
EmbeddingBridge,
14+
LinearBridge,
1415
MLPBridge,
1516
NormalizationBridge,
1617
UnembeddingBridge,
@@ -24,6 +25,10 @@ def __init__(self, cfg: Any) -> None:
2425
"""Initialize the Qwen2 architecture adapter."""
2526
super().__init__(cfg)
2627

28+
self.cfg.default_prepend_bos = False
29+
self.cfg.gated_mlp = True
30+
self.cfg.uses_rms_norm = True
31+
2732
self.conversion_rules = HookConversionSet(
2833
{
2934
"embed.e": "model.embed_tokens.weight",
@@ -65,8 +70,24 @@ def __init__(self, cfg: Any) -> None:
6570
submodules={
6671
"ln1": NormalizationBridge(name="input_layernorm", config=self.cfg),
6772
"ln2": NormalizationBridge(name="post_attention_layernorm", config=self.cfg),
68-
"attn": AttentionBridge(name="self_attn", config=self.cfg),
69-
"mlp": MLPBridge(name="mlp"),
73+
"attn": AttentionBridge(
74+
name="self_attn",
75+
config=self.cfg,
76+
submodules={
77+
"q": LinearBridge(name="q_proj"),
78+
"k": LinearBridge(name="k_proj"),
79+
"v": LinearBridge(name="v_proj"),
80+
"o": LinearBridge(name="o_proj"),
81+
},
82+
),
83+
"mlp": MLPBridge(
84+
name="mlp",
85+
submodules={
86+
"gate": LinearBridge(name="gate_proj"),
87+
"in": LinearBridge(name="up_proj"),
88+
"out": LinearBridge(name="down_proj"),
89+
},
90+
),
7091
},
7192
),
7293
"ln_final": NormalizationBridge(name="model.norm", config=self.cfg),

0 commit comments

Comments
 (0)