@@ -866,10 +866,10 @@ def sample_acc_and_loss(model, batch_size=15000):
866
866
867
867
# %%
868
868
@torch .no_grad ()
869
- def metric_tracking (term_dic , l_bound , accuracy_bound ):
870
- (acc , loss ) = sample_acc_and_loss (model_1 , batch_size = 5000 )
869
+ def metric_tracking (model , term_dic , l_bound , accuracy_bound ):
870
+ (acc , loss ) = sample_acc_and_loss (model , batch_size = 5000 )
871
871
(term_0 , term_1 , term_2 , term_3 , term_4 , term_5 , term_6 , term_7 , term_8 ) = terms (
872
- model_1
872
+ model
873
873
)
874
874
term_dic ["l_b" ].append (l_bound )
875
875
term_dic ["a_b" ].append (accuracy_bound )
@@ -941,12 +941,12 @@ def metric_tracking(term_dic, l_bound, accuracy_bound):
941
941
import numpy as np
942
942
943
943
944
- def plot_loss (data ):
944
+ def plot_loss (data , n , m , rand ):
945
945
l_b = data ["l_b" ].detach ().cpu ()
946
946
l_a = data ["l_a" ].detach ().cpu ()
947
947
948
948
plt .figure (figsize = (10 , 6 ))
949
- x = np .arange (500 )
949
+ x = np .arange (0 , n , m )
950
950
951
951
# Plot both lines
952
952
plt .plot (x , l_b , "r-" , label = "loss bound" , linewidth = 1 , marker = "." , markersize = 2 )
@@ -955,8 +955,9 @@ def plot_loss(data):
955
955
)
956
956
957
957
# Add horizontal line at ln(26)
958
- plt .axhline (y = np .log (26 ), color = "grey" , linestyle = "--" , label = "ln(26)" )
959
- plt .text (x [- 1 ], np .log (26 ), "ln(26)" , verticalalignment = "bottom" )
958
+ if rand :
959
+ plt .axhline (y = np .log (26 ), color = "grey" , linestyle = "--" , label = "ln(26)" )
960
+ plt .text (x [- 1 ], np .log (26 ), "ln(26)" , verticalalignment = "bottom" )
960
961
961
962
# Set scale and labels
962
963
plt .yscale ("log" )
@@ -968,18 +969,22 @@ def plot_loss(data):
968
969
plt .show ()
969
970
970
971
971
- def plot_accuracy (data ):
972
+ def plot_accuracy (data , n , m , rand ):
972
973
a_b = data ["a_b" ].detach ().cpu ()
973
974
a_a = data ["a_a" ].detach ().cpu ()
974
975
975
976
plt .figure (figsize = (10 , 6 ))
976
- x = np .arange (500 )
977
+ x = np .arange (0 , n , m )
977
978
978
979
plt .plot (
979
980
x , a_b , "b-" , label = "accuracy bound" , linewidth = 1 , marker = "." , markersize = 2
980
981
)
981
982
plt .plot (x , a_a , "g-" , label = "accuracy" , linewidth = 1 , marker = "." , markersize = 2 )
982
983
984
+ if rand :
985
+ plt .axhline (y = 1 / 26 , color = "grey" , linestyle = "--" , label = "ln(26)" )
986
+ plt .text (x [- 1 ], 1 / 26 , "1/26" , verticalalignment = "bottom" )
987
+
983
988
plt .title ("Accuracy as model is finetuned" )
984
989
plt .xlabel ("Gradient Steps" )
985
990
plt .ylabel ("Accuracy" )
@@ -988,7 +993,7 @@ def plot_accuracy(data):
988
993
plt .show ()
989
994
990
995
991
- def plot_zero (data ):
996
+ def plot_zero (data , n , m ):
992
997
tr_1 = data ["1" ].detach ().cpu ()
993
998
tr_2 = data ["2" ].detach ().cpu ()
994
999
tr_4 = data ["4" ].detach ().cpu ()
@@ -997,7 +1002,7 @@ def plot_zero(data):
997
1002
tr_8 = data ["8" ].detach ().cpu ()
998
1003
999
1004
plt .figure (figsize = (10 , 6 ))
1000
- x = np .arange (500 )
1005
+ x = np .arange (0 , n , m )
1001
1006
1002
1007
plt .plot (x , tr_1 , "b-" , label = "term_1" , linewidth = 1 , marker = "." , markersize = 2 )
1003
1008
plt .plot (x , tr_2 , "g-" , label = "term_2" , linewidth = 1 , marker = "." , markersize = 2 )
@@ -1018,12 +1023,12 @@ def plot_zero(data):
1018
1023
plt .show ()
1019
1024
1020
1025
1021
- def plot_diag (data , i ):
1026
+ def plot_diag (data , i , n , m ):
1022
1027
tr_1_d = data [str (i ) + "_d" ].detach ().cpu ()
1023
1028
tr_1_o = data [str (i ) + "_o" ].detach ().cpu ()
1024
1029
1025
1030
plt .figure (figsize = (10 , 6 ))
1026
- x = np .arange (500 )
1031
+ x = np .arange (0 , n , m )
1027
1032
1028
1033
plt .plot (x , tr_1_d , "b-" , label = "diagonal" , linewidth = 1 , marker = "." , markersize = 2 )
1029
1034
plt .plot (
@@ -1078,76 +1083,46 @@ def l_2(model):
1078
1083
1079
1084
1080
1085
# %%
1081
- def get_graphs (fun , model ):
1082
- t_0_d = []
1083
- t_0_o = []
1084
- t_1 = []
1085
- t_2 = []
1086
- t_3_d = []
1087
- t_3_o = []
1088
- t_4 = []
1089
- t_5 = []
1090
- t_6 = []
1091
- t_7_d = []
1092
- t_7_o = []
1093
- t_8 = []
1094
- loss_b = []
1095
- acc_b = []
1096
- loss_a = []
1097
- acc_a = []
1086
+ def get_graphs (fun , model , term_name , model_name ):
1087
+ term_dic = {
1088
+ "l_b" : [],
1089
+ "a_b" : [],
1090
+ "l_a" : [],
1091
+ "a_a" : [],
1092
+ "0_d" : [],
1093
+ "0_o" : [],
1094
+ "1" : [],
1095
+ "2" : [],
1096
+ "3_d" : [],
1097
+ "3_o" : [],
1098
+ "4" : [],
1099
+ "5" : [],
1100
+ "6" : [],
1101
+ "7_d" : [],
1102
+ "7_o" : [],
1103
+ "8" : [],
1104
+ }
1098
1105
optimiser = torch .optim .AdamW (
1099
1106
model_1 .parameters (), lr = 2e-3 , betas = (0.9 , 0.999 ), weight_decay = 1.0
1100
1107
)
1101
- for i in range (500 ):
1108
+ for i in range (5000 ):
1102
1109
print (i )
1103
- a = loss_bound (model )
1104
- l_bound = a [- 2 ]
1105
- accuracy_bound = a [- 1 ]
1106
- (acc , loss ) = sample_acc_and_loss (model , batch_size = 5000 )
1107
- print (l_bound )
1108
-
1109
- (term_0 , term_1 , term_2 , term_3 , term_4 , term_5 , term_6 , term_7 , term_8 ) = (
1110
- terms (model )
1111
- )
1112
- loss_b .append (l_bound )
1113
- acc_b .append (accuracy_bound )
1114
- loss_a .append (loss )
1115
- acc_a .append (acc )
1116
- t_0_d .append (((term_0 [index_0_d ])).mean ())
1117
- t_0_o .append (((term_0 [index_0_o ])).mean ())
1118
- t_1 .append (((term_1 [causal_mask ]) ** 2 ).mean ().sqrt ())
1119
- t_2 .append (((term_2 [causal_mask ]) ** 2 ).mean ().sqrt ())
1120
- t_3_d .append (((term_3 [index_3_d ])).mean ())
1121
- t_3_o .append (((term_3 [index_3_o ])).mean ())
1122
- t_4 .append (((term_4 [causal_mask ]) ** 2 ).mean ().sqrt ())
1123
- t_5 .append (((term_5 ) ** 2 ).mean ().sqrt ())
1124
- t_6 .append (((term_6 ) ** 2 ).mean ().sqrt ())
1125
- t_7_d .append (((term_7 [index_7_d ])).mean ())
1126
- t_7_o .append (((term_7 [index_7_o ])).mean ())
1127
- t_8 .append (((term_8 ) ** 2 ).mean ().sqrt ())
1128
- term_dic = {
1129
- "l_b" : torch .tensor (loss_b ),
1130
- "a_b" : torch .tensor (acc_b ),
1131
- "l_a" : torch .tensor (loss_a ),
1132
- "a_a" : torch .tensor (acc_a ),
1133
- "0_d" : torch .tensor (t_0_d ),
1134
- "0_o" : torch .tensor (t_0_o ),
1135
- "1" : torch .tensor (t_1 ),
1136
- "2" : torch .tensor (t_2 ),
1137
- "3_d" : torch .tensor (t_3_d ),
1138
- "3_o" : torch .tensor (t_3_o ),
1139
- "4" : torch .tensor (t_4 ),
1140
- "5" : torch .tensor (t_5 ),
1141
- "6" : torch .tensor (t_6 ),
1142
- "7_d" : torch .tensor (t_7_d ),
1143
- "7_o" : torch .tensor (t_7_o ),
1144
- "8" : torch .tensor (t_8 ),
1145
- }
1146
1110
a_loss = fun (model )
1147
1111
print (a_loss )
1148
1112
a_loss .backward ()
1149
1113
optimiser .step ()
1150
1114
optimiser .zero_grad ()
1115
+ if i % 100 == 0 :
1116
+ a = loss_bound (model )
1117
+ l_bound = a [- 2 ]
1118
+ accuracy_bound = a [- 1 ]
1119
+ metric_tracking (
1120
+ model , term_dic , l_bound .detach ().cpu (), accuracy_bound .detach ().cpu ()
1121
+ )
1122
+ print (l_bound )
1123
+ display_model (model )
1124
+ torch .save (term_dic , term_name )
1125
+ torch .save (model , model_name )
1151
1126
return term_dic
1152
1127
1153
1128
@@ -1168,27 +1143,67 @@ def add_noise(model, v):
1168
1143
put_in_model (model , new_raw_terms )
1169
1144
1170
1145
1171
- noise_plot = {"noise" : [], "acc" : [], "loss" : [], "l_b" : [], "a_b" : []}
1146
+ @torch .no_grad ()
1147
+ def noise_metric_tracking (term_dic , l_bound , accuracy_bound , noise ):
1148
+ (acc , loss ) = sample_acc_and_loss (model_2 , batch_size = 5000 )
1149
+ (term_0 , term_1 , term_2 , term_3 , term_4 , term_5 , term_6 , term_7 , term_8 ) = terms (
1150
+ model_2
1151
+ )
1152
+ term_dic ["noise" ].append (noise )
1153
+ term_dic ["l_b" ].append (l_bound )
1154
+ term_dic ["a_b" ].append (accuracy_bound )
1155
+ term_dic ["l_a" ].append (loss )
1156
+ term_dic ["a_a" ].append (acc )
1157
+ term_dic ["0_d" ].append (((term_0 [index_0_d ])).mean ())
1158
+ term_dic ["0_o" ].append (((term_0 [index_0_o ])).mean ())
1159
+ term_dic ["1" ].append (((term_1 [causal_mask ]) ** 2 ).mean ().sqrt ())
1160
+ term_dic ["2" ].append (((term_2 [causal_mask ]) ** 2 ).mean ().sqrt ())
1161
+ term_dic ["3_d" ].append (((term_3 [index_3_d ])).mean ())
1162
+ term_dic ["3_o" ].append (((term_3 [index_3_o ])).mean ())
1163
+ term_dic ["4" ].append (((term_4 [causal_mask ]) ** 2 ).mean ().sqrt ())
1164
+ term_dic ["5" ].append (((term_5 ) ** 2 ).mean ().sqrt ())
1165
+ term_dic ["6" ].append (((term_6 ) ** 2 ).mean ().sqrt ())
1166
+ term_dic ["7_d" ].append (((term_7 [index_7_d ])).mean ())
1167
+ term_dic ["7_o" ].append (((term_7 [index_7_o ])).mean ())
1168
+ term_dic ["8" ].append (((term_8 ) ** 2 ).mean ().sqrt ())
1169
+
1170
+
1171
+ noise_data = {
1172
+ "noise" : [],
1173
+ "l_b" : [],
1174
+ "a_b" : [],
1175
+ "l_a" : [],
1176
+ "a_a" : [],
1177
+ "0_d" : [],
1178
+ "0_o" : [],
1179
+ "1" : [],
1180
+ "2" : [],
1181
+ "3_d" : [],
1182
+ "3_o" : [],
1183
+ "4" : [],
1184
+ "5" : [],
1185
+ "6" : [],
1186
+ "7_d" : [],
1187
+ "7_o" : [],
1188
+ "8" : [],
1189
+ }
1172
1190
for i in range (25 , 51 ):
1173
- loss_b = 0
1174
- acc_b = 0
1175
- loss = 0
1176
- acc = 0
1191
+ print (i )
1177
1192
for j in range (10 ):
1178
1193
add_noise (model_2 , i / 1000 )
1179
1194
a = loss_bound (model_2 )
1180
- loss_b += a [- 2 ]
1181
- acc_b += a [- 1 ]
1182
- b = sample_acc_and_loss (model_2 , batch_size = 5000 )
1183
- loss += b [1 ]
1184
- acc += b [0 ]
1185
- noise_plot ["l_b" ].append (loss_b / 10 )
1186
- noise_plot ["a_b" ].append (acc_b / 10 )
1187
- noise_plot ["loss" ].append (loss / 10 )
1188
- noise_plot ["acc" ].append (acc / 10 )
1189
- noise_plot ["noise" ].append (i / 1000 )
1195
+ l_bound = a [- 2 ]
1196
+ accuracy_bound = a [- 1 ]
1197
+
1198
+ print (l_bound )
1199
+ noise_metric_tracking (
1200
+ noise_data , l_bound .detach ().cpu (), accuracy_bound .detach ().cpu (), i / 1000
1201
+ )
1202
+ torch .save (noise_data , "noise_term.pt" )
1203
+ torch .save (noise_data , "noise_term.pt" )
1190
1204
1191
1205
1206
+ # %%
1192
1207
plt .plot (noise_plot ["noise" ], noise_plot ["l_b" ], label = "Loss Bound" )
1193
1208
plt .plot (noise_plot ["noise" ], noise_plot ["loss" ], label = "Loss" )
1194
1209
plt .xlabel ("Noise Level" )
@@ -1250,7 +1265,7 @@ def display_model(m):
1250
1265
# Axis labels
1251
1266
plt .xlabel ("Key Token" )
1252
1267
plt .ylabel ("Ouput Token" )
1253
- plt .title ("term_7.mean(dim=(0) )" )
1268
+ plt .title ("term_7.mean(dim=0 )" )
1254
1269
1255
1270
plt .grid (False )
1256
1271
plt .tight_layout ()
0 commit comments