Skip to content

Commit 1129f0a

Browse files
committed
Weight access with getattr in layer norm folding
1 parent 3a6f596 commit 1129f0a

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

transformer_lens/model_bridge/bridge.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -531,14 +531,20 @@ def fold_layer_norm(self, fold_biases=True, center_weights=True):
531531
# Fold ln2 into MLP
532532
if not self.cfg.attn_only:
533533
if fold_biases:
534-
self.blocks[l].mlp.input.bias.data = self.blocks[l].mlp.input.bias.data + (
535-
self.blocks[l].mlp.input.weight.data * self.blocks[l].ln2.bias.data[:, None]
536-
).sum(-2)
534+
getattr(self.blocks[l].mlp, "in").bias.data = getattr(
535+
self.blocks[l].mlp, "in"
536+
).bias.data + (
537+
getattr(self.blocks[l].mlp, "in").weight.data
538+
* self.blocks[l].ln2.bias.data[:, None]
539+
).sum(
540+
-2
541+
)
537542

538543
self.blocks[l].ln2.bias.data = torch.zeros_like(self.blocks[l].ln2.bias)
539544

540-
self.blocks[l].mlp.input.weight.data = (
541-
self.blocks[l].mlp.input.weight.data * self.blocks[l].ln2.weight.data[:, None]
545+
getattr(self.blocks[l].mlp, "in").weight.data = (
546+
getattr(self.blocks[l].mlp, "in").weight.data
547+
* self.blocks[l].ln2.weight.data[:, None]
542548
)
543549

544550
if self.cfg.gated_mlp:
@@ -550,10 +556,10 @@ def fold_layer_norm(self, fold_biases=True, center_weights=True):
550556
self.blocks[l].ln2.weight.data = torch.zeros_like(self.blocks[l].ln2.weight)
551557

552558
if center_weights:
553-
self.blocks[l].mlp.input.weight.data = self.blocks[
554-
l
555-
].mlp.input.weight.data - einops.reduce(
556-
self.blocks[l].mlp.input.weight.data,
559+
getattr(self.blocks[l].mlp, "in").weight.data = getattr(
560+
self.blocks[l].mlp, "in"
561+
).weight.data - einops.reduce(
562+
getattr(self.blocks[l].mlp, "in").weight.data,
557563
"d_model d_mlp -> 1 d_mlp",
558564
"mean",
559565
)

0 commit comments

Comments
 (0)