Skip to content

Commit 20a57dd

Browse files
bryce13950jlarson4
andauthored
Qwen3 adapter (#1138)
* Removed all attributes of which directly mapped keys. These attributes are now handled by the component mapping Bridge classes * Remove source keys where they have been made redundant by the bridges * Formatting update * Remove source keys where they have been made redundant by the bridges * created qwen 3 adapter --------- Co-authored-by: jlarson <[email protected]>
1 parent 41e6d4a commit 20a57dd

File tree

16 files changed

+195
-92
lines changed

16 files changed

+195
-92
lines changed

transformer_lens/factories/architecture_adapter_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
Phi3ArchitectureAdapter,
2828
PhiArchitectureAdapter,
2929
Qwen2ArchitectureAdapter,
30+
Qwen3ArchitectureAdapter,
3031
QwenArchitectureAdapter,
3132
T5ArchitectureAdapter,
3233
)
@@ -54,6 +55,7 @@
5455
"Phi3ForCausalLM": Phi3ArchitectureAdapter,
5556
"QwenForCausalLM": QwenArchitectureAdapter,
5657
"Qwen2ForCausalLM": Qwen2ArchitectureAdapter,
58+
"Qwen3ForCausalLM": Qwen3ArchitectureAdapter,
5759
"T5ForConditionalGeneration": T5ArchitectureAdapter,
5860
"NanoGPTForCausalLM": NanogptArchitectureAdapter,
5961
"MinGPTForCausalLM": MingptArchitectureAdapter,

transformer_lens/model_bridge/sources/transformers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def determine_architecture_from_hf_config(hf_config):
145145
"phi3": "Phi3ForCausalLM",
146146
"qwen": "QwenForCausalLM",
147147
"qwen2": "Qwen2ForCausalLM",
148+
"qwen3": "Qwen3ForCausalLM",
148149
"t5": "T5ForConditionalGeneration",
149150
}
150151
if model_type in model_type_mappings:

transformer_lens/model_bridge/supported_architectures/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@
7070
from transformer_lens.model_bridge.supported_architectures.qwen2 import (
7171
Qwen2ArchitectureAdapter,
7272
)
73+
from transformer_lens.model_bridge.supported_architectures.qwen3 import (
74+
Qwen3ArchitectureAdapter,
75+
)
7376
from transformer_lens.model_bridge.supported_architectures.t5 import (
7477
T5ArchitectureAdapter,
7578
)
@@ -97,5 +100,6 @@
97100
"PythiaArchitectureAdapter",
98101
"QwenArchitectureAdapter",
99102
"Qwen2ArchitectureAdapter",
103+
"Qwen3ArchitectureAdapter",
100104
"T5ArchitectureAdapter",
101105
]

transformer_lens/model_bridge/supported_architectures/bert.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,37 +45,30 @@ def __init__(self, cfg: Any) -> None:
4545
tensor_conversion=RearrangeTensorConversion(
4646
"(h d_head) d_model -> h d_head d_model"
4747
),
48-
source_key="bert.encoder.layer.{i}.attention.self.query.weight",
4948
),
5049
"blocks.{i}.attn.k.weight": ParamProcessingConversion(
5150
tensor_conversion=RearrangeTensorConversion(
5251
"(h d_head) d_model -> h d_head d_model"
5352
),
54-
source_key="bert.encoder.layer.{i}.attention.self.key.weight",
5553
),
5654
"blocks.{i}.attn.v.weight": ParamProcessingConversion(
5755
tensor_conversion=RearrangeTensorConversion(
5856
"(h d_head) d_model -> h d_head d_model"
5957
),
60-
source_key="bert.encoder.layer.{i}.attention.self.value.weight",
6158
),
6259
"blocks.{i}.attn.q.bias": ParamProcessingConversion(
6360
tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head"),
64-
source_key="bert.encoder.layer.{i}.attention.self.query.bias",
6561
),
6662
"blocks.{i}.attn.k.bias": ParamProcessingConversion(
6763
tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head"),
68-
source_key="bert.encoder.layer.{i}.attention.self.key.bias",
6964
),
7065
"blocks.{i}.attn.v.bias": ParamProcessingConversion(
7166
tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head"),
72-
source_key="bert.encoder.layer.{i}.attention.self.value.bias",
7367
),
7468
"blocks.{i}.attn.o.weight": ParamProcessingConversion(
7569
tensor_conversion=RearrangeTensorConversion(
7670
"d_model (h d_head) -> h d_head d_model"
7771
),
78-
source_key="bert.encoder.layer.{i}.attention.output.dense.weight",
7972
),
8073
}
8174

