36
36
37
37
38
38
# %%
39
- def loss_bound (model , s , w ):
39
+ def loss_bound (model , s ):
40
40
41
41
W_pos = model .W_pos
42
42
W_E = model .W_E
@@ -53,7 +53,7 @@ def loss_bound(model, s, w):
53
53
54
54
e_p = W_E .unsqueeze (dim = 0 ) + W_pos .unsqueeze (dim = 1 )
55
55
56
- everything = (
56
+ term_0 = (
57
57
einops .einsum (
58
58
e_p ,
59
59
W_Q_0 ,
@@ -70,16 +70,16 @@ def loss_bound(model, s, w):
70
70
for p in range (2 , n_ctx ): #
71
71
tmp = torch .zeros ((p , d_voc ))
72
72
for t_q in range (d_voc ):
73
- tmp [- 1 , :] = everything [p - 1 , t_q , p - 1 , t_q ]
73
+ tmp [- 1 , :] = term_0 [p - 1 , t_q , p - 1 , t_q ]
74
74
75
75
for t_k in range (d_voc ):
76
- tmp [- 2 , :] = everything [p - 1 , t_q , p - 2 , t_k ]
77
- tmp [:- 2 , :] = everything [p - 1 , t_q , : p - 2 , :]
76
+ tmp [- 2 , :] = term_0 [p - 1 , t_q , p - 2 , t_k ]
77
+ tmp [:- 2 , :] = term_0 [p - 1 , t_q , : p - 2 , :]
78
78
tmp_sm = tmp .softmax (dim = 0 )
79
79
table [t_q , t_k , p - 2 , :] = tmp_sm [- 2 , :]
80
80
# Table represents post softmax attention paid to t_k, if the final entry is spammed everywhere, and t_q is used as the first entry, at pth poisition
81
81
82
- # everything looks like EQKE, table looks like you're indexing by query, key, position (of key?), and other token in the sequence.
82
+ # term_0 looks like EQKE, table looks like you're indexing by query, key, position (of key?), and other token in the sequence.
83
83
# They you're computing softmax of d_voc - 2 copies of the other token, one copy of t_k in p-2, and the query in p-1.
84
84
# Then you store the post-softmax attention paid to t_k.
85
85
#
@@ -177,6 +177,9 @@ def loss_bound(model, s, w):
177
177
"q_pos q_val k, k l, l m, m n, n p, p q -> q_pos q_val q" ,
178
178
)
179
179
180
+ if s == - 1 :
181
+ return (term_0 , term_1 , term_2 , term_3 , term_4 , term_5 , term_6 , term_7 , term_8 )
182
+
180
183
if s == 0 :
181
184
reduced_3 = einops .einsum (
182
185
term_3 , "q_pos q_val k_pos k_val -> q_pos q_val k_pos"
@@ -421,17 +424,27 @@ def least_attention(a, i_1, i_2, j, dic):
421
424
if s == 2 :
422
425
return (attn_1 , bound , bound_2 )
423
426
424
- def loss_diff_1 (b , i_1 , i_2 , dic ):
427
+ def loss_diff_1 (b , i_1 , i_2 , dic , n = None ):
428
+
429
+ if n == b :
430
+ return 0
425
431
426
- n = torch .arange (d_voc )[torch .arange (d_voc ) != b ]
432
+ if n is None :
433
+
434
+ n = torch .arange (d_voc )[torch .arange (d_voc ) != b ]
427
435
428
436
return (
429
- term_5 [i_2 , dic [i_2 ]][..., n ] - term_5 [i_2 , : , b ].unsqueeze (dim = - 1 )
437
+ term_5 [i_2 , dic [i_2 ]][..., n ] - term_5 [i_2 , dic [ i_2 ] , b ].unsqueeze (dim = - 1 )
430
438
).max ()
431
439
432
- def loss_diff_2 (b , i_1 , i_2 , dic ):
440
+ def loss_diff_2 (b , i_1 , i_2 , dic , n = None ):
441
+
442
+ if n == b :
443
+ return 0
444
+
445
+ if n is None :
433
446
434
- n = torch .arange (d_voc )[torch .arange (d_voc ) != b ]
447
+ n = torch .arange (d_voc )[torch .arange (d_voc ) != b ]
435
448
436
449
c = (term_6 [0 , dic [0 ]][..., n ] - term_6 [0 , dic [0 ], b ].unsqueeze (dim = - 1 )).max ()
437
450
@@ -460,8 +473,12 @@ def loss_diff_2(b, i_1, i_2, dic):
460
473
)
461
474
return ld_2
462
475
463
- def loss_diff_3 (b , i_1 , i_2 , dic ):
464
- n = torch .arange (d_voc )[torch .arange (d_voc ) != b ]
476
+ def loss_diff_3 (b , i_1 , i_2 , dic , n = None ):
477
+ if n == b :
478
+ return 0
479
+
480
+ if n is None :
481
+ n = torch .arange (d_voc )[torch .arange (d_voc ) != b ]
465
482
c = (term_7 [0 , dic [0 ]][..., n ] - term_7 [0 , dic [0 ], b ].unsqueeze (dim = - 1 )).max ()
466
483
for i in range (i_1 ):
467
484
c = torch .max (
@@ -488,9 +505,14 @@ def loss_diff_3(b, i_1, i_2, dic):
488
505
)
489
506
return ld_3
490
507
491
- def loss_diff_4 (b , i_1 , i_2 , dic ):
508
+ def loss_diff_4 (b , i_1 , i_2 , dic , n = None ):
492
509
493
- n = torch .arange (d_voc )[torch .arange (d_voc ) != b ]
510
+ if n == b :
511
+ return 0
512
+
513
+ if n is None :
514
+
515
+ n = torch .arange (d_voc )[torch .arange (d_voc ) != b ]
494
516
495
517
for k in range (i_2 + 1 ):
496
518
if k != 0 and k != 1 :
@@ -546,32 +568,57 @@ def loss_diff_4(b, i_1, i_2, dic):
546
568
)
547
569
return ld_4
548
570
549
- def total_bound (b , i_1 , i_2 , dic ):
571
+ def total_bound (b , i_1 , i_2 , dic , n = None ):
550
572
return (
551
- loss_diff_1 (b , i_1 , i_2 , dic )
552
- + loss_diff_2 (b , i_1 , i_2 , dic )
553
- + loss_diff_3 (b , i_1 , i_2 , dic )
554
- + loss_diff_4 (b , i_1 , i_2 , dic )
573
+ loss_diff_1 (b , i_1 , i_2 , dic , n )
574
+ + loss_diff_2 (b , i_1 , i_2 , dic , n )
575
+ + loss_diff_3 (b , i_1 , i_2 , dic , n )
576
+ + loss_diff_4 (b , i_1 , i_2 , dic , n )
555
577
)
556
578
557
- out = torch .zeros ((d_voc , n_ctx , n_ctx )) + torch .inf
579
+ if s == 3 :
580
+
581
+ out = torch .zeros ((d_voc , n_ctx , n_ctx )) + torch .inf
582
+ # b i_2 i_1
583
+
584
+ for b in range (e_p .shape [1 ]):
585
+
586
+ for i_2 in range (e_p .shape [0 ] - 1 ):
587
+ for i_1 in range (1 , i_2 ):
588
+
589
+ if (i_1 < i_2 ) & (i_1 > 0 ):
590
+ dic = {i_1 : b }
591
+ for i in range (8 ):
592
+ dic .setdefault (i , torch .arange (26 ))
593
+
594
+ out [b , i_2 , i_1 ] = total_bound (b , i_1 , i_2 , dic )
595
+
596
+ out_2 = 1 / (1 + ((d_voc - 1 ) * torch .exp (out )))
597
+
598
+ return (attn_1 , bound , bound_2 , out , out_2 )
599
+
600
+ out = torch .zeros ((d_voc , n_ctx , n_ctx , d_voc )) + torch .inf
558
601
# b i_2 i_1
559
602
560
603
for b in range (e_p .shape [1 ]):
604
+ for n in range (e_p .shape [1 ]):
605
+ for i_2 in range (e_p .shape [0 ] - 1 ):
606
+ for i_1 in range (1 , i_2 ):
561
607
562
- for i_2 in range (e_p .shape [0 ] - 1 ):
563
- for i_1 in range (1 , i_2 ):
608
+ if (i_1 < i_2 ) & (i_1 > 0 ):
609
+ dic = {i_1 : b }
610
+ for i in range (8 ):
611
+ dic .setdefault (i , torch .arange (26 ))
564
612
565
- if (i_1 < i_2 ) & (i_1 > 0 ):
566
- dic = {i_1 : b }
567
- for i in range (8 ):
568
- dic .setdefault (i , torch .arange (26 ))
613
+ out [b , i_2 , i_1 , n ] = total_bound (b , i_1 , i_2 , dic , n )
569
614
570
- out [ b , i_2 , i_1 ] = total_bound ( b , i_1 , i_2 , dic )
615
+ out_2 = einops . einsum ( out . softmax ( dim = - 1 ), "b i_2 i_1 b -> b i_2 i_1" )
571
616
572
- out_2 = 1 / (1 + ((d_voc - 1 ) * torch .exp (out )))
617
+ out_3 = einops .einsum (
618
+ out - out .max (dim = - 1 ).values .unsqueeze (dim = - 1 ), "b i_2 i_1 b -> b i_2 i_1"
619
+ )
573
620
574
- return (attn_1 , bound , bound_2 , out , out_2 )
621
+ return (attn_1 , bound , bound_2 , out , out_2 , out_3 )
575
622
576
623
577
624
# %%
@@ -647,22 +694,40 @@ def total_bound(b, i_1, i_2, dic):
647
694
counter += 1
648
695
print (counter )
649
696
650
-
697
+ # %%
698
+ valid = (
699
+ ein .array (
700
+ lambda i , j , k : where (k > 0 , where (j > k , where (j < 7 , 1 , 0 ), 0 ), 0 ),
701
+ sizes = [d_voc , n_ctx , n_ctx ],
702
+ )
703
+ .bool ()
704
+ .to (device )
705
+ )
651
706
optimiser = torch .optim .AdamW (
652
- model_1 .parameters (), lr = 5e-3 , betas = (0.9 , 0.999 ), weight_decay = 1. 0
707
+ model_1 .parameters (), lr = 0.5 , betas = (0.9 , 0.999 ), weight_decay = 0
653
708
)
709
+ # %%
710
+ a = loss_bound (model_1 , 3 )[4 ]
711
+ loss = 1 - a [valid ].min ()
712
+ print (a [valid ].min ())
713
+ print (a [valid ].mean ())
714
+ print (a [valid ].max ())
715
+ for i in range (1 ):
716
+ print (i + 1 )
654
717
655
- a = loss_bound (model_1 , 3 , 8 )[4 ]
656
- loss = 1 - a [a != 0 ].mean ()
657
- for i in range (30 ):
658
- print (a [a != 0 ].mean ())
659
718
loss .backward ()
660
719
optimiser .step ()
661
720
optimiser .zero_grad ()
662
- a = loss_bound (model_1 , 3 , 8 )[4 ][5 ]
663
- loss = 1 - a [a != 0 ].mean ()
664
- counter += 1
665
- print (counter )
721
+ a = loss_bound (model_1 , 3 )[4 ]
722
+ loss = 1 - a [valid ].min ()
723
+ print (a [valid ].min ())
724
+ print (a [valid ].mean ())
725
+ print (a [valid ].max ())
726
+ if i % 10 == 1 :
727
+ r = loss_bound (model_1 , 4 )[5 ]
728
+ print (r [valid ].min ())
729
+ print (r [valid ].mean ())
730
+ print (r [valid ].max ())
666
731
667
732
# %%
668
733
'''
0 commit comments