@@ -696,7 +696,7 @@ def good_loss_bound(model):
696
696
a = second_layer_attention (matrices , attn_1 )
697
697
loss = 1 - (torch .nansum (a * weights_1 ) / (weights_1 .sum ()))
698
698
print (a [~ torch .isnan (a )].min ())
699
- print (a [ ~ torch .isnan ( a )]. mean ( ))
699
+ print (torch .nansum ( a * weights_1 ) / ( weights_1 . sum () ))
700
700
print (a [~ torch .isnan (a )].max ())
701
701
while loss > 0.5 :
702
702
# torch.autograd.set_detect_anomaly(True)
@@ -724,38 +724,37 @@ def good_loss_bound(model):
724
724
.bool ()
725
725
.to (device )
726
726
)
727
+ weights_2 = ein .array (
728
+ lambda i , j , k : where (k > 0 , where (j > k , where (j < 7 , 1 , 0 ), 0 ), 0 )
729
+ * ((d_voc - 1 ) * ((d_voc - 1 ) ** (j - 2 ))),
730
+ sizes = [d_voc , n_ctx , n_ctx ],
731
+ ).to (device )
727
732
# %%
728
733
optimiser = torch .optim .AdamW (
729
- model_1 .parameters (), lr = 1 , betas = (0.9 , 0.999 ), weight_decay = 0
734
+ model_1 .parameters (), lr = 1e-2 , betas = (0.9 , 0.999 ), weight_decay = 1. 0
730
735
)
731
736
# %%
732
- optimiser = torch .optim .SGD (model_1 .parameters (), lr = 100 )
737
+ # optimiser = torch.optim.SGD(model_1.parameters(), lr=100)
733
738
# %%
734
- a = loss_bound (model_1 , 3 )[4 ]
735
- loss = 1 - a [valid ].mean ()
736
- print (a [valid ].min ())
737
- print (a [valid ].mean ())
738
- print (a [valid ].max ())
739
- for i in range (1 ):
740
- print (i + 1 )
741
-
739
+ bound = loss_bound (model_1 )[1 ]
740
+ loss = 1 - (torch .nansum (bound * weights_2 ) / (weights_2 .sum ()))
741
+ print (bound [valid ].min ())
742
+ print (torch .nansum (bound * weights_2 ) / (weights_2 .sum ()))
743
+ print (bound [valid ].max ())
744
+ while loss > 0.5 :
745
+ # torch.autograd.set_detect_anomaly(True)
742
746
loss .backward ()
747
+ # torch.nn.utils.clip_grad_norm_(model_1.parameters(), max_norm=1.0)
743
748
optimiser .step ()
744
- for param in model_1 .parameters ():
745
- if param .requires_grad :
746
- print (param .grad .norm ()) # Check gradient norms
747
-
748
749
optimiser .zero_grad ()
749
- a = loss_bound (model_1 , 3 )[4 ]
750
- loss = 1 - a [valid ].mean ()
751
- print (a [valid ].min ())
752
- print (a [valid ].mean ())
753
- print (a [valid ].max ())
754
- if i % 10 == 1 :
755
- r = loss_bound (model_1 , 4 )[5 ]
756
- print (r [valid ].min ())
757
- print (r [valid ].mean ())
758
- print (r [valid ].max ())
750
+ bound = loss_bound (model_1 )[1 ]
751
+ loss = 1 - (torch .nansum (bound * weights_2 ) / (weights_2 .sum ()))
752
+ counter += 1
753
+ print (counter )
754
+ print (bound [valid ].min ())
755
+ print (torch .nansum (bound * weights_2 ) / (weights_2 .sum ()))
756
+ print (bound [valid ].max ())
757
+
759
758
760
759
# %%
761
760
ModelMatrixLoggingOptions .all (
0 commit comments