transformer_lens/model_bridge/supported_architectures/gemma1.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,17 @@ def __init__(self, cfg: Any) -> None:
4646
),
4747
source_key="model.embed_tokens.weight",
4848
),
49-
"blocks.{i}.attn.q": ParamProcessingConversion(
49+
"blocks.{i}.attn.q.weight": ParamProcessingConversion(
5050
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
51-
source_key="model.layers.{i}.self_attn.q_proj.weight",
5251
),
53-
"blocks.{i}.attn.k": ParamProcessingConversion(
52+
"blocks.{i}.attn.k.weight": ParamProcessingConversion(
5453
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
55-
source_key="model.layers.{i}.self_attn.k_proj.weight",
5654
),
57-
"blocks.{i}.attn.v": ParamProcessingConversion(
55+
"blocks.{i}.attn.v.weight": ParamProcessingConversion(
5856
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
59-
source_key="model.layers.{i}.self_attn.v_proj.weight",
6057
),
61-
"blocks.{i}.attn.o": ParamProcessingConversion(
58+
"blocks.{i}.attn.o.weight": ParamProcessingConversion(
6259
tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
63-
source_key="model.layers.{i}.self_attn.o_proj.weight",
6460
),
6561
}
6662

transformer_lens/model_bridge/supported_architectures/gemma2.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,27 +49,23 @@ def __init__(self, cfg: Any) -> None:
4949
),
5050
source_key="model.embed_tokens.weight",
5151
),
52-
"blocks.{i}.attn.q": ParamProcessingConversion(
52+
"blocks.{i}.attn.q.weight": ParamProcessingConversion(
5353
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
54-
source_key="model.layers.{i}.self_attn.q_proj.weight",
5554
),
56-
"blocks.{i}.attn.k": ParamProcessingConversion(
55+
"blocks.{i}.attn.k.weight": ParamProcessingConversion(
5756
tensor_conversion=RearrangeTensorConversion(
5857
"(n h) m -> n m h",
5958
n=getattr(self.cfg, "n_key_value_heads", self.cfg.n_heads),
6059
),
61-
source_key="model.layers.{i}.self_attn.k_proj.weight",
6260
),
63-
"blocks.{i}.attn.v": ParamProcessingConversion(
61+
"blocks.{i}.attn.v.weight": ParamProcessingConversion(
6462
tensor_conversion=RearrangeTensorConversion(
6563
"(n h) m -> n m h",
6664
n=getattr(self.cfg, "n_key_value_heads", self.cfg.n_heads),
6765
),
68-
source_key="model.layers.{i}.self_attn.v_proj.weight",
6966
),
70-
"blocks.{i}.attn.o": ParamProcessingConversion(
67+
"blocks.{i}.attn.o.weight": ParamProcessingConversion(
7168
tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
72-
source_key="model.layers.{i}.self_attn.o_proj.weight",
7369
),
7470
}
7571

transformer_lens/model_bridge/supported_architectures/gpt_oss.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,17 @@ def __init__(self, cfg: Any) -> None:
3939
# Conversion rules for weight processing/folding
4040
# GPT-OSS uses MoE with batched experts, so we need special handling
4141
self.weight_processing_conversions = {
42-
"blocks.{i}.attn.q": ParamProcessingConversion(
42+
"blocks.{i}.attn.q.weight": ParamProcessingConversion(
4343
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
44-
source_key="model.layers.{i}.self_attn.q_proj.weight",
4544
),
46-
"blocks.{i}.attn.k": ParamProcessingConversion(
45+
"blocks.{i}.attn.k.weight": ParamProcessingConversion(
4746
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
48-
source_key="model.layers.{i}.self_attn.k_proj.weight",
4947
),
50-
"blocks.{i}.attn.v": ParamProcessingConversion(
48+
"blocks.{i}.attn.v.weight": ParamProcessingConversion(
5149
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
52-
source_key="model.layers.{i}.self_attn.v_proj.weight",
5350
),
54-
"blocks.{i}.attn.o": ParamProcessingConversion(
51+
"blocks.{i}.attn.o.weight": ParamProcessingConversion(
5552
tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
56-
source_key="model.layers.{i}.self_attn.o_proj.weight",
5753
),
5854
}
5955

transformer_lens/model_bridge/supported_architectures/gptj.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,17 @@ def __init__(self, cfg: Any) -> None:
3333
self.cfg.attn_only = False
3434

3535
self.weight_processing_conversions = {
36-
"blocks.{i}.attn.q": ParamProcessingConversion(
36+
"blocks.{i}.attn.q.weight": ParamProcessingConversion(
3737
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
38-
source_key="transformer.h.{i}.attn.q_proj.weight",
3938
),
40-
"blocks.{i}.attn.k": ParamProcessingConversion(
39+
"blocks.{i}.attn.k.weight": ParamProcessingConversion(
4140
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
42-
source_key="transformer.h.{i}.attn.k_proj.weight",
4341
),
44-
"blocks.{i}.attn.v": ParamProcessingConversion(
42+
"blocks.{i}.attn.v.weight": ParamProcessingConversion(
4543
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
46-
source_key="transformer.h.{i}.attn.v_proj.weight",
4744
),
48-
"blocks.{i}.attn.o": ParamProcessingConversion(
45+
"blocks.{i}.attn.o.weight": ParamProcessingConversion(
4946
tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
50-
source_key="transformer.h.{i}.attn.out_proj.weight",
5147
),
5248
}
5349

transformer_lens/model_bridge/supported_architectures/llama.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,27 +71,23 @@ def __init__(self, cfg: Any) -> None:
7171
self.cfg.eps_attr = "variance_epsilon"
7272

7373
self.weight_processing_conversions = {
74-
"blocks.{i}.attn.q": ParamProcessingConversion(
74+
"blocks.{i}.attn.q.weight": ParamProcessingConversion(
7575
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
76-
source_key="model.layers.{i}.self_attn.q_proj.weight",
7776
),
78-
"blocks.{i}.attn.k": ParamProcessingConversion(
77+
"blocks.{i}.attn.k.weight": ParamProcessingConversion(
7978
tensor_conversion=RearrangeTensorConversion(
8079
"(n h) m -> n m h",
8180
n=getattr(self.cfg, "n_key_value_heads", self.cfg.n_heads),
8281
),
83-
source_key="model.layers.{i}.self_attn.k_proj.weight",
8482
),
85-
"blocks.{i}.attn.v": ParamProcessingConversion(
83+
"blocks.{i}.attn.v.weight": ParamProcessingConversion(
8684
tensor_conversion=RearrangeTensorConversion(
8785
"(n h) m -> n m h",
8886
n=getattr(self.cfg, "n_key_value_heads", self.cfg.n_heads),
8987
),
90-
source_key="model.layers.{i}.self_attn.v_proj.weight",
9188
),
92-
"blocks.{i}.attn.o": ParamProcessingConversion(
89+
"blocks.{i}.attn.o.weight": ParamProcessingConversion(
9390
tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
94-
source_key="model.layers.{i}.self_attn.o_proj.weight",
9591
),
9692
}
9793

transformer_lens/model_bridge/supported_architectures/mistral.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,25 +45,21 @@ def __init__(self, cfg: Any) -> None:
4545
self.cfg.uses_rms_norm = True
4646

4747
self.weight_processing_conversions = {
48-
"blocks.{i}.attn.q": ParamProcessingConversion(
48+
"blocks.{i}.attn.q.weight": ParamProcessingConversion(
4949
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
50-
source_key="model.layers.{i}.self_attn.q_proj.weight",
5150
),
52-
"blocks.{i}.attn.k": ParamProcessingConversion(
51+
"blocks.{i}.attn.k.weight": ParamProcessingConversion(
5352
tensor_conversion=RearrangeTensorConversion(
5453
"(n h) m -> n m h", n=self.cfg.n_key_value_heads
5554
),
56-
source_key="model.layers.{i}.self_attn.k_proj.weight",
5755
),
58-
"blocks.{i}.attn.v": ParamProcessingConversion(
56+
"blocks.{i}.attn.v.weight": ParamProcessingConversion(
5957
tensor_conversion=RearrangeTensorConversion(
6058
"(n h) m -> n m h", n=self.cfg.n_key_value_heads
6159
),
62-
source_key="model.layers.{i}.self_attn.v_proj.weight",
6360
),
64-
"blocks.{i}.attn.o": ParamProcessingConversion(
61+
"blocks.{i}.attn.o.weight": ParamProcessingConversion(
6562
tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
66-
source_key="model.layers.{i}.self_attn.o_proj.weight",
6763
),
6864
}
6965

0 commit comments

Comments
 (0)