Skip to content

Commit 090f983

Browse files
authored
Create bridge for every module in Phi 1 (#1055)
1 parent 840dc44 commit 090f983

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

transformer_lens/model_bridge/sources/transformers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def boot(
230230
# Load the tokenizer
231231
tokenizer = tokenizer
232232
default_padding_side = getattr(adapter.cfg, "default_padding_side", None)
233+
use_fast = getattr(adapter.cfg, "use_fast", True)
233234

234235
if tokenizer is not None:
235236
tokenizer = setup_tokenizer(tokenizer, default_padding_side=default_padding_side)
@@ -239,6 +240,7 @@ def boot(
239240
AutoTokenizer.from_pretrained(
240241
model_name,
241242
add_bos_token=True,
243+
use_fast=use_fast,
242244
token=huggingface_token if len(huggingface_token) > 0 else None,
243245
),
244246
default_padding_side=default_padding_side,

transformer_lens/model_bridge/supported_architectures/phi.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
AttentionBridge,
1212
BlockBridge,
1313
EmbeddingBridge,
14+
LinearBridge,
1415
MLPBridge,
1516
NormalizationBridge,
1617
UnembeddingBridge,
@@ -28,6 +29,8 @@ def __init__(self, cfg: Any) -> None:
2829
"""
2930
super().__init__(cfg)
3031

32+
self.default_cfg = {"use_fast": False}
33+
3134
self.conversion_rules = HookConversionSet(
3235
{
3336
"embed.e": "transformer.wte.weight",
@@ -78,13 +81,30 @@ def __init__(self, cfg: Any) -> None:
7881
# Set up component mapping
7982
self.component_mapping = {
8083
"embed": EmbeddingBridge(name="model.embed_tokens"),
84+
"rotary_emb": EmbeddingBridge(name="model.rotary_emb"),
8185
"blocks": BlockBridge(
8286
name="model.layers",
8387
submodules={
8488
"ln1": NormalizationBridge(name="input_layernorm", config=self.cfg),
89+
"attn": AttentionBridge(
90+
name="self_attn",
91+
config=self.cfg,
92+
submodules={
93+
"q": LinearBridge(name="q_proj"),
94+
"k": LinearBridge(name="k_proj"),
95+
"v": LinearBridge(name="v_proj"),
96+
"o": LinearBridge(name="dense"),
97+
},
98+
),
99+
# Layer norm 1 and 2 are tied.
85100
"ln2": NormalizationBridge(name="input_layernorm", config=self.cfg),
86-
"attn": AttentionBridge(name="self_attn", config=self.cfg),
87-
"mlp": MLPBridge(name="mlp"),
101+
"mlp": MLPBridge(
102+
name="mlp",
103+
submodules={
104+
"in": LinearBridge(name="fc1"),
105+
"out": LinearBridge(name="fc2"),
106+
},
107+
),
88108
},
89109
),
90110
"ln_final": NormalizationBridge(name="model.final_layernorm", config=self.cfg),

0 commit comments

Comments
 (0)