Skip to content

Commit 9de153c

Browse files
graphs
1 parent 06254cf commit 9de153c

8 files changed

+109
-40
lines changed
168 KB
Binary file not shown.
168 KB
Binary file not shown.
168 KB
Binary file not shown.
168 KB
Binary file not shown.
168 KB
Binary file not shown.
168 KB
Binary file not shown.

gbmi/exp_indhead/induction_head_results.py

Lines changed: 109 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -862,57 +862,77 @@ def sample_acc_and_loss(model, batch_size=15000):
862862
optimiser = torch.optim.AdamW(
863863
model_1.parameters(), lr=2e-3, betas=(0.9, 0.999), weight_decay=1.0
864864
)
865+
866+
865867
# %%
868+
@torch.no_grad()
869+
def metric_tracking(term_dic, l_bound, accuracy_bound):
870+
(acc, loss) = sample_acc_and_loss(model_1, batch_size=5000)
871+
(term_0, term_1, term_2, term_3, term_4, term_5, term_6, term_7, term_8) = terms(
872+
model_1
873+
)
874+
term_dic["l_b"].append(l_bound)
875+
term_dic["a_b"].append(accuracy_bound)
876+
term_dic["l_a"].append(loss)
877+
term_dic["a_a"].append(acc)
878+
term_dic["0_d"].append(((term_0[index_0_d])).mean())
879+
term_dic["0_o"].append(((term_0[index_0_o])).mean())
880+
term_dic["1"].append(((term_1[causal_mask]) ** 2).mean().sqrt())
881+
term_dic["2"].append(((term_2[causal_mask]) ** 2).mean().sqrt())
882+
term_dic["3_d"].append(((term_3[index_3_d])).mean())
883+
term_dic["3_o"].append(((term_3[index_3_o])).mean())
884+
term_dic["4"].append(((term_4[causal_mask]) ** 2).mean().sqrt())
885+
term_dic["5"].append(((term_5) ** 2).mean().sqrt())
886+
term_dic["6"].append(((term_6) ** 2).mean().sqrt())
887+
term_dic["7_d"].append(((term_7[index_7_d])).mean())
888+
term_dic["7_o"].append(((term_7[index_7_o])).mean())
889+
term_dic["8"].append(((term_8) ** 2).mean().sqrt())
890+
891+
892+
import wandb
893+
894+
wandb.init(project="induction_head_finetune_results")
895+
term_dic = {
896+
"l_b": [],
897+
"a_b": [],
898+
"l_a": [],
899+
"a_a": [],
900+
"0_d": [],
901+
"0_o": [],
902+
"1": [],
903+
"2": [],
904+
"3_d": [],
905+
"3_o": [],
906+
"4": [],
907+
"5": [],
908+
"6": [],
909+
"7_d": [],
910+
"7_o": [],
911+
"8": [],
912+
}
913+
866914
for i in range(500):
915+
if i % 100 == 0:
916+
torch.save(model_1, f"finetuned_model_{i}.pth")
867917
print(i)
868918
a = loss_bound(model_1)
869919
l_bound = a[-2]
870920
accuracy_bound = a[-1]
871-
(acc, loss) = sample_acc_and_loss(model_1, batch_size=5000)
921+
872922
print(l_bound)
873-
(term_0, term_1, term_2, term_3, term_4, term_5, term_6, term_7, term_8) = terms(
874-
model_1
875-
)
876-
loss_b.append(l_bound)
877-
acc_b.append(accuracy_bound)
878-
loss_a.append(loss)
879-
acc_a.append(acc)
880-
t_0_d.append(((term_0[index_0_d])).mean())
881-
t_0_o.append(((term_0[index_0_o])).mean())
882-
t_1.append(((term_1[causal_mask]) ** 2).mean().sqrt())
883-
t_2.append(((term_2[causal_mask]) ** 2).mean().sqrt())
884-
t_3_d.append(((term_3[index_3_d])).mean())
885-
t_3_o.append(((term_3[index_3_o])).mean())
886-
t_4.append(((term_4[causal_mask]) ** 2).mean().sqrt())
887-
t_5.append(((term_5) ** 2).mean().sqrt())
888-
t_6.append(((term_6) ** 2).mean().sqrt())
889-
t_7_d.append(((term_7[index_7_d])).mean())
890-
t_7_o.append(((term_7[index_7_o])).mean())
891-
t_8.append(((term_8) ** 2).mean().sqrt())
892-
term_dic = {
893-
"l_b": torch.tensor(loss_b),
894-
"a_b": torch.tensor(acc_b),
895-
"l_a": torch.tensor(loss_a),
896-
"a_a": torch.tensor(acc_a),
897-
"0_d": torch.tensor(t_0_d),
898-
"0_o": torch.tensor(t_0_o),
899-
"1": torch.tensor(t_1),
900-
"2": torch.tensor(t_2),
901-
"3_d": torch.tensor(t_3_d),
902-
"3_o": torch.tensor(t_3_o),
903-
"4": torch.tensor(t_4),
904-
"5": torch.tensor(t_5),
905-
"6": torch.tensor(t_6),
906-
"7_d": torch.tensor(t_7_d),
907-
"7_o": torch.tensor(t_7_o),
908-
"8": torch.tensor(t_8),
909-
}
910-
# torch.save(term_dic,"term.pt")
911923
l_bound.backward()
924+
metric_tracking(term_dic, l_bound.detach().cpu(), accuracy_bound.detach().cpu())
912925
optimiser.step()
913926
optimiser.zero_grad()
914-
# torch.save(model_1,"finetuned_model.pth")
915927

