Skip to content

Commit fb5b181

Browse files
more graphs
1 parent bca6803 commit fb5b181

File tree

5 files changed

+105
-90
lines changed

5 files changed

+105
-90
lines changed

gbmi/exp_indhead/induction_head_results.py

Lines changed: 105 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -866,10 +866,10 @@ def sample_acc_and_loss(model, batch_size=15000):
866866

867867
# %%
868868
@torch.no_grad()
869-
def metric_tracking(term_dic, l_bound, accuracy_bound):
870-
(acc, loss) = sample_acc_and_loss(model_1, batch_size=5000)
869+
def metric_tracking(model, term_dic, l_bound, accuracy_bound):
870+
(acc, loss) = sample_acc_and_loss(model, batch_size=5000)
871871
(term_0, term_1, term_2, term_3, term_4, term_5, term_6, term_7, term_8) = terms(
872-
model_1
872+
model
873873
)
874874
term_dic["l_b"].append(l_bound)
875875
term_dic["a_b"].append(accuracy_bound)
@@ -941,12 +941,12 @@ def metric_tracking(term_dic, l_bound, accuracy_bound):
941941
import numpy as np
942942

943943

944-
def plot_loss(data):
944+
def plot_loss(data, n, m, rand):
945945
l_b = data["l_b"].detach().cpu()
946946
l_a = data["l_a"].detach().cpu()
947947

948948
plt.figure(figsize=(10, 6))
949-
x = np.arange(500)
949+
x = np.arange(0, n, m)
950950

951951
# Plot both lines
952952
plt.plot(x, l_b, "r-", label="loss bound", linewidth=1, marker=".", markersize=2)
@@ -955,8 +955,9 @@ def plot_loss(data):
955955
)
956956

957957
# Add horizontal line at ln(26)
958-
plt.axhline(y=np.log(26), color="grey", linestyle="--", label="ln(26)")
959-
plt.text(x[-1], np.log(26), "ln(26)", verticalalignment="bottom")
958+
if rand:
959+
plt.axhline(y=np.log(26), color="grey", linestyle="--", label="ln(26)")
960+
plt.text(x[-1], np.log(26), "ln(26)", verticalalignment="bottom")
960961

961962
# Set scale and labels
962963
plt.yscale("log")
@@ -968,18 +969,22 @@ def plot_loss(data):
968969
plt.show()
969970

970971

971-
def plot_accuracy(data):
972+
def plot_accuracy(data, n, m, rand):
972973
a_b = data["a_b"].detach().cpu()
973974
a_a = data["a_a"].detach().cpu()
974975

975976
plt.figure(figsize=(10, 6))
976-
x = np.arange(500)
977+
x = np.arange(0, n, m)
977978

978979
plt.plot(
979980
x, a_b, "b-", label="accuracy bound", linewidth=1, marker=".", markersize=2
980981
)
981982
plt.plot(x, a_a, "g-", label="accuracy", linewidth=1, marker=".", markersize=2)
982983

984+
if rand:
985+
plt.axhline(y=1 / 26, color="grey", linestyle="--", label="ln(26)")
986+
plt.text(x[-1], 1 / 26, "1/26", verticalalignment="bottom")
987+
983988
plt.title("Accuracy as model is finetuned")
984989
plt.xlabel("Gradient Steps")
985990
plt.ylabel("Accuracy")
@@ -988,7 +993,7 @@ def plot_accuracy(data):
988993
plt.show()
989994

990995

991-
def plot_zero(data):
996+
def plot_zero(data, n, m):
992997
tr_1 = data["1"].detach().cpu()
993998
tr_2 = data["2"].detach().cpu()
994999
tr_4 = data["4"].detach().cpu()
@@ -997,7 +1002,7 @@ def plot_zero(data):
9971002
tr_8 = data["8"].detach().cpu()
9981003

9991004
plt.figure(figsize=(10, 6))
1000-
x = np.arange(500)
1005+
x = np.arange(0, n, m)
10011006

10021007
plt.plot(x, tr_1, "b-", label="term_1", linewidth=1, marker=".", markersize=2)
10031008
plt.plot(x, tr_2, "g-", label="term_2", linewidth=1, marker=".", markersize=2)
@@ -1018,12 +1023,12 @@ def plot_zero(data):
10181023
plt.show()
10191024

10201025

1021-
def plot_diag(data, i):
1026+
def plot_diag(data, i, n, m):
10221027
tr_1_d = data[str(i) + "_d"].detach().cpu()
10231028
tr_1_o = data[str(i) + "_o"].detach().cpu()
10241029

10251030
plt.figure(figsize=(10, 6))
1026-
x = np.arange(500)
1031+
x = np.arange(0, n, m)
10271032

10281033
plt.plot(x, tr_1_d, "b-", label="diagonal", linewidth=1, marker=".", markersize=2)
10291034
plt.plot(
@@ -1078,76 +1083,46 @@ def l_2(model):
10781083

10791084

10801085
# %%
1081-
def get_graphs(fun, model):
1082-
t_0_d = []
1083-
t_0_o = []
1084-
t_1 = []
1085-
t_2 = []
1086-
t_3_d = []
1087-
t_3_o = []
1088-
t_4 = []
1089-
t_5 = []
1090-
t_6 = []
1091-
t_7_d = []
1092-
t_7_o = []
1093-
t_8 = []
1094-
loss_b = []
1095-
acc_b = []
1096-
loss_a = []
1097-
acc_a = []
1086+
def get_graphs(fun, model, term_name, model_name):
1087+
term_dic = {
1088+
"l_b": [],
1089+
"a_b": [],
1090+
"l_a": [],
1091+
"a_a": [],
1092+
"0_d": [],
1093+
"0_o": [],
1094+
"1": [],
1095+
"2": [],
1096+
"3_d": [],
1097+
"3_o": [],
1098+
"4": [],
1099+
"5": [],
1100+
"6": [],
1101+
"7_d": [],
1102+
"7_o": [],
1103+
"8": [],
1104+
}
10981105
optimiser = torch.optim.AdamW(
10991106
model_1.parameters(), lr=2e-3, betas=(0.9, 0.999), weight_decay=1.0
11001107
)
1101-
for i in range(500):
1108+
for i in range(5000):
11021109
print(i)
1103-
a = loss_bound(model)
1104-
l_bound = a[-2]
1105-
accuracy_bound = a[-1]
1106-
(acc, loss) = sample_acc_and_loss(model, batch_size=5000)
1107-
print(l_bound)
1108-
1109-
(term_0, term_1, term_2, term_3, term_4, term_5, term_6, term_7, term_8) = (
1110-
terms(model)
1111-
)
1112-
loss_b.append(l_bound)
1113-
acc_b.append(accuracy_bound)
1114-
loss_a.append(loss)
1115-
acc_a.append(acc)
1116-
t_0_d.append(((term_0[index_0_d])).mean())
1117-
t_0_o.append(((term_0[index_0_o])).mean())
1118-
t_1.append(((term_1[causal_mask]) ** 2).mean().sqrt())
1119-
t_2.append(((term_2[causal_mask]) ** 2).mean().sqrt())
1120-
t_3_d.append(((term_3[index_3_d])).mean())
1121-
t_3_o.append(((term_3[index_3_o])).mean())
1122-
t_4.append(((term_4[causal_mask]) ** 2).mean().sqrt())
1123-
t_5.append(((term_5) ** 2).mean().sqrt())
1124-
t_6.append(((term_6) ** 2).mean().sqrt())
1125-
t_7_d.append(((term_7[index_7_d])).mean())
1126-
t_7_o.append(((term_7[index_7_o])).mean())
1127-
t_8.append(((term_8) ** 2).mean().sqrt())
1128-
term_dic = {
1129-
"l_b": torch.tensor(loss_b),
1130-
"a_b": torch.tensor(acc_b),
1131-
"l_a": torch.tensor(loss_a),
1132-
"a_a": torch.tensor(acc_a),
1133-
"0_d": torch.tensor(t_0_d),
1134-
"0_o": torch.tensor(t_0_o),
1135-
"1": torch.tensor(t_1),
1136-
"2": torch.tensor(t_2),
1137-
"3_d": torch.tensor(t_3_d),
1138-
"3_o": torch.tensor(t_3_o),
1139-
"4": torch.tensor(t_4),
1140-
"5": torch.tensor(t_5),
1141-
"6": torch.tensor(t_6),
1142-
"7_d": torch.tensor(t_7_d),
1143-
"7_o": torch.tensor(t_7_o),
1144-
"8": torch.tensor(t_8),
1145-
}
11461110
a_loss = fun(model)
11471111
print(a_loss)
11481112
a_loss.backward()
11491113
optimiser.step()
11501114
optimiser.zero_grad()
1115+
if i % 100 == 0:
1116+
a = loss_bound(model)
1117+
l_bound = a[-2]
1118+
accuracy_bound = a[-1]
1119+
metric_tracking(
1120+
model, term_dic, l_bound.detach().cpu(), accuracy_bound.detach().cpu()
1121+
)
1122+
print(l_bound)
1123+
display_model(model)
1124+
torch.save(term_dic, term_name)
1125+
torch.save(model, model_name)
11511126
return term_dic
11521127

11531128

@@ -1168,27 +1143,67 @@ def add_noise(model, v):
11681143
put_in_model(model, new_raw_terms)
11691144

11701145

1171-
noise_plot = {"noise": [], "acc": [], "loss": [], "l_b": [], "a_b": []}
1146+
@torch.no_grad()
1147+
def noise_metric_tracking(term_dic, l_bound, accuracy_bound, noise):
1148+
(acc, loss) = sample_acc_and_loss(model_2, batch_size=5000)
1149+
(term_0, term_1, term_2, term_3, term_4, term_5, term_6, term_7, term_8) = terms(
1150+
model_2
1151+
)
1152+
term_dic["noise"].append(noise)
1153+
term_dic["l_b"].append(l_bound)
1154+
term_dic["a_b"].append(accuracy_bound)
1155+
term_dic["l_a"].append(loss)
1156+
term_dic["a_a"].append(acc)
1157+
term_dic["0_d"].append(((term_0[index_0_d])).mean())
1158+
term_dic["0_o"].append(((term_0[index_0_o])).mean())
1159+
term_dic["1"].append(((term_1[causal_mask]) ** 2).mean().sqrt())
1160+
term_dic["2"].append(((term_2[causal_mask]) ** 2).mean().sqrt())
1161+
term_dic["3_d"].append(((term_3[index_3_d])).mean())
1162+
term_dic["3_o"].append(((term_3[index_3_o])).mean())
1163+
term_dic["4"].append(((term_4[causal_mask]) ** 2).mean().sqrt())
1164+
term_dic["5"].append(((term_5) ** 2).mean().sqrt())
1165+
term_dic["6"].append(((term_6) ** 2).mean().sqrt())
1166+
term_dic["7_d"].append(((term_7[index_7_d])).mean())
1167+
term_dic["7_o"].append(((term_7[index_7_o])).mean())
1168+
term_dic["8"].append(((term_8) ** 2).mean().sqrt())
1169+
1170+
1171+
noise_data = {
1172+
"noise": [],
1173+
"l_b": [],
1174+
"a_b": [],
1175+
"l_a": [],
1176+
"a_a": [],
1177+
"0_d": [],
1178+
"0_o": [],
1179+
"1": [],
1180+
"2": [],
1181+
"3_d": [],
1182+
"3_o": [],
1183+
"4": [],
1184+
"5": [],
1185+
"6": [],
1186+
"7_d": [],
1187+
"7_o": [],
1188+
"8": [],
1189+
}
11721190
for i in range(25, 51):
1173-
loss_b = 0
1174-
acc_b = 0
1175-
loss = 0
1176-
acc = 0
1191+
print(i)
11771192
for j in range(10):
11781193
add_noise(model_2, i / 1000)
11791194
a = loss_bound(model_2)
1180-
loss_b += a[-2]
1181-
acc_b += a[-1]
1182-
b = sample_acc_and_loss(model_2, batch_size=5000)
1183-
loss += b[1]
1184-
acc += b[0]
1185-
noise_plot["l_b"].append(loss_b / 10)
1186-
noise_plot["a_b"].append(acc_b / 10)
1187-
noise_plot["loss"].append(loss / 10)
1188-
noise_plot["acc"].append(acc / 10)
1189-
noise_plot["noise"].append(i / 1000)
1195+
l_bound = a[-2]
1196+
accuracy_bound = a[-1]
1197+
1198+
print(l_bound)
1199+
noise_metric_tracking(
1200+
noise_data, l_bound.detach().cpu(), accuracy_bound.detach().cpu(), i / 1000
1201+
)
1202+
torch.save(noise_data, "noise_term.pt")
1203+
torch.save(noise_data, "noise_term.pt")
11901204

11911205

1206+
# %%
11921207
plt.plot(noise_plot["noise"], noise_plot["l_b"], label="Loss Bound")
11931208
plt.plot(noise_plot["noise"], noise_plot["loss"], label="Loss")
11941209
plt.xlabel("Noise Level")
@@ -1250,7 +1265,7 @@ def display_model(m):
12501265
# Axis labels
12511266
plt.xlabel("Key Token")
12521267
plt.ylabel("Ouput Token")
1253-
plt.title("term_7.mean(dim=(0))")
1268+
plt.title("term_7.mean(dim=0)")
12541269

12551270
plt.grid(False)
12561271
plt.tight_layout()

gbmi/exp_indhead/l_2_2_model.pth

168 KB
Binary file not shown.

gbmi/exp_indhead/l_2_2_term.pt

8.03 KB
Binary file not shown.

gbmi/exp_indhead/noise_term.pt

22.3 KB
Binary file not shown.

gbmi/exp_indhead/term_test.pt

-1.98 MB
Binary file not shown.

0 commit comments

Comments
 (0)