Skip to content

Commit 0b26fd9

Browse files
handcoded
1 parent bea76cd commit 0b26fd9

File tree

3 files changed

+563
-473
lines changed

3 files changed

+563
-473
lines changed

gbmi/exp_indhead/finetune_ind.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -205,15 +205,15 @@ def diff_3(a, i_1, i_2, j, dic, matrices, attn_1):
205205
for i in range(0, j - 1):
206206
c = torch.max(c, term_3[i_2, dic[i_2], i, dic[i]].max())
207207
c = torch.max(c, term_3[i_2, dic[i_2], j, dic[j]].max())
208-
t_3 = t_3 + (1 - attn_1[dic[j], j - 1].min()) * c
208+
t_3 += (1 - attn_1[dic[j], j - 1].min()) * c
209209

210210
# print(t_3)
211211
if j == 1:
212212
c = term_3[i_2, a, j - 1, dic[j - 1]].max()
213213
# new=c.clone()
214214
t_3 = c * attn_1[dic[j], j - 1].min()
215215
c = torch.max(c, term_3[i_2, a, j, dic[j]].max())
216-
t_3 = t_3 + (1 - attn_1[dic[j], j - 1].min()) * c
216+
t_3 += (1 - attn_1[dic[j], j - 1].min()) * c
217217

218218
if j == 0:
219219

@@ -349,10 +349,10 @@ def second_layer_attention(matrices, attn_1):
349349
- torch.inf
350350
)
351351

352-
for a in range(term_0.shape[1]):
352+
for a in range(0, term_0.shape[1]):
353353

354-
for i_2 in range(2, term_0.shape[0] - 1):
355-
for i_1 in range(1, i_2):
354+
for i_2 in range(3, term_0.shape[0] - 1):
355+
for i_1 in range(2, i_2):
356356
for j in range(i_2 + 1):
357357
if (i_1 < i_2) & (i_1 > 0) & (i_2 + 1 > j):
358358
dic = {
@@ -607,8 +607,8 @@ def loss_bound(model):
607607

608608
for b in range(term_0.shape[1]):
609609

610-
for i_2 in range(2, term_0.shape[0] - 1):
611-
for i_1 in range(1, i_2):
610+
for i_2 in range(3, term_0.shape[0] - 1):
611+
for i_1 in range(2, i_2):
612612

613613
if (i_1 < i_2) & (i_1 > 0):
614614
dic = {i_1: b}
@@ -621,7 +621,7 @@ def loss_bound(model):
621621

622622
out_2 = 1 / (1 + ((d_voc - 1) * torch.exp(out)))
623623

624-
return (out, out_2)
624+
return (attn_1, bound_2, out, out_2)
625625

626626

627627
def good_loss_bound(model):
@@ -630,13 +630,13 @@ def good_loss_bound(model):
630630
attn_1 = first_layer_attention(matrices)
631631
bound_2 = second_layer_attention(matrices, attn_1)
632632

633-
out = torch.zeros((d_voc, n_ctx, n_ctx, d_voc)) + torch.inf
633+
out = torch.zeros((d_voc, n_ctx, n_ctx, d_voc))
634634
# b i_2 i_1
635635

636636
for b in range(term_0.shape[1]):
637637
for n in range(term_0.shape[1]):
638-
for i_2 in range(term_0.shape[0] - 1):
639-
for i_1 in range(1, i_2):
638+
for i_2 in range(3, term_0.shape[0] - 1):
639+
for i_1 in range(2, i_2):
640640

641641
if (i_1 < i_2) & (i_1 > 0):
642642
dic = {i_1: b}
@@ -687,8 +687,8 @@ def good_loss_bound(model):
687687
# %%
688688
weights_1 = torch.zeros((d_voc, n_ctx, n_ctx))
689689
for a in range(d_voc):
690-
for i_2 in range(2, n_ctx - 1):
691-
for i_1 in range(1, i_2):
690+
for i_2 in range(3, n_ctx - 1):
691+
for i_1 in range(2, i_2):
692692
weights_1[a, i_2, i_1] = (d_voc - 1) ** (i_2 - 1)
693693
# %%
694694
matrices = terms(model_1)
@@ -718,30 +718,30 @@ def good_loss_bound(model):
718718
# %%
719719
valid = (
720720
ein.array(
721-
lambda i, j, k: where(k > 0, where(j > k, where(j < 7, 1, 0), 0), 0),
721+
lambda i, j, k: where(k > 1, where(j > k, where(j < 7, 1, 0), 0), 0),
722722
sizes=[d_voc, n_ctx, n_ctx],
723723
)
724724
.bool()
725725
.to(device)
726726
)
727727
weights_2 = ein.array(
728-
lambda i, j, k: where(k > 0, where(j > k, where(j < 7, 1, 0), 0), 0)
728+
lambda i, j, k: where(k > 1, where(j > k, where(j < 7, 1, 0), 0), 0)
729729
* ((d_voc - 1) * ((d_voc - 1) ** (j - 2))),
730730
sizes=[d_voc, n_ctx, n_ctx],
731731
).to(device)
732732
# %%
733733
optimiser = torch.optim.AdamW(
734-
model_1.parameters(), lr=1e-2, betas=(0.9, 0.999), weight_decay=1.0
734+
model_1.parameters(), lr=5e-2, betas=(0.9, 0.999), weight_decay=1.0
735735
)
736736
# %%
737737
# optimiser = torch.optim.SGD(model_1.parameters(), lr=100)
738738
# %%
739-
bound = loss_bound(model_1)[1]
739+
bound = loss_bound(model_1)[3]
740740
loss = 1 - (torch.nansum(bound * weights_2) / (weights_2.sum()))
741741
print(bound[valid].min())
742742
print(torch.nansum(bound * weights_2) / (weights_2.sum()))
743743
print(bound[valid].max())
744-
while loss > 0.5:
744+
while loss > 0.05:
745745
# torch.autograd.set_detect_anomaly(True)
746746
loss.backward()
747747
# torch.nn.utils.clip_grad_norm_(model_1.parameters(), max_norm=1.0)
@@ -756,6 +756,19 @@ def good_loss_bound(model):
756756
print(bound[valid].max())
757757

758758

759+
# %%
760+
for i in range(10):
761+
print(i)
762+
a = loss_bound(model_1)
763+
loss = 1 - ((torch.nansum(a[3] * weights_2) / (weights_2.sum())))
764+
print(a[0].min())
765+
print(torch.nansum(a[3] * weights_2) / (weights_2.sum()))
766+
print(torch.nansum(a[1] * weights_1) / (weights_1.sum()))
767+
loss.backward()
768+
optimiser.step()
769+
optimiser.zero_grad()
770+
771+
759772
# %%
760773
ModelMatrixLoggingOptions.all(
761774
use_subplots=True, add_mean={-1: None, 0: "tok_to_pos", 1: None}

0 commit comments

Comments
 (0)