@@ -862,57 +862,77 @@ def sample_acc_and_loss(model, batch_size=15000):
862
862
optimiser = torch .optim .AdamW (
863
863
model_1 .parameters (), lr = 2e-3 , betas = (0.9 , 0.999 ), weight_decay = 1.0
864
864
)
865
+
866
+
865
867
# %%
868
+ @torch .no_grad ()
869
+ def metric_tracking (term_dic , l_bound , accuracy_bound ):
870
+ (acc , loss ) = sample_acc_and_loss (model_1 , batch_size = 5000 )
871
+ (term_0 , term_1 , term_2 , term_3 , term_4 , term_5 , term_6 , term_7 , term_8 ) = terms (
872
+ model_1
873
+ )
874
+ term_dic ["l_b" ].append (l_bound )
875
+ term_dic ["a_b" ].append (accuracy_bound )
876
+ term_dic ["l_a" ].append (loss )
877
+ term_dic ["a_a" ].append (acc )
878
+ term_dic ["0_d" ].append (((term_0 [index_0_d ])).mean ())
879
+ term_dic ["0_o" ].append (((term_0 [index_0_o ])).mean ())
880
+ term_dic ["1" ].append (((term_1 [causal_mask ]) ** 2 ).mean ().sqrt ())
881
+ term_dic ["2" ].append (((term_2 [causal_mask ]) ** 2 ).mean ().sqrt ())
882
+ term_dic ["3_d" ].append (((term_3 [index_3_d ])).mean ())
883
+ term_dic ["3_o" ].append (((term_3 [index_3_o ])).mean ())
884
+ term_dic ["4" ].append (((term_4 [causal_mask ]) ** 2 ).mean ().sqrt ())
885
+ term_dic ["5" ].append (((term_5 ) ** 2 ).mean ().sqrt ())
886
+ term_dic ["6" ].append (((term_6 ) ** 2 ).mean ().sqrt ())
887
+ term_dic ["7_d" ].append (((term_7 [index_7_d ])).mean ())
888
+ term_dic ["7_o" ].append (((term_7 [index_7_o ])).mean ())
889
+ term_dic ["8" ].append (((term_8 ) ** 2 ).mean ().sqrt ())
890
+
891
+
892
+ import wandb
893
+
894
+ wandb .init (project = "induction_head_finetune_results" )
895
+ term_dic = {
896
+ "l_b" : [],
897
+ "a_b" : [],
898
+ "l_a" : [],
899
+ "a_a" : [],
900
+ "0_d" : [],
901
+ "0_o" : [],
902
+ "1" : [],
903
+ "2" : [],
904
+ "3_d" : [],
905
+ "3_o" : [],
906
+ "4" : [],
907
+ "5" : [],
908
+ "6" : [],
909
+ "7_d" : [],
910
+ "7_o" : [],
911
+ "8" : [],
912
+ }
913
+
866
914
for i in range (500 ):
915
+ if i % 100 == 0 :
916
+ torch .save (model_1 , f"finetuned_model_{ i } .pth" )
867
917
print (i )
868
918
a = loss_bound (model_1 )
869
919
l_bound = a [- 2 ]
870
920
accuracy_bound = a [- 1 ]
871
- ( acc , loss ) = sample_acc_and_loss ( model_1 , batch_size = 5000 )
921
+
872
922
print (l_bound )
873
- (term_0 , term_1 , term_2 , term_3 , term_4 , term_5 , term_6 , term_7 , term_8 ) = terms (
874
- model_1
875
- )
876
- loss_b .append (l_bound )
877
- acc_b .append (accuracy_bound )
878
- loss_a .append (loss )
879
- acc_a .append (acc )
880
- t_0_d .append (((term_0 [index_0_d ])).mean ())
881
- t_0_o .append (((term_0 [index_0_o ])).mean ())
882
- t_1 .append (((term_1 [causal_mask ]) ** 2 ).mean ().sqrt ())
883
- t_2 .append (((term_2 [causal_mask ]) ** 2 ).mean ().sqrt ())
884
- t_3_d .append (((term_3 [index_3_d ])).mean ())
885
- t_3_o .append (((term_3 [index_3_o ])).mean ())
886
- t_4 .append (((term_4 [causal_mask ]) ** 2 ).mean ().sqrt ())
887
- t_5 .append (((term_5 ) ** 2 ).mean ().sqrt ())
888
- t_6 .append (((term_6 ) ** 2 ).mean ().sqrt ())
889
- t_7_d .append (((term_7 [index_7_d ])).mean ())
890
- t_7_o .append (((term_7 [index_7_o ])).mean ())
891
- t_8 .append (((term_8 ) ** 2 ).mean ().sqrt ())
892
- term_dic = {
893
- "l_b" : torch .tensor (loss_b ),
894
- "a_b" : torch .tensor (acc_b ),
895
- "l_a" : torch .tensor (loss_a ),
896
- "a_a" : torch .tensor (acc_a ),
897
- "0_d" : torch .tensor (t_0_d ),
898
- "0_o" : torch .tensor (t_0_o ),
899
- "1" : torch .tensor (t_1 ),
900
- "2" : torch .tensor (t_2 ),
901
- "3_d" : torch .tensor (t_3_d ),
902
- "3_o" : torch .tensor (t_3_o ),
903
- "4" : torch .tensor (t_4 ),
904
- "5" : torch .tensor (t_5 ),
905
- "6" : torch .tensor (t_6 ),
906
- "7_d" : torch .tensor (t_7_d ),
907
- "7_o" : torch .tensor (t_7_o ),
908
- "8" : torch .tensor (t_8 ),
909
- }
910
- # torch.save(term_dic,"term.pt")
911
923
l_bound .backward ()
924
+ metric_tracking (term_dic , l_bound .detach ().cpu (), accuracy_bound .detach ().cpu ())
912
925
optimiser .step ()
913
926
optimiser .zero_grad ()
914
- # torch.save(model_1,"finetuned_model.pth")
915
927
928
+ torch .save (model_1 , "finetuned_model_test.pth" )
929
+ for a , b in term_dic .items ():
930
+ b = torch .tensor (b )
931
+
932
+ torch .save (term_dic , "term_test.pt" )
933
+ wandb .save ("finetuned_model_test.pth" )
934
+ wandb .log (term_dic )
935
+ wandb .finish ()
916
936
917
937
# %%
918
938
data_1 = torch .load ("term.pt" )
@@ -1139,4 +1159,53 @@ def add_noise(model, v):
1139
1159
put_in_model (model , new_raw_terms )
1140
1160
1141
1161
1162
+ # %%
1163
+ def display_model (m ):
1164
+ a = terms (m )
1165
+ plt .figure (figsize = (8 , 6 ))
1166
+ plt .imshow (
1167
+ (a [0 ].mean (dim = (1 , 3 ))).detach ().cpu ().numpy (), cmap = "viridis" , aspect = "auto"
1168
+ )
1169
+ plt .colorbar (label = "Value" ) # Add color scale label
1170
+
1171
+ # Axis labels
1172
+ plt .xlabel ("Key Position" )
1173
+ plt .ylabel ("Query Position" )
1174
+ plt .title ("term_0.mean(dim=(1,3))" )
1175
+
1176
+ plt .grid (False )
1177
+ plt .tight_layout ()
1178
+ plt .show ()
1179
+
1180
+ plt .figure (figsize = (8 , 6 ))
1181
+ plt .imshow (
1182
+ (a [3 ].mean (dim = (0 , 2 ))).detach ().cpu ().numpy (), cmap = "viridis" , aspect = "auto"
1183
+ )
1184
+ plt .colorbar (label = "Value" ) # Add color scale label
1185
+
1186
+ # Axis labels
1187
+ plt .xlabel ("Key Token" )
1188
+ plt .ylabel ("Query Token" )
1189
+ plt .title ("term_3.mean(dim=(0,2))" )
1190
+
1191
+ plt .grid (False )
1192
+ plt .tight_layout ()
1193
+ plt .show ()
1194
+
1195
+ plt .figure (figsize = (8 , 6 ))
1196
+ plt .imshow (
1197
+ (a [7 ].mean (dim = (0 ))).detach ().cpu ().numpy (), cmap = "viridis" , aspect = "auto"
1198
+ )
1199
+ plt .colorbar (label = "Value" ) # Add color scale label
1200
+
1201
+ # Axis labels
1202
+ plt .xlabel ("Key Token" )
1203
+ plt .ylabel ("Ouput Token" )
1204
+ plt .title ("term_7.mean(dim=(0))" )
1205
+
1206
+ plt .grid (False )
1207
+ plt .tight_layout ()
1208
+ plt .show ()
1209
+
1210
+
1142
1211
# %%
0 commit comments