Skip to content

Commit 69b4e6f

Browse files
square root
1 parent abfaf34 commit 69b4e6f

File tree

3 files changed

+967
-4
lines changed

3 files changed

+967
-4
lines changed

gbmi/exp_indhead/finetune_ind.py

Lines changed: 234 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,13 +325,132 @@ def diff_2_4(a, i_1, i_2, j, dic, matrices, attn_1):
325325
return t_4
326326

327327

328+
def diff_2_3_4(a, i_1, i_2, j, dic, matrices, attn_1):
329+
330+
(term_0, term_1, term_2, term_3, term_4, term_5, term_6, term_7, term_8) = matrices
331+
332+
if j == i_1:
333+
return 0
334+
for k in range(i_2 + 1):
335+
if j != 0 and j != 1:
336+
c = (
337+
term_4[k, dic[k], j - 1][..., dic[j - 1]].max()
338+
+ term_3[i_2, a, j - 1, dic[j - 1]].max()
339+
)
340+
# new = c.clone()
341+
d = c * attn_1[dic[j], j - 1].min()
342+
343+
for i in range(0, j - 1):
344+
345+
c = torch.max(
346+
c,
347+
term_4[k, dic[k], i][..., dic[i]].max()
348+
+ term_3[i_2, dic[i_2], i, dic[i]].max(),
349+
)
350+
c = torch.max(
351+
c,
352+
term_4[k, dic[k], j][..., dic[j]].max()
353+
+ term_3[i_2, dic[i_2], j, dic[j]].max(),
354+
)
355+
d = d + (1 - attn_1[dic[j], j - 1].min()) * c
356+
357+
if j == 0:
358+
359+
d = (
360+
term_4[k, dic[k], j][..., dic[j]].max()
361+
+ term_3[i_2, a, j, dic[j]].max()
362+
)
363+
364+
if j == 1:
365+
c = (
366+
term_4[k, dic[k], j - 1][..., dic[j - 1]].max()
367+
+ term_3[i_2, a, j - 1, dic[j - 1]].max()
368+
)
369+
# new=c.clone()
370+
d = c * attn_1[dic[j], j - 1].min()
371+
c = torch.max(
372+
c,
373+
term_4[k, dic[k], j][..., dic[j]].max()
374+
+ term_3[i_2, a, j, dic[j]].max(),
375+
)
376+
d = d + (1 - attn_1[dic[j], j - 1].min()) * c
377+
378+
# print(d)
379+
if i_1 != 1:
380+
c = term_4[k, dic[k], i_1 - 1, a].min() + term_3[i_2, a, i_1 - 1, a]
381+
# new=c.clone()
382+
d = d - attn_1[dic[i_1], i_1 - 1].min() * c
383+
384+
for i in range(0, i_1 - 1):
385+
386+
c = torch.min(
387+
c,
388+
term_4[k, dic[k], i][..., dic[i]].min()
389+
+ term_3[i_2, dic[i_2], i, dic[i]].min(),
390+
)
391+
c = torch.min(
392+
c,
393+
term_4[k, dic[k], i_1][..., dic[i_1]].min()
394+
+ term_3[i_2, dic[i_2], i_1, dic[i_1]].min(),
395+
)
396+
d = d - (1 - attn_1[dic[i_1], i_1 - 1].min()) * c
397+
398+
if i_1 == 1:
399+
c = term_4[k, dic[k], i_1 - 1, a].min() + term_3[i_2, a, i_1 - 1, a]
400+
# new=c.clone()
401+
d = d - attn_1[dic[i_1], i_1 - 1].min() * c
402+
403+
c = torch.min(
404+
c,
405+
term_4[k, dic[k], i_1][..., dic[i_1]].min()
406+
+ term_3[i_2, a, i_1, dic[i_1]].min(),
407+
)
408+
d = d - (1 - attn_1[dic[i_1], i_1 - 1].min()) * c
409+
410+
# print(d)
411+
412+
if type(dic[j]) == int:
413+
d = (
414+
d
415+
+ (
416+
term_2[k, dic[k], j][..., dic[j]]
417+
- term_2[k, dic[k], i_1][..., dic[i_1]].min(dim=-1).values
418+
).max()
419+
)
420+
421+
else:
422+
d = (
423+
d
424+
+ (
425+
term_2[k, dic[k], j][..., dic[j]].max(dim=-1).values
426+
- term_2[k, dic[k], i_1][..., dic[i_1]].min(dim=-1).values
427+
).max()
428+
)
429+
430+
if k == 0:
431+
432+
f = d
433+
434+
if k != 0:
435+
f = torch.max(f, d)
436+
437+
if k == i_2 - 1:
438+
439+
g = d.clone()
440+
441+
t_4 = g * attn_1[dic[i_2], i_2 - 1]
442+
t_4 = t_4 + (1 - attn_1[dic[i_2], i_2 - 1]) * f
443+
444+
return t_4
445+
446+
328447
def least_attention(a, i_1, i_2, j, dic, matrices, attn_1):
329448
e = diff_2_4(a, i_1, i_2, j, dic, matrices, attn_1)
330449

331450
return (
332451
diff_1(a, i_1, i_2, j, dic, matrices)
333-
+ diff_3(a, i_1, i_2, j, dic, matrices, attn_1)
334452
+ e
453+
+ diff_3(a, i_1, i_2, j, dic, matrices, attn_1)
335454
)
336455

337456

