@@ -531,14 +531,20 @@ def fold_layer_norm(self, fold_biases=True, center_weights=True):
531531 # Fold ln2 into MLP
532532 if not self .cfg .attn_only :
533533 if fold_biases :
534- self .blocks [l ].mlp .input .bias .data = self .blocks [l ].mlp .input .bias .data + (
535- self .blocks [l ].mlp .input .weight .data * self .blocks [l ].ln2 .bias .data [:, None ]
536- ).sum (- 2 )
534+ getattr (self .blocks [l ].mlp , "in" ).bias .data = getattr (
535+ self .blocks [l ].mlp , "in"
536+ ).bias .data + (
537+ getattr (self .blocks [l ].mlp , "in" ).weight .data
538+ * self .blocks [l ].ln2 .bias .data [:, None ]
539+ ).sum (
540+ - 2
541+ )
537542
538543 self .blocks [l ].ln2 .bias .data = torch .zeros_like (self .blocks [l ].ln2 .bias )
539544
540- self .blocks [l ].mlp .input .weight .data = (
541- self .blocks [l ].mlp .input .weight .data * self .blocks [l ].ln2 .weight .data [:, None ]
545+ getattr (self .blocks [l ].mlp , "in" ).weight .data = (
546+ getattr (self .blocks [l ].mlp , "in" ).weight .data
547+ * self .blocks [l ].ln2 .weight .data [:, None ]
542548 )
543549
544550 if self .cfg .gated_mlp :
@@ -550,10 +556,10 @@ def fold_layer_norm(self, fold_biases=True, center_weights=True):
550556 self .blocks [l ].ln2 .weight .data = torch .zeros_like (self .blocks [l ].ln2 .weight )
551557
552558 if center_weights :
553- self .blocks [l ].mlp . input . weight .data = self . blocks [
554- l
555- ]. mlp . input .weight .data - einops .reduce (
556- self .blocks [l ].mlp . input .weight .data ,
559+ getattr ( self .blocks [l ].mlp , "in" ). weight .data = getattr (
560+ self . blocks [ l ]. mlp , "in"
561+ ) .weight .data - einops .reduce (
562+ getattr ( self .blocks [l ].mlp , "in" ) .weight .data ,
557563 "d_model d_mlp -> 1 d_mlp" ,
558564 "mean" ,
559565 )
0 commit comments