1111 AttentionBridge ,
1212 BlockBridge ,
1313 EmbeddingBridge ,
14+ LinearBridge ,
1415 MLPBridge ,
1516 NormalizationBridge ,
1617 UnembeddingBridge ,
@@ -28,6 +29,8 @@ def __init__(self, cfg: Any) -> None:
2829 """
2930 super ().__init__ (cfg )
3031
32+ self .default_cfg = {"use_fast" : False }
33+
3134 self .conversion_rules = HookConversionSet (
3235 {
3336 "embed.e" : "transformer.wte.weight" ,
@@ -78,13 +81,30 @@ def __init__(self, cfg: Any) -> None:
7881 # Set up component mapping
7982 self .component_mapping = {
8083 "embed" : EmbeddingBridge (name = "model.embed_tokens" ),
84+ "rotary_emb" : EmbeddingBridge (name = "model.rotary_emb" ),
8185 "blocks" : BlockBridge (
8286 name = "model.layers" ,
8387 submodules = {
8488 "ln1" : NormalizationBridge (name = "input_layernorm" , config = self .cfg ),
89+ "attn" : AttentionBridge (
90+ name = "self_attn" ,
91+ config = self .cfg ,
92+ submodules = {
93+ "q" : LinearBridge (name = "q_proj" ),
94+ "k" : LinearBridge (name = "k_proj" ),
95+ "v" : LinearBridge (name = "v_proj" ),
96+ "o" : LinearBridge (name = "dense" ),
97+ },
98+ ),
99+ # Layer norm 1 and 2 are tied.
85100 "ln2" : NormalizationBridge (name = "input_layernorm" , config = self .cfg ),
86- "attn" : AttentionBridge (name = "self_attn" , config = self .cfg ),
87- "mlp" : MLPBridge (name = "mlp" ),
101+ "mlp" : MLPBridge (
102+ name = "mlp" ,
103+ submodules = {
104+ "in" : LinearBridge (name = "fc1" ),
105+ "out" : LinearBridge (name = "fc2" ),
106+ },
107+ ),
88108 },
89109 ),
90110 "ln_final" : NormalizationBridge (name = "model.final_layernorm" , config = self .cfg ),
0 commit comments