@@ -623,6 +623,8 @@ def total_bound(b, i_1, i_2, dic, n=None):
623
623
return (attn_1 , bound , bound_2 , out , out_2 , out_3 )
624
624
625
625
626
+ # %%
627
+ runtime_model_1 , model_1 = train_or_load_model (ABCAB8_1H , force = "load" )
626
628
# %%
627
629
optimiser = torch .optim .AdamW (
628
630
model_1 .parameters (), lr = 5e-3 , betas = (0.9 , 0.999 ), weight_decay = 1.0
@@ -656,10 +658,25 @@ def total_bound(b, i_1, i_2, dic, n=None):
656
658
optimiser = torch .optim .AdamW (
657
659
model_1 .parameters (), lr = 1e-3 , betas = (0.9 , 0.999 ), weight_decay = 1.0
658
660
)
661
+ # %%
662
+ a = loss_bound (model_1 , 2 )[2 ]
663
+ loss = 1 - a [~ torch .isnan (a )].mean ()
664
+ while loss > 0.5 :
665
+ print (a [~ torch .isnan (a )].min ())
666
+ print (a [~ torch .isnan (a )].mean ())
667
+ print (a [~ torch .isnan (a )].max ())
668
+ loss .backward ()
669
+ optimiser .step ()
670
+ optimiser .zero_grad ()
671
+ a = loss_bound (model_1 , 2 )[2 ]
672
+ loss = 1 - a [~ torch .isnan (a )].mean ()
673
+ counter += 1
674
+ print (counter )
659
675
676
+ # %%
660
677
a = loss_bound (model_1 , 2 )[2 ]
661
678
loss = 1 - a [~ torch .isnan (a )].min ()
662
- while loss > 0.1 :
679
+ while loss > 0.5 :
663
680
print (a [~ torch .isnan (a )].min ())
664
681
print (a [~ torch .isnan (a )].mean ())
665
682
print (a [~ torch .isnan (a )].max ())
@@ -670,8 +687,6 @@ def total_bound(b, i_1, i_2, dic, n=None):
670
687
loss = 1 - a [~ torch .isnan (a )].min ()
671
688
counter += 1
672
689
print (counter )
673
-
674
-
675
690
# %%
676
691
valid = (
677
692
ein .array (
@@ -681,6 +696,7 @@ def total_bound(b, i_1, i_2, dic, n=None):
681
696
.bool ()
682
697
.to (device )
683
698
)
699
+ # %%
684
700
optimiser = torch .optim .AdamW (
685
701
model_1 .parameters (), lr = 1 , betas = (0.9 , 0.999 ), weight_decay = 0
686
702
)
@@ -714,6 +730,9 @@ def total_bound(b, i_1, i_2, dic, n=None):
714
730
print (r [valid ].max ())
715
731
716
732
# %%
733
+ ModelMatrixLoggingOptions .all (
734
+ use_subplots = True , add_mean = {- 1 : None , 0 : "tok_to_pos" , 1 : None }
735
+ ).plot_matrices_from_model (model )
717
736
'''
718
737
def least_attention_2(a, b, i_1, i_2, j):
719
738
0 commit comments