Skip to content

Commit 5b6a68e

Browse files
finetuning
1 parent edbda5a commit 5b6a68e

File tree

2 files changed

+276
-44
lines changed

2 files changed

+276
-44
lines changed

gbmi/exp_indhead/finetunebound.py

Lines changed: 105 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737

3838
# %%
39-
def loss_bound(model, s, w):
39+
def loss_bound(model, s):
4040

4141
W_pos = model.W_pos
4242
W_E = model.W_E
@@ -53,7 +53,7 @@ def loss_bound(model, s, w):
5353

5454
e_p = W_E.unsqueeze(dim=0) + W_pos.unsqueeze(dim=1)
5555

56-
everything = (
56+
term_0 = (
5757
einops.einsum(
5858
e_p,
5959
W_Q_0,
@@ -70,16 +70,16 @@ def loss_bound(model, s, w):
7070
for p in range(2, n_ctx): #
7171
tmp = torch.zeros((p, d_voc))
7272
for t_q in range(d_voc):
73-
tmp[-1, :] = everything[p - 1, t_q, p - 1, t_q]
73+
tmp[-1, :] = term_0[p - 1, t_q, p - 1, t_q]
7474

7575
for t_k in range(d_voc):
76-
tmp[-2, :] = everything[p - 1, t_q, p - 2, t_k]
77-
tmp[:-2, :] = everything[p - 1, t_q, : p - 2, :]
76+
tmp[-2, :] = term_0[p - 1, t_q, p - 2, t_k]
77+
tmp[:-2, :] = term_0[p - 1, t_q, : p - 2, :]
7878
tmp_sm = tmp.softmax(dim=0)
7979
table[t_q, t_k, p - 2, :] = tmp_sm[-2, :]
8080
# Table represents post softmax attention paid to t_k, if the final entry is spammed everywhere, and t_q is used as the first entry, at pth poisition
8181

82-
# everything looks like EQKE, table looks like you're indexing by query, key, position (of key?), and other token in the sequence.
82+
# term_0 looks like EQKE, table looks like you're indexing by query, key, position (of key?), and other token in the sequence.
8383
# They you're computing softmax of d_voc - 2 copies of the other token, one copy of t_k in p-2, and the query in p-1.
8484
# Then you store the post-softmax attention paid to t_k.
8585
#
@@ -177,6 +177,9 @@ def loss_bound(model, s, w):
177177
"q_pos q_val k, k l, l m, m n, n p, p q -> q_pos q_val q",
178178
)
179179

180+
if s == -1:
181+
return (term_0, term_1, term_2, term_3, term_4, term_5, term_6, term_7, term_8)
182+
180183
if s == 0:
181184
reduced_3 = einops.einsum(
182185
term_3, "q_pos q_val k_pos k_val -> q_pos q_val k_pos"
@@ -421,17 +424,27 @@ def least_attention(a, i_1, i_2, j, dic):
421424
if s == 2:
422425
return (attn_1, bound, bound_2)
423426

424-
def loss_diff_1(b, i_1, i_2, dic):
427+
def loss_diff_1(b, i_1, i_2, dic, n=None):
428+
429+
if n == b:
430+
return 0
425431

426-
n = torch.arange(d_voc)[torch.arange(d_voc) != b]
432+
if n is None:
433+
434+
n = torch.arange(d_voc)[torch.arange(d_voc) != b]
427435

428436
return (
429-
term_5[i_2, dic[i_2]][..., n] - term_5[i_2, :, b].unsqueeze(dim=-1)
437+
term_5[i_2, dic[i_2]][..., n] - term_5[i_2, dic[i_2], b].unsqueeze(dim=-1)
430438
).max()
431439

432-
def loss_diff_2(b, i_1, i_2, dic):
440+
def loss_diff_2(b, i_1, i_2, dic, n=None):
441+
442+
if n == b:
443+
return 0
444+
445+
if n is None:
433446

434-
n = torch.arange(d_voc)[torch.arange(d_voc) != b]
447+
n = torch.arange(d_voc)[torch.arange(d_voc) != b]
435448

436449
c = (term_6[0, dic[0]][..., n] - term_6[0, dic[0], b].unsqueeze(dim=-1)).max()
437450

@@ -460,8 +473,12 @@ def loss_diff_2(b, i_1, i_2, dic):
460473
)
461474
return ld_2
462475

463-
def loss_diff_3(b, i_1, i_2, dic):
464-
n = torch.arange(d_voc)[torch.arange(d_voc) != b]
476+
def loss_diff_3(b, i_1, i_2, dic, n=None):
477+
if n == b:
478+
return 0
479+
480+
if n is None:
481+
n = torch.arange(d_voc)[torch.arange(d_voc) != b]
465482
c = (term_7[0, dic[0]][..., n] - term_7[0, dic[0], b].unsqueeze(dim=-1)).max()
466483
for i in range(i_1):
467484
c = torch.max(
@@ -488,9 +505,14 @@ def loss_diff_3(b, i_1, i_2, dic):
488505
)
489506
return ld_3
490507

491-
def loss_diff_4(b, i_1, i_2, dic):
508+
def loss_diff_4(b, i_1, i_2, dic, n=None):
492509

493-
n = torch.arange(d_voc)[torch.arange(d_voc) != b]
510+
if n == b:
511+
return 0
512+
513+
if n is None:
514+
515+
n = torch.arange(d_voc)[torch.arange(d_voc) != b]
494516

495517
for k in range(i_2 + 1):
496518
if k != 0 and k != 1:
@@ -546,32 +568,57 @@ def loss_diff_4(b, i_1, i_2, dic):
546568
)
547569
return ld_4
548570

549-
def total_bound(b, i_1, i_2, dic):
571+
def total_bound(b, i_1, i_2, dic, n=None):
550572
return (
551-
loss_diff_1(b, i_1, i_2, dic)
552-
+ loss_diff_2(b, i_1, i_2, dic)
553-
+ loss_diff_3(b, i_1, i_2, dic)
554-
+ loss_diff_4(b, i_1, i_2, dic)
573+
loss_diff_1(b, i_1, i_2, dic, n)
574+
+ loss_diff_2(b, i_1, i_2, dic, n)
575+
+ loss_diff_3(b, i_1, i_2, dic, n)
576+
+ loss_diff_4(b, i_1, i_2, dic, n)
555577
)
556578

557-
out = torch.zeros((d_voc, n_ctx, n_ctx)) + torch.inf
579+
if s == 3:
580+
581+
out = torch.zeros((d_voc, n_ctx, n_ctx)) + torch.inf
582+
# b i_2 i_1
583+
584+
for b in range(e_p.shape[1]):
585+
586+
for i_2 in range(e_p.shape[0] - 1):
587+
for i_1 in range(1, i_2):
588+
589+
if (i_1 < i_2) & (i_1 > 0):
590+
dic = {i_1: b}
591+
for i in range(8):
592+
dic.setdefault(i, torch.arange(26))
593+
594+
out[b, i_2, i_1] = total_bound(b, i_1, i_2, dic)
595+
596+
out_2 = 1 / (1 + ((d_voc - 1) * torch.exp(out)))
597+
598+
return (attn_1, bound, bound_2, out, out_2)
599+
600+
out = torch.zeros((d_voc, n_ctx, n_ctx, d_voc)) + torch.inf
558601
# b i_2 i_1
559602

560603
for b in range(e_p.shape[1]):
604+
for n in range(e_p.shape[1]):
605+
for i_2 in range(e_p.shape[0] - 1):
606+
for i_1 in range(1, i_2):
561607

562-
for i_2 in range(e_p.shape[0] - 1):
563-
for i_1 in range(1, i_2):
608+
if (i_1 < i_2) & (i_1 > 0):
609+
dic = {i_1: b}
610+
for i in range(8):
611+
dic.setdefault(i, torch.arange(26))
564612

565-
if (i_1 < i_2) & (i_1 > 0):
566-
dic = {i_1: b}
567-
for i in range(8):
568-
dic.setdefault(i, torch.arange(26))
613+
out[b, i_2, i_1, n] = total_bound(b, i_1, i_2, dic, n)
569614

570-
out[b, i_2, i_1] = total_bound(b, i_1, i_2, dic)
615+
out_2 = einops.einsum(out.softmax(dim=-1), "b i_2 i_1 b -> b i_2 i_1")
571616

572-
out_2 = 1 / (1 + ((d_voc - 1) * torch.exp(out)))
617+
out_3 = einops.einsum(
618+
out - out.max(dim=-1).values.unsqueeze(dim=-1), "b i_2 i_1 b -> b i_2 i_1"
619+
)
573620

574-
return (attn_1, bound, bound_2, out, out_2)
621+
return (attn_1, bound, bound_2, out, out_2, out_3)
575622

576623

577624
# %%
@@ -647,22 +694,40 @@ def total_bound(b, i_1, i_2, dic):
647694
counter += 1
648695
print(counter)
649696

650-
697+
# %%
698+
valid = (
699+
ein.array(
700+
lambda i, j, k: where(k > 0, where(j > k, where(j < 7, 1, 0), 0), 0),
701+
sizes=[d_voc, n_ctx, n_ctx],
702+
)
703+
.bool()
704+
.to(device)
705+
)
651706
optimiser = torch.optim.AdamW(
652-
model_1.parameters(), lr=5e-3, betas=(0.9, 0.999), weight_decay=1.0
707+
model_1.parameters(), lr=0.5, betas=(0.9, 0.999), weight_decay=0
653708
)
709+
# %%
710+
a = loss_bound(model_1, 3)[4]
711+
loss = 1 - a[valid].min()
712+
print(a[valid].min())
713+
print(a[valid].mean())
714+
print(a[valid].max())
715+
for i in range(1):
716+
print(i + 1)
654717

655-
a = loss_bound(model_1, 3, 8)[4]
656-
loss = 1 - a[a != 0].mean()
657-
for i in range(30):
658-
print(a[a != 0].mean())
659718
loss.backward()
660719
optimiser.step()
661720
optimiser.zero_grad()
662-
a = loss_bound(model_1, 3, 8)[4][5]
663-
loss = 1 - a[a != 0].mean()
664-
counter += 1
665-
print(counter)
721+
a = loss_bound(model_1, 3)[4]
722+
loss = 1 - a[valid].min()
723+
print(a[valid].min())
724+
print(a[valid].mean())
725+
print(a[valid].max())
726+
if i % 10 == 1:
727+
r = loss_bound(model_1, 4)[5]
728+
print(r[valid].min())
729+
print(r[valid].mean())
730+
print(r[valid].max())
666731

667732
# %%
668733
'''

0 commit comments

Comments
 (0)