@@ -443,43 +443,43 @@ def to_single_str_token(self, int_token: int) -> str:
443443 @property
444444 def W_K (self ) -> Float [torch .Tensor , "n_layers n_heads d_model d_head" ]:
445445 """Stack the key weights across all layers."""
446- return torch .stack ([block .attn .W_K .weight for block in self .blocks ], dim = 0 )
446+ return torch .stack ([block .attn .k .weight for block in self .blocks ], dim = 0 )
447447
448448 @property
449449 def W_Q (self ) -> Float [torch .Tensor , "n_layers n_heads d_model d_head" ]:
450450 """Stack the query weights across all layers."""
451- return torch .stack ([block .attn .W_Q .weight for block in self .blocks ], dim = 0 )
451+ return torch .stack ([block .attn .q .weight for block in self .blocks ], dim = 0 )
452452
453453 @property
454454 def W_V (self ) -> Float [torch .Tensor , "n_layers n_heads d_model d_head" ]:
455455 """Stack the value weights across all layers."""
456- return torch .stack ([block .attn .W_V .weight for block in self .blocks ], dim = 0 )
456+ return torch .stack ([block .attn .v .weight for block in self .blocks ], dim = 0 )
457457
458458 @property
459459 def W_O (self ) -> Float [torch .Tensor , "n_layers n_heads d_head d_model" ]:
460460 """Stack the attn output weights across all layers."""
461- return torch .stack ([block .attn .W_O .weight for block in self .blocks ], dim = 0 )
461+ return torch .stack ([block .attn .o .weight for block in self .blocks ], dim = 0 )
462462
463463 @property
464464 def W_in (self ) -> Float [torch .Tensor , "n_layers d_model d_mlp" ]:
465465 """Stack the MLP input weights across all layers."""
466- return torch .stack ([block .mlp . W_in .weight for block in self .blocks ], dim = 0 )
466+ return torch .stack ([getattr ( block .mlp , "in" ) .weight for block in self .blocks ], dim = 0 )
467467
468468 @property
469469 def W_gate (self ) -> Union [Float [torch .Tensor , "n_layers d_model d_mlp" ], None ]:
470470 """Stack the MLP gate weights across all layers.
471471
472472 Only works for models with gated MLPs.
473473 """
474- if self .cfg . gated_mlp :
475- return torch .stack ([block .mlp .W_gate .weight for block in self .blocks ], dim = 0 )
474+ if getattr ( self .cfg , " gated_mlp" , False ) :
475+ return torch .stack ([block .mlp .gate .weight for block in self .blocks ], dim = 0 )
476476 else :
477477 return None
478478
479479 @property
480480 def W_out (self ) -> Float [torch .Tensor , "n_layers d_mlp d_model" ]:
481481 """Stack the MLP output weights across all layers."""
482- return torch .stack ([block .mlp .W_out .weight for block in self .blocks ], dim = 0 )
482+ return torch .stack ([block .mlp .out .weight for block in self .blocks ], dim = 0 )
483483
484484 @property
485485 def b_K (self ) -> Float [torch .Tensor , "n_layers n_heads d_head" ]:
0 commit comments