@@ -1009,16 +1009,41 @@ def matrices_to_log(
1009
1009
)
1010
1010
1011
1011
@torch .no_grad ()
1012
- def log_matrices (
1012
+ def plot_matrices (
1013
1013
self ,
1014
- logger : Run ,
1015
- model : HookedTransformer ,
1014
+ W_E : Float [Tensor , "d_vocab d_model" ], # noqa: F722
1015
+ W_pos : Float [Tensor , "n_ctx d_model" ], # noqa: F722
1016
+ W_U : Float [Tensor , "d_model d_vocab_out" ], # noqa: F722
1017
+ W_Q : Float [Tensor , "n_layers n_heads d_model d_head" ], # noqa: F722
1018
+ W_K : Float [Tensor , "n_layers n_heads d_model d_head" ], # noqa: F722
1019
+ W_V : Float [Tensor , "n_layers n_heads d_model d_head" ], # noqa: F722
1020
+ W_O : Float [Tensor , "n_layers n_heads d_head d_model" ], # noqa: F722# noqa: F722
1021
+ b_U : Optional [Float [Tensor , "d_vocab_out" ]] = None , # noqa: F821
1022
+ b_Q : Optional [Float [Tensor , "n_layers n_heads d_head" ]] = None , # noqa: F722
1023
+ b_K : Optional [Float [Tensor , "n_layers n_heads d_head" ]] = None , # noqa: F722
1024
+ b_V : Optional [Float [Tensor , "n_layers n_heads d_head" ]] = None , # noqa: F722
1025
+ b_O : Optional [Float [Tensor , "n_layers d_model" ]] = None , # noqa: F722
1016
1026
* ,
1017
- unsafe : bool = False ,
1027
+ attention_dir : Literal [ "bidirectional" , "causal" ] = "causal" ,
1018
1028
default_heatmap_kwargs : dict [str , Any ] = {},
1019
- ** kwargs ,
1020
- ):
1021
- matrices = dict (self .matrices_to_log (model , unsafe = unsafe ))
1029
+ ) -> dict [str , go .Figure ]:
1030
+ matrices = dict (
1031
+ self .matrices_to_plot (
1032
+ W_E = W_E ,
1033
+ W_pos = W_pos ,
1034
+ W_U = W_U ,
1035
+ W_Q = W_Q ,
1036
+ W_K = W_K ,
1037
+ W_V = W_V ,
1038
+ W_O = W_O ,
1039
+ b_U = b_U ,
1040
+ b_Q = b_Q ,
1041
+ b_K = b_K ,
1042
+ b_V = b_V ,
1043
+ b_O = b_O ,
1044
+ attention_dir = attention_dir ,
1045
+ )
1046
+ )
1022
1047
if self .use_subplots :
1023
1048
OVs = tuple (name for name , _ in matrices .items () if "U" in name )
1024
1049
QKs = tuple (name for name , _ in matrices .items () if "U" not in name )
@@ -1060,6 +1085,35 @@ def log_matrices(
1060
1085
)
1061
1086
for name , matrix in matrices .items ()
1062
1087
}
1088
+ return figs
1089
+
1090
+ @torch .no_grad ()
1091
+ def log_matrices (
1092
+ self ,
1093
+ logger : Run ,
1094
+ model : HookedTransformer ,
1095
+ * ,
1096
+ unsafe : bool = False ,
1097
+ default_heatmap_kwargs : dict [str , Any ] = {},
1098
+ ** kwargs ,
1099
+ ):
1100
+ self .assert_model_supported (model , unsafe = unsafe )
1101
+ figs = self .plot_matrices (
1102
+ model .W_E ,
1103
+ model .W_pos ,
1104
+ model .W_U ,
1105
+ model .W_Q ,
1106
+ model .W_K ,
1107
+ model .W_V ,
1108
+ model .W_O ,
1109
+ model .b_U ,
1110
+ model .b_Q ,
1111
+ model .b_K ,
1112
+ model .b_V ,
1113
+ model .b_O ,
1114
+ attention_dir = model .cfg .attention_dir ,
1115
+ default_heatmap_kwargs = default_heatmap_kwargs ,
1116
+ )
1063
1117
logger .log (
1064
1118
{encode_4_byte_unicode (k ): v for k , v in figs .items ()},
1065
1119
commit = False ,
0 commit comments