Skip to content

Commit 28f5508

Browse files
f
1 parent 78eed30 commit 28f5508

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

gbmi/exp_indhead/finetunebound.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -653,22 +653,16 @@ def total_bound(b, i_1, i_2, dic, n=None):
653653
print(counter)
654654
# %%
655655

656-
a = loss_bound(model_1, 2)[2]
657-
loss = 1 - a[~torch.isnan(a)].mean()
658-
while loss > 0.1:
659-
print(1 - loss)
660-
loss.backward()
661-
optimiser.step()
662-
optimiser.zero_grad()
663-
a = loss_bound(model_1, 2)[2]
664-
loss = 1 - a[~torch.isnan(a)].mean()
665-
counter += 1
666-
print(counter)
667-
# %%
656+
optimiser = torch.optim.AdamW(
657+
model_1.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=1.0
658+
)
659+
668660
a = loss_bound(model_1, 2)[2]
669661
loss = 1 - a[~torch.isnan(a)].min()
670-
while loss > 0.5:
671-
print(1 - loss)
662+
while loss > 0.1:
663+
print(a[~torch.isnan(a)].min())
664+
print(a[~torch.isnan(a)].mean())
665+
print(a[~torch.isnan(a)].max())
672666
loss.backward()
673667
optimiser.step()
674668
optimiser.zero_grad()

0 commit comments

Comments
 (0)