@@ -628,72 +628,54 @@ def total_bound(b, i_1, i_2, dic, n=None):
628628
629629counter = 0
630630# %%
631- loss = loss_bound (model_1 , 0 , 5 )
631+ loss = loss_bound (model_1 , 0 )
632632for i in range (100 ):
633633 print (loss )
634634 loss .backward ()
635635 optimiser .step ()
636636 optimiser .zero_grad ()
637- loss = loss_bound (model_1 , 0 , 5 )
637+ loss = loss_bound (model_1 , 0 )
638638 counter += 1
639639 print (counter )
640640
641641
642642# %%
643- loss = 1 - loss_bound (model_1 , 1 , 5 ).min ()
643+ loss = 1 - loss_bound (model_1 , 1 ).min ()
644644while loss > 0.02 :
645645 print (1 - loss )
646646 loss .backward ()
647647 optimiser .step ()
648648 optimiser .zero_grad ()
649- loss = 1 - loss_bound (model_1 , 1 , 5 ).min ()
649+ loss = 1 - loss_bound (model_1 , 1 ).min ()
650650 counter += 1
651651 print (counter )
652652# %%
653653
654- a = loss_bound (model_1 , 2 , 6 )[2 ]
654+ a = loss_bound (model_1 , 2 )[2 ]
655655loss = 1 - a [~ torch .isnan (a )].mean ()
656656while loss > 0.1 :
657657 print (1 - loss )
658658 loss .backward ()
659659 optimiser .step ()
660660 optimiser .zero_grad ()
661- a = loss_bound (model_1 , 2 , 6 )[2 ]
661+ a = loss_bound (model_1 , 2 )[2 ]
662662 loss = 1 - a [~ torch .isnan (a )].mean ()
663663 counter += 1
664664 print (counter )
665665# %%
666- a = loss_bound (model_1 , 2 , 8 )[2 ]
666+ a = loss_bound (model_1 , 2 )[2 ]
667667loss = 1 - a [~ torch .isnan (a )].min ()
668668while loss > 0.5 :
669669 print (1 - loss )
670670 loss .backward ()
671671 optimiser .step ()
672672 optimiser .zero_grad ()
673- a = loss_bound (model_1 , 2 , 8 )[2 ]
673+ a = loss_bound (model_1 , 2 )[2 ]
674674 loss = 1 - a [~ torch .isnan (a )].min ()
675675 counter += 1
676676 print (counter )
677677
678678
679- # %%
680- counter = 0
681- optimiser = torch .optim .AdamW (
682- model_1 .parameters (), lr = 5e-1 , betas = (0.9 , 0.999 ), weight_decay = 1.0
683- )
684-
685- a = loss_bound (model_1 , 3 , 8 )[4 ]
686- loss = 1 - a [a != 0 ].mean ()
687- for i in range (1 ):
688- print (a [a != 0 ].mean ())
689- loss .backward ()
690- optimiser .step ()
691- optimiser .zero_grad ()
692- a = loss_bound (model_1 , 3 , 8 )[4 ][5 ]
693- loss = 1 - a [a != 0 ].mean ()
694- counter += 1
695- print (counter )
696-
697679# %%
698680valid = (
699681 ein .array (
@@ -704,11 +686,13 @@ def total_bound(b, i_1, i_2, dic, n=None):
704686 .to (device )
705687)
706688optimiser = torch .optim .AdamW (
707- model_1 .parameters (), lr = 0.5 , betas = (0.9 , 0.999 ), weight_decay = 0
689+ model_1 .parameters (), lr = 1 , betas = (0.9 , 0.999 ), weight_decay = 0
708690)
709691# %%
692+ optimiser = torch .optim .SGD (model_1 .parameters (), lr = 100 )
693+ # %%
710694a = loss_bound (model_1 , 3 )[4 ]
711- loss = 1 - a [valid ].min ()
695+ loss = 1 - a [valid ].mean ()
712696print (a [valid ].min ())
713697print (a [valid ].mean ())
714698print (a [valid ].max ())
@@ -717,9 +701,13 @@ def total_bound(b, i_1, i_2, dic, n=None):
717701
718702 loss .backward ()
719703 optimiser .step ()
704+ for param in model_1 .parameters ():
705+ if param .requires_grad :
706+ print (param .grad .norm ()) # Check gradient norms
707+
720708 optimiser .zero_grad ()
721709 a = loss_bound (model_1 , 3 )[4 ]
722- loss = 1 - a [valid ].min ()
710+ loss = 1 - a [valid ].mean ()
723711 print (a [valid ].min ())
724712 print (a [valid ].mean ())
725713 print (a [valid ].max ())
0 commit comments