928+
torch.save(model_1, "finetuned_model_test.pth")
929+
for a, b in term_dic.items():
930+
b = torch.tensor(b)
931+
932+
torch.save(term_dic, "term_test.pt")
933+
wandb.save("finetuned_model_test.pth")
934+
wandb.log(term_dic)
935+
wandb.finish()
916936

917937
# %%
918938
data_1 = torch.load("term.pt")
@@ -1139,4 +1159,53 @@ def add_noise(model, v):
11391159
put_in_model(model, new_raw_terms)
11401160

11411161

1162+
# %%
1163+
def display_model(m):
1164+
a = terms(m)
1165+
plt.figure(figsize=(8, 6))
1166+
plt.imshow(
1167+
(a[0].mean(dim=(1, 3))).detach().cpu().numpy(), cmap="viridis", aspect="auto"
1168+
)
1169+
plt.colorbar(label="Value") # Add color scale label
1170+
1171+
# Axis labels
1172+
plt.xlabel("Key Position")
1173+
plt.ylabel("Query Position")
1174+
plt.title("term_0.mean(dim=(1,3))")
1175+
1176+
plt.grid(False)
1177+
plt.tight_layout()
1178+
plt.show()
1179+
1180+
plt.figure(figsize=(8, 6))
1181+
plt.imshow(
1182+
(a[3].mean(dim=(0, 2))).detach().cpu().numpy(), cmap="viridis", aspect="auto"
1183+
)
1184+
plt.colorbar(label="Value") # Add color scale label
1185+
1186+
# Axis labels
1187+
plt.xlabel("Key Token")
1188+
plt.ylabel("Query Token")
1189+
plt.title("term_3.mean(dim=(0,2))")
1190+
1191+
plt.grid(False)
1192+
plt.tight_layout()
1193+
plt.show()
1194+
1195+
plt.figure(figsize=(8, 6))
1196+
plt.imshow(
1197+
(a[7].mean(dim=(0))).detach().cpu().numpy(), cmap="viridis", aspect="auto"
1198+
)
1199+
plt.colorbar(label="Value") # Add color scale label
1200+
1201+
# Axis labels
1202+
plt.xlabel("Key Token")
1203+
plt.ylabel("Ouput Token")
1204+
plt.title("term_7.mean(dim=(0))")
1205+
1206+
plt.grid(False)
1207+
plt.tight_layout()
1208+
plt.show()
1209+
1210+
11421211
# %%

gbmi/exp_indhead/term_test.pt

2.01 MB
Binary file not shown.

0 commit comments

Comments
 (0)