Skip to content

Commit aa906de

Browse files
stuff
1 parent 6177dd4 commit aa906de

File tree

1 file changed

+17
-29
lines changed

1 file changed

+17
-29
lines changed

gbmi/exp_indhead/finetunebound.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -628,72 +628,54 @@ def total_bound(b, i_1, i_2, dic, n=None):
628628

629629
counter = 0
630630
# %%
631-
loss = loss_bound(model_1, 0, 5)
631+
loss = loss_bound(model_1, 0)
632632
for i in range(100):
633633
print(loss)
634634
loss.backward()
635635
optimiser.step()
636636
optimiser.zero_grad()
637-
loss = loss_bound(model_1, 0, 5)
637+
loss = loss_bound(model_1, 0)
638638
counter += 1
639639
print(counter)
640640

641641

642642
# %%
643-
loss = 1 - loss_bound(model_1, 1, 5).min()
643+
loss = 1 - loss_bound(model_1, 1).min()
644644
while loss > 0.02:
645645
print(1 - loss)
646646
loss.backward()
647647
optimiser.step()
648648
optimiser.zero_grad()
649-
loss = 1 - loss_bound(model_1, 1, 5).min()
649+
loss = 1 - loss_bound(model_1, 1).min()
650650
counter += 1
651651
print(counter)
652652
# %%
653653

654-
a = loss_bound(model_1, 2, 6)[2]
654+
a = loss_bound(model_1, 2)[2]
655655
loss = 1 - a[~torch.isnan(a)].mean()
656656
while loss > 0.1:
657657
print(1 - loss)
658658
loss.backward()
659659
optimiser.step()
660660
optimiser.zero_grad()
661-
a = loss_bound(model_1, 2, 6)[2]
661+
a = loss_bound(model_1, 2)[2]
662662
loss = 1 - a[~torch.isnan(a)].mean()
663663
counter += 1
664664
print(counter)
665665
# %%
666-
a = loss_bound(model_1, 2, 8)[2]
666+
a = loss_bound(model_1, 2)[2]
667667
loss = 1 - a[~torch.isnan(a)].min()
668668
while loss > 0.5:
669669
print(1 - loss)
670670
loss.backward()
671671
optimiser.step()
672672
optimiser.zero_grad()
673-
a = loss_bound(model_1, 2, 8)[2]
673+
a = loss_bound(model_1, 2)[2]
674674
loss = 1 - a[~torch.isnan(a)].min()
675675
counter += 1
676676
print(counter)
677677

678678

679-
# %%
680-
counter = 0
681-
optimiser = torch.optim.AdamW(
682-
model_1.parameters(), lr=5e-1, betas=(0.9, 0.999), weight_decay=1.0
683-
)
684-
685-
a = loss_bound(model_1, 3, 8)[4]
686-
loss = 1 - a[a != 0].mean()
687-
for i in range(1):
688-
print(a[a != 0].mean())
689-
loss.backward()
690-
optimiser.step()
691-
optimiser.zero_grad()
692-
a = loss_bound(model_1, 3, 8)[4][5]
693-
loss = 1 - a[a != 0].mean()
694-
counter += 1
695-
print(counter)
696-
697679
# %%
698680
valid = (
699681
ein.array(
@@ -704,11 +686,13 @@ def total_bound(b, i_1, i_2, dic, n=None):
704686
.to(device)
705687
)
706688
optimiser = torch.optim.AdamW(
707-
model_1.parameters(), lr=0.5, betas=(0.9, 0.999), weight_decay=0
689+
model_1.parameters(), lr=1, betas=(0.9, 0.999), weight_decay=0
708690
)
709691
# %%
692+
optimiser = torch.optim.SGD(model_1.parameters(), lr=100)
693+
# %%
710694
a = loss_bound(model_1, 3)[4]
711-
loss = 1 - a[valid].min()
695+
loss = 1 - a[valid].mean()
712696
print(a[valid].min())
713697
print(a[valid].mean())
714698
print(a[valid].max())
@@ -717,9 +701,13 @@ def total_bound(b, i_1, i_2, dic, n=None):
717701

718702
loss.backward()
719703
optimiser.step()
704+
for param in model_1.parameters():
705+
if param.requires_grad:
706+
print(param.grad.norm()) # Check gradient norms
707+
720708
optimiser.zero_grad()
721709
a = loss_bound(model_1, 3)[4]
722-
loss = 1 - a[valid].min()
710+
loss = 1 - a[valid].mean()
723711
print(a[valid].min())
724712
print(a[valid].mean())
725713
print(a[valid].max())

0 commit comments

Comments
 (0)