@@ -587,12 +706,124 @@ def loss_diff_4(b, i_1, i_2, dic, matrices, attn_1, bound_2, n=None):
587706
return ld_4
588707

589708

709+
def loss_diff_3_4(b, i_1, i_2, dic, matrices, attn_1, bound_2, n=None):
710+
711+
(term_0, term_1, term_2, term_3, term_4, term_5, term_6, term_7, term_8) = matrices
712+
713+
if n == b:
714+
return 0
715+
716+
if n is None:
717+
718+
n = torch.arange(d_voc)[torch.arange(d_voc) != b]
719+
720+
for k in range(i_2 + 1):
721+
if k != 0 and k != 1:
722+
c = (
723+
term_8[k - 1, dic[k - 1]][..., n]
724+
- term_8[k - 1, dic[k - 1], b].unsqueeze(dim=-1)
725+
).max()
726+
d = c * attn_1[dic[k], k - 1].min()
727+
for i in range(k - 1):
728+
c = torch.max(
729+
c,
730+
(
731+
term_8[i, dic[i]][..., n]
732+
- term_8[i, dic[i], b].unsqueeze(dim=-1)
733+
).max(),
734+
)
735+
c = torch.max(
736+
c,
737+
(
738+
term_8[k, dic[k]][..., n]
739+
- term_8[k, dic[k], b].unsqueeze(dim=-1)
740+
).max(),
741+
)
742+
d += (1 - attn_1[dic[k], k - 1].min()) * c
743+
744+
if k == 0:
745+
d = (
746+
term_8[0, dic[0]][..., n] - term_8[0, dic[0], b].unsqueeze(dim=-1)
747+
).max()
748+
749+
if k == 1:
750+
c = (
751+
term_8[0, dic[0]][..., n] - term_8[0, dic[0], b].unsqueeze(dim=-1)
752+
).max()
753+
d = c * attn_1[dic[k], k - 1].min()
754+
c = torch.max(
755+
c,
756+
(
757+
term_8[1, dic[1]][..., n]
758+
- term_8[1, dic[1], b].unsqueeze(dim=-1)
759+
).max(),
760+
)
761+
d += (1 - attn_1[dic[k], k - 1].min()) * c
762+
763+
d = (
764+
d
765+
+ (
766+
term_7[k, dic[k]][..., n] - term_7[k, dic[k], b].unsqueeze(dim=-1)
767+
).max()
768+
)
769+
770+
if k == 0:
771+
f = d
772+
if k != 0:
773+
f = torch.max(f, d)
774+
if k == i_1:
775+
g = d
776+
ld_4 = g * (bound_2[dic[i_2], i_2, i_1].min())
777+
ld_4 += (1 - bound_2[dic[i_2], i_2, i_1].min()) * f
778+
return ld_4
779+
780+
for k in range(i_2 + 1):
781+
if k != 0 and k != 1:
782+
c = (term_8[k - 1, dic[k - 1], n] - term_8[k - 1, dic[k - 1], b]).max()
783+
d = c * attn_1[dic[k], k - 1].min()
784+
for i in range(k - 1):
785+
c = torch.max(
786+
c,
787+
(term_8[i, dic[i], n] - term_8[i, dic[i], b]).max(),
788+
)
789+
c = torch.max(
790+
c,
791+
(term_8[k, dic[k], n] - term_8[k, dic[k], b]).max(),
792+
)
793+
d += (1 - attn_1[dic[k], k - 1].min()) * c
794+
795+
if k == 0:
796+
d = (term_8[0, dic[0], n] - term_8[0, dic[0], b]).max()
797+
798+
if k == 1:
799+
c = (term_8[0, dic[0], n] - term_8[0, dic[0], b]).max()
800+
d = c * attn_1[dic[k], k - 1].min()
801+
c = torch.max(
802+
c,
803+
(term_8[1, dic[1], n] - term_8[1, dic[1], b]).max(),
804+
)
805+
d += (1 - attn_1[dic[k], k - 1].min()) * c
806+
807+
d = d + (term_7[k, dic[k], n] - term_7[k, dic[k], b]).max()
808+
809+
if k == 0:
810+
f = d
811+
if k != 0:
812+
f = torch.max(f, d)
813+
if k == i_1:
814+
g = d
815+
ld_4 = g * (bound_2[dic[i_2], i_2, i_1].min())
816+
ld_4 += (1 - bound_2[dic[i_2], i_2, i_1].min()) * f
817+
return ld_4
818+
819+
590820
def total_bound(b, i_1, i_2, dic, matrices, attn_1, bound_2, n=None):
591821
return (
592822
loss_diff_1(b, i_1, i_2, dic, matrices, attn_1, bound_2, n)
593823
+ loss_diff_2(b, i_1, i_2, dic, matrices, attn_1, bound_2, n)
594-
+ loss_diff_3(b, i_1, i_2, dic, matrices, attn_1, bound_2, n)
595-
+ loss_diff_4(b, i_1, i_2, dic, matrices, attn_1, bound_2, n)
824+
+ loss_diff_3_4(b, i_1, i_2, dic, matrices, attn_1, bound_2, n)
825+
# + loss_diff_3(b, i_1, i_2, dic, matrices, attn_1, bound_2, n)
826+
# + loss_diff_4(b, i_1, i_2, dic, matrices, attn_1, bound_2, n)
596827
)
597828

598829

0 commit comments

Comments
 (0)