@@ -205,15 +205,15 @@ def diff_3(a, i_1, i_2, j, dic, matrices, attn_1):
205
205
for i in range (0 , j - 1 ):
206
206
c = torch .max (c , term_3 [i_2 , dic [i_2 ], i , dic [i ]].max ())
207
207
c = torch .max (c , term_3 [i_2 , dic [i_2 ], j , dic [j ]].max ())
208
- t_3 = t_3 + (1 - attn_1 [dic [j ], j - 1 ].min ()) * c
208
+ t_3 += (1 - attn_1 [dic [j ], j - 1 ].min ()) * c
209
209
210
210
# print(t_3)
211
211
if j == 1 :
212
212
c = term_3 [i_2 , a , j - 1 , dic [j - 1 ]].max ()
213
213
# new=c.clone()
214
214
t_3 = c * attn_1 [dic [j ], j - 1 ].min ()
215
215
c = torch .max (c , term_3 [i_2 , a , j , dic [j ]].max ())
216
- t_3 = t_3 + (1 - attn_1 [dic [j ], j - 1 ].min ()) * c
216
+ t_3 += (1 - attn_1 [dic [j ], j - 1 ].min ()) * c
217
217
218
218
if j == 0 :
219
219
@@ -349,10 +349,10 @@ def second_layer_attention(matrices, attn_1):
349
349
- torch .inf
350
350
)
351
351
352
- for a in range (term_0 .shape [1 ]):
352
+ for a in range (0 , term_0 .shape [1 ]):
353
353
354
- for i_2 in range (2 , term_0 .shape [0 ] - 1 ):
355
- for i_1 in range (1 , i_2 ):
354
+ for i_2 in range (3 , term_0 .shape [0 ] - 1 ):
355
+ for i_1 in range (2 , i_2 ):
356
356
for j in range (i_2 + 1 ):
357
357
if (i_1 < i_2 ) & (i_1 > 0 ) & (i_2 + 1 > j ):
358
358
dic = {
@@ -607,8 +607,8 @@ def loss_bound(model):
607
607
608
608
for b in range (term_0 .shape [1 ]):
609
609
610
- for i_2 in range (2 , term_0 .shape [0 ] - 1 ):
611
- for i_1 in range (1 , i_2 ):
610
+ for i_2 in range (3 , term_0 .shape [0 ] - 1 ):
611
+ for i_1 in range (2 , i_2 ):
612
612
613
613
if (i_1 < i_2 ) & (i_1 > 0 ):
614
614
dic = {i_1 : b }
@@ -621,7 +621,7 @@ def loss_bound(model):
621
621
622
622
out_2 = 1 / (1 + ((d_voc - 1 ) * torch .exp (out )))
623
623
624
- return (out , out_2 )
624
+ return (attn_1 , bound_2 , out , out_2 )
625
625
626
626
627
627
def good_loss_bound (model ):
@@ -630,13 +630,13 @@ def good_loss_bound(model):
630
630
attn_1 = first_layer_attention (matrices )
631
631
bound_2 = second_layer_attention (matrices , attn_1 )
632
632
633
- out = torch .zeros ((d_voc , n_ctx , n_ctx , d_voc )) + torch . inf
633
+ out = torch .zeros ((d_voc , n_ctx , n_ctx , d_voc ))
634
634
# b i_2 i_1
635
635
636
636
for b in range (term_0 .shape [1 ]):
637
637
for n in range (term_0 .shape [1 ]):
638
- for i_2 in range (term_0 .shape [0 ] - 1 ):
639
- for i_1 in range (1 , i_2 ):
638
+ for i_2 in range (3 , term_0 .shape [0 ] - 1 ):
639
+ for i_1 in range (2 , i_2 ):
640
640
641
641
if (i_1 < i_2 ) & (i_1 > 0 ):
642
642
dic = {i_1 : b }
@@ -687,8 +687,8 @@ def good_loss_bound(model):
687
687
# %%
688
688
weights_1 = torch .zeros ((d_voc , n_ctx , n_ctx ))
689
689
for a in range (d_voc ):
690
- for i_2 in range (2 , n_ctx - 1 ):
691
- for i_1 in range (1 , i_2 ):
690
+ for i_2 in range (3 , n_ctx - 1 ):
691
+ for i_1 in range (2 , i_2 ):
692
692
weights_1 [a , i_2 , i_1 ] = (d_voc - 1 ) ** (i_2 - 1 )
693
693
# %%
694
694
matrices = terms (model_1 )
@@ -718,30 +718,30 @@ def good_loss_bound(model):
718
718
# %%
719
719
valid = (
720
720
ein .array (
721
- lambda i , j , k : where (k > 0 , where (j > k , where (j < 7 , 1 , 0 ), 0 ), 0 ),
721
+ lambda i , j , k : where (k > 1 , where (j > k , where (j < 7 , 1 , 0 ), 0 ), 0 ),
722
722
sizes = [d_voc , n_ctx , n_ctx ],
723
723
)
724
724
.bool ()
725
725
.to (device )
726
726
)
727
727
weights_2 = ein .array (
728
- lambda i , j , k : where (k > 0 , where (j > k , where (j < 7 , 1 , 0 ), 0 ), 0 )
728
+ lambda i , j , k : where (k > 1 , where (j > k , where (j < 7 , 1 , 0 ), 0 ), 0 )
729
729
* ((d_voc - 1 ) * ((d_voc - 1 ) ** (j - 2 ))),
730
730
sizes = [d_voc , n_ctx , n_ctx ],
731
731
).to (device )
732
732
# %%
733
733
optimiser = torch .optim .AdamW (
734
- model_1 .parameters (), lr = 1e -2 , betas = (0.9 , 0.999 ), weight_decay = 1.0
734
+ model_1 .parameters (), lr = 5e -2 , betas = (0.9 , 0.999 ), weight_decay = 1.0
735
735
)
736
736
# %%
737
737
# optimiser = torch.optim.SGD(model_1.parameters(), lr=100)
738
738
# %%
739
- bound = loss_bound (model_1 )[1 ]
739
+ bound = loss_bound (model_1 )[3 ]
740
740
loss = 1 - (torch .nansum (bound * weights_2 ) / (weights_2 .sum ()))
741
741
print (bound [valid ].min ())
742
742
print (torch .nansum (bound * weights_2 ) / (weights_2 .sum ()))
743
743
print (bound [valid ].max ())
744
- while loss > 0.5 :
744
+ while loss > 0.05 :
745
745
# torch.autograd.set_detect_anomaly(True)
746
746
loss .backward ()
747
747
# torch.nn.utils.clip_grad_norm_(model_1.parameters(), max_norm=1.0)
@@ -756,6 +756,19 @@ def good_loss_bound(model):
756
756
print (bound [valid ].max ())
757
757
758
758
759
+ # %%
760
+ for i in range (10 ):
761
+ print (i )
762
+ a = loss_bound (model_1 )
763
+ loss = 1 - ((torch .nansum (a [3 ] * weights_2 ) / (weights_2 .sum ())))
764
+ print (a [0 ].min ())
765
+ print (torch .nansum (a [3 ] * weights_2 ) / (weights_2 .sum ()))
766
+ print (torch .nansum (a [1 ] * weights_1 ) / (weights_1 .sum ()))
767
+ loss .backward ()
768
+ optimiser .step ()
769
+ optimiser .zero_grad ()
770
+
771
+
759
772
# %%
760
773
ModelMatrixLoggingOptions .all (
761
774
use_subplots = True , add_mean = {- 1 : None , 0 : "tok_to_pos" , 1 : None }
0 commit comments