Skip to content

Commit bca6803

Browse files
formatting
1 parent 9de153c commit bca6803

File tree

1 file changed

+67
-18
lines changed

1 file changed

+67
-18
lines changed

gbmi/exp_indhead/induction_head_results.py

Lines changed: 67 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,7 @@ def metric_tracking(term_dic, l_bound, accuracy_bound):
910910
"7_o": [],
911911
"8": [],
912912
}
913+
# %%
913914

914915
for i in range(500):
915916
if i % 100 == 0:
@@ -925,17 +926,17 @@ def metric_tracking(term_dic, l_bound, accuracy_bound):
925926
optimiser.step()
926927
optimiser.zero_grad()
927928

928-
torch.save(model_1, "finetuned_model_test.pth")
929+
# torch.save(model_1, "finetuned_model_test.pth")
929930
for a, b in term_dic.items():
930931
b = torch.tensor(b)
931932

932-
torch.save(term_dic, "term_test.pt")
933-
wandb.save("finetuned_model_test.pth")
934-
wandb.log(term_dic)
933+
# torch.save(term_dic, "term_test.pt")
934+
# wandb.save("finetuned_model_test.pth")
935+
# wandb.log(term_dic)
935936
wandb.finish()
936937

937938
# %%
938-
data_1 = torch.load("term.pt")
939+
data_1 = torch.load("term_test.pt")
939940
import matplotlib.pyplot as plt
940941
import numpy as np
941942

@@ -998,18 +999,20 @@ def plot_zero(data):
998999
plt.figure(figsize=(10, 6))
9991000
x = np.arange(500)
10001001

1001-
plt.plot(x, tr_1, "b-", label="l_2 norm", linewidth=1, marker=".", markersize=2)
1002-
plt.plot(x, tr_2, "g-", label="accuracy", linewidth=1, marker=".", markersize=2)
1003-
plt.plot(x, tr_4, "r-", label="accuracy", linewidth=1, marker=".", markersize=2)
1002+
plt.plot(x, tr_1, "b-", label="term_1", linewidth=1, marker=".", markersize=2)
1003+
plt.plot(x, tr_2, "g-", label="term_2", linewidth=1, marker=".", markersize=2)
1004+
plt.plot(x, tr_4, "r-", label="term_4", linewidth=1, marker=".", markersize=2)
10041005
plt.plot(
1005-
x, tr_5, color="orange", label="accuracy", linewidth=1, marker=".", markersize=2
1006+
x, tr_5, color="orange", label="term_5", linewidth=1, marker=".", markersize=2
10061007
)
1007-
plt.plot(x, tr_6, "y-", label="accuracy", linewidth=1, marker=".", markersize=2)
1008-
plt.plot(x, tr_8, "g-", label="accuracy", linewidth=1, marker=".", markersize=2)
1008+
plt.plot(x, tr_6, "y-", label="term_6", linewidth=1, marker=".", markersize=2)
1009+
plt.plot(x, tr_8, "p-", label="term_8", linewidth=1, marker=".", markersize=2)
10091010

1010-
plt.title("Accuracy as model is finetuned")
1011+
plt.title(
1012+
"Root Mean Square value of terms that are not used in our mechanistic interpretation as model is finetuned"
1013+
)
10111014
plt.xlabel("Gradient Steps")
1012-
plt.ylabel("Accuracy")
1015+
plt.ylabel("RMS value")
10131016
plt.legend()
10141017
plt.grid(True)
10151018
plt.show()
@@ -1022,12 +1025,14 @@ def plot_diag(data, i):
10221025
plt.figure(figsize=(10, 6))
10231026
x = np.arange(500)
10241027

1025-
plt.plot(x, tr_1_d, "b-", label="l_2 norm", linewidth=1, marker=".", markersize=2)
1026-
plt.plot(x, tr_1_o, "g-", label="accuracy", linewidth=1, marker=".", markersize=2)
1028+
plt.plot(x, tr_1_d, "b-", label="diagonal", linewidth=1, marker=".", markersize=2)
1029+
plt.plot(
1030+
x, tr_1_o, "g-", label="off diagonal", linewidth=1, marker=".", markersize=2
1031+
)
10271032

1028-
plt.title("Accuracy as model is finetuned")
1033+
plt.title("Mean of off and on diagonal elements of term_" + str(i))
10291034
plt.xlabel("Gradient Steps")
1030-
plt.ylabel("Accuracy")
1035+
plt.ylabel("Mean")
10311036
plt.legend()
10321037
plt.grid(True)
10331038
plt.show()
@@ -1148,17 +1153,61 @@ def get_graphs(fun, model):
11481153

11491154
# %%
11501155
def noise(M, v):
1151-
return M + torch.randn_like(M) * v
1156+
return M + torch.normal(mean=0, std=v, size=M.shape).to(device)
1157+
1158+
1159+
model_2.to(device)
11521160

11531161

11541162
def add_noise(model, v):
1163+
11551164
new_raw_terms = []
11561165
for i in range(len(raw_terms)):
11571166
new_raw_terms.append(noise(raw_terms[i].detach().clone(), v))
11581167
new_raw_terms[i].requires_grad = True
11591168
put_in_model(model, new_raw_terms)
11601169

11611170

1171+
noise_plot = {"noise": [], "acc": [], "loss": [], "l_b": [], "a_b": []}
1172+
for i in range(25, 51):
1173+
loss_b = 0
1174+
acc_b = 0
1175+
loss = 0
1176+
acc = 0
1177+
for j in range(10):
1178+
add_noise(model_2, i / 1000)
1179+
a = loss_bound(model_2)
1180+
loss_b += a[-2]
1181+
acc_b += a[-1]
1182+
b = sample_acc_and_loss(model_2, batch_size=5000)
1183+
loss += b[1]
1184+
acc += b[0]
1185+
noise_plot["l_b"].append(loss_b / 10)
1186+
noise_plot["a_b"].append(acc_b / 10)
1187+
noise_plot["loss"].append(loss / 10)
1188+
noise_plot["acc"].append(acc / 10)
1189+
noise_plot["noise"].append(i / 1000)
1190+
1191+
1192+
plt.plot(noise_plot["noise"], noise_plot["l_b"], label="Loss Bound")
1193+
plt.plot(noise_plot["noise"], noise_plot["loss"], label="Loss")
1194+
plt.xlabel("Noise Level")
1195+
plt.ylabel("Loss")
1196+
plt.legend()
1197+
plt.grid(True)
1198+
plt.tight_layout()
1199+
plt.show()
1200+
1201+
plt.plot(noise_plot["noise"], noise_plot["a_b"], label="Accuracy Bound")
1202+
plt.plot(noise_plot["noise"], noise_plot["acc"], label="Accuracy")
1203+
plt.xlabel("Noise Level")
1204+
plt.ylabel("Accuracy")
1205+
plt.legend()
1206+
plt.grid(True)
1207+
plt.tight_layout()
1208+
plt.show()
1209+
1210+
11621211
# %%
11631212
def display_model(m):
11641213
a = terms(m)

0 commit comments

Comments
 (0)