Skip to content

Commit ac73820

Browse files
authored
updated property access (#1026)
* updated property access * removed extra function
1 parent 78c8c84 commit ac73820

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

transformer_lens/model_bridge/bridge.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -443,43 +443,43 @@ def to_single_str_token(self, int_token: int) -> str:
443443
@property
444444
def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
445445
"""Stack the key weights across all layers."""
446-
return torch.stack([block.attn.W_K.weight for block in self.blocks], dim=0)
446+
return torch.stack([block.attn.k.weight for block in self.blocks], dim=0)
447447

448448
@property
449449
def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
450450
"""Stack the query weights across all layers."""
451-
return torch.stack([block.attn.W_Q.weight for block in self.blocks], dim=0)
451+
return torch.stack([block.attn.q.weight for block in self.blocks], dim=0)
452452

453453
@property
454454
def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
455455
"""Stack the value weights across all layers."""
456-
return torch.stack([block.attn.W_V.weight for block in self.blocks], dim=0)
456+
return torch.stack([block.attn.v.weight for block in self.blocks], dim=0)
457457

458458
@property
459459
def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]:
460460
"""Stack the attn output weights across all layers."""
461-
return torch.stack([block.attn.W_O.weight for block in self.blocks], dim=0)
461+
return torch.stack([block.attn.o.weight for block in self.blocks], dim=0)
462462

463463
@property
464464
def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
465465
"""Stack the MLP input weights across all layers."""
466-
return torch.stack([block.mlp.W_in.weight for block in self.blocks], dim=0)
466+
return torch.stack([getattr(block.mlp, "in").weight for block in self.blocks], dim=0)
467467

468468
@property
469469
def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]:
470470
"""Stack the MLP gate weights across all layers.
471471
472472
Only works for models with gated MLPs.
473473
"""
474-
if self.cfg.gated_mlp:
475-
return torch.stack([block.mlp.W_gate.weight for block in self.blocks], dim=0)
474+
if getattr(self.cfg, "gated_mlp", False):
475+
return torch.stack([block.mlp.gate.weight for block in self.blocks], dim=0)
476476
else:
477477
return None
478478

479479
@property
480480
def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]:
481481
"""Stack the MLP output weights across all layers."""
482-
return torch.stack([block.mlp.W_out.weight for block in self.blocks], dim=0)
482+
return torch.stack([block.mlp.out.weight for block in self.blocks], dim=0)
483483

484484
@property
485485
def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:

0 commit comments

Comments
 (0)