@@ -910,6 +910,7 @@ def metric_tracking(term_dic, l_bound, accuracy_bound):
910
910
"7_o" : [],
911
911
"8" : [],
912
912
}
913
+ # %%
913
914
914
915
for i in range (500 ):
915
916
if i % 100 == 0 :
@@ -925,17 +926,17 @@ def metric_tracking(term_dic, l_bound, accuracy_bound):
925
926
optimiser .step ()
926
927
optimiser .zero_grad ()
927
928
928
- torch .save (model_1 , "finetuned_model_test.pth" )
929
+ # torch.save(model_1, "finetuned_model_test.pth")
929
930
for a , b in term_dic .items ():
930
931
b = torch .tensor (b )
931
932
932
- torch .save (term_dic , "term_test.pt" )
933
- wandb .save ("finetuned_model_test.pth" )
934
- wandb .log (term_dic )
933
+ # torch.save(term_dic, "term_test.pt")
934
+ # wandb.save("finetuned_model_test.pth")
935
+ # wandb.log(term_dic)
935
936
wandb .finish ()
936
937
937
938
# %%
938
- data_1 = torch .load ("term .pt" )
939
+ data_1 = torch .load ("term_test .pt" )
939
940
import matplotlib .pyplot as plt
940
941
import numpy as np
941
942
@@ -998,18 +999,20 @@ def plot_zero(data):
998
999
plt .figure (figsize = (10 , 6 ))
999
1000
x = np .arange (500 )
1000
1001
1001
- plt .plot (x , tr_1 , "b-" , label = "l_2 norm " , linewidth = 1 , marker = "." , markersize = 2 )
1002
- plt .plot (x , tr_2 , "g-" , label = "accuracy " , linewidth = 1 , marker = "." , markersize = 2 )
1003
- plt .plot (x , tr_4 , "r-" , label = "accuracy " , linewidth = 1 , marker = "." , markersize = 2 )
1002
+ plt .plot (x , tr_1 , "b-" , label = "term_1 " , linewidth = 1 , marker = "." , markersize = 2 )
1003
+ plt .plot (x , tr_2 , "g-" , label = "term_2 " , linewidth = 1 , marker = "." , markersize = 2 )
1004
+ plt .plot (x , tr_4 , "r-" , label = "term_4 " , linewidth = 1 , marker = "." , markersize = 2 )
1004
1005
plt .plot (
1005
- x , tr_5 , color = "orange" , label = "accuracy " , linewidth = 1 , marker = "." , markersize = 2
1006
+ x , tr_5 , color = "orange" , label = "term_5 " , linewidth = 1 , marker = "." , markersize = 2
1006
1007
)
1007
- plt .plot (x , tr_6 , "y-" , label = "accuracy " , linewidth = 1 , marker = "." , markersize = 2 )
1008
- plt .plot (x , tr_8 , "g -" , label = "accuracy " , linewidth = 1 , marker = "." , markersize = 2 )
1008
+ plt .plot (x , tr_6 , "y-" , label = "term_6 " , linewidth = 1 , marker = "." , markersize = 2 )
1009
+ plt .plot (x , tr_8 , "p -" , label = "term_8 " , linewidth = 1 , marker = "." , markersize = 2 )
1009
1010
1010
- plt .title ("Accuracy as model is finetuned" )
1011
+ plt .title (
1012
+ "Root Mean Square value of terms that are not used in our mechanistic interpretation as model is finetuned"
1013
+ )
1011
1014
plt .xlabel ("Gradient Steps" )
1012
- plt .ylabel ("Accuracy " )
1015
+ plt .ylabel ("RMS value " )
1013
1016
plt .legend ()
1014
1017
plt .grid (True )
1015
1018
plt .show ()
@@ -1022,12 +1025,14 @@ def plot_diag(data, i):
1022
1025
plt .figure (figsize = (10 , 6 ))
1023
1026
x = np .arange (500 )
1024
1027
1025
- plt .plot (x , tr_1_d , "b-" , label = "l_2 norm" , linewidth = 1 , marker = "." , markersize = 2 )
1026
- plt .plot (x , tr_1_o , "g-" , label = "accuracy" , linewidth = 1 , marker = "." , markersize = 2 )
1028
+ plt .plot (x , tr_1_d , "b-" , label = "diagonal" , linewidth = 1 , marker = "." , markersize = 2 )
1029
+ plt .plot (
1030
+ x , tr_1_o , "g-" , label = "off diagonal" , linewidth = 1 , marker = "." , markersize = 2
1031
+ )
1027
1032
1028
- plt .title ("Accuracy as model is finetuned" )
1033
+ plt .title ("Mean of off and on diagonal elements of term_" + str ( i ) )
1029
1034
plt .xlabel ("Gradient Steps" )
1030
- plt .ylabel ("Accuracy " )
1035
+ plt .ylabel ("Mean " )
1031
1036
plt .legend ()
1032
1037
plt .grid (True )
1033
1038
plt .show ()
@@ -1148,17 +1153,61 @@ def get_graphs(fun, model):
1148
1153
1149
1154
# %%
1150
1155
def noise (M , v ):
1151
- return M + torch .randn_like (M ) * v
1156
+ return M + torch .normal (mean = 0 , std = v , size = M .shape ).to (device )
1157
+
1158
+
1159
+ model_2 .to (device )
1152
1160
1153
1161
1154
1162
def add_noise (model , v ):
1163
+
1155
1164
new_raw_terms = []
1156
1165
for i in range (len (raw_terms )):
1157
1166
new_raw_terms .append (noise (raw_terms [i ].detach ().clone (), v ))
1158
1167
new_raw_terms [i ].requires_grad = True
1159
1168
put_in_model (model , new_raw_terms )
1160
1169
1161
1170
1171
+ noise_plot = {"noise" : [], "acc" : [], "loss" : [], "l_b" : [], "a_b" : []}
1172
+ for i in range (25 , 51 ):
1173
+ loss_b = 0
1174
+ acc_b = 0
1175
+ loss = 0
1176
+ acc = 0
1177
+ for j in range (10 ):
1178
+ add_noise (model_2 , i / 1000 )
1179
+ a = loss_bound (model_2 )
1180
+ loss_b += a [- 2 ]
1181
+ acc_b += a [- 1 ]
1182
+ b = sample_acc_and_loss (model_2 , batch_size = 5000 )
1183
+ loss += b [1 ]
1184
+ acc += b [0 ]
1185
+ noise_plot ["l_b" ].append (loss_b / 10 )
1186
+ noise_plot ["a_b" ].append (acc_b / 10 )
1187
+ noise_plot ["loss" ].append (loss / 10 )
1188
+ noise_plot ["acc" ].append (acc / 10 )
1189
+ noise_plot ["noise" ].append (i / 1000 )
1190
+
1191
+
1192
+ plt .plot (noise_plot ["noise" ], noise_plot ["l_b" ], label = "Loss Bound" )
1193
+ plt .plot (noise_plot ["noise" ], noise_plot ["loss" ], label = "Loss" )
1194
+ plt .xlabel ("Noise Level" )
1195
+ plt .ylabel ("Loss" )
1196
+ plt .legend ()
1197
+ plt .grid (True )
1198
+ plt .tight_layout ()
1199
+ plt .show ()
1200
+
1201
+ plt .plot (noise_plot ["noise" ], noise_plot ["a_b" ], label = "Accuracy Bound" )
1202
+ plt .plot (noise_plot ["noise" ], noise_plot ["acc" ], label = "Accuracy" )
1203
+ plt .xlabel ("Noise Level" )
1204
+ plt .ylabel ("Accuracy" )
1205
+ plt .legend ()
1206
+ plt .grid (True )
1207
+ plt .tight_layout ()
1208
+ plt .show ()
1209
+
1210
+
1162
1211
# %%
1163
1212
def display_model (m ):
1164
1213
a = terms (m )
0 commit comments