Skip to content

Commit aad8f09

Browse files
saving
1 parent 4a2c7ab commit aad8f09

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

gbmi/exp_indhead/finetunebound.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,8 @@ def total_bound(b, i_1, i_2, dic, n=None):
623623
return (attn_1, bound, bound_2, out, out_2, out_3)
624624

625625

626+
# %%
627+
runtime_model_1, model_1 = train_or_load_model(ABCAB8_1H, force="load")
626628
# %%
627629
optimiser = torch.optim.AdamW(
628630
model_1.parameters(), lr=5e-3, betas=(0.9, 0.999), weight_decay=1.0
@@ -656,10 +658,25 @@ def total_bound(b, i_1, i_2, dic, n=None):
656658
optimiser = torch.optim.AdamW(
657659
model_1.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=1.0
658660
)
661+
# %%
662+
a = loss_bound(model_1, 2)[2]
663+
loss = 1 - a[~torch.isnan(a)].mean()
664+
while loss > 0.5:
665+
print(a[~torch.isnan(a)].min())
666+
print(a[~torch.isnan(a)].mean())
667+
print(a[~torch.isnan(a)].max())
668+
loss.backward()
669+
optimiser.step()
670+
optimiser.zero_grad()
671+
a = loss_bound(model_1, 2)[2]
672+
loss = 1 - a[~torch.isnan(a)].mean()
673+
counter += 1
674+
print(counter)
659675

676+
# %%
660677
a = loss_bound(model_1, 2)[2]
661678
loss = 1 - a[~torch.isnan(a)].min()
662-
while loss > 0.1:
679+
while loss > 0.5:
663680
print(a[~torch.isnan(a)].min())
664681
print(a[~torch.isnan(a)].mean())
665682
print(a[~torch.isnan(a)].max())
@@ -670,8 +687,6 @@ def total_bound(b, i_1, i_2, dic, n=None):
670687
loss = 1 - a[~torch.isnan(a)].min()
671688
counter += 1
672689
print(counter)
673-
674-
675690
# %%
676691
valid = (
677692
ein.array(
@@ -681,6 +696,7 @@ def total_bound(b, i_1, i_2, dic, n=None):
681696
.bool()
682697
.to(device)
683698
)
699+
# %%
684700
optimiser = torch.optim.AdamW(
685701
model_1.parameters(), lr=1, betas=(0.9, 0.999), weight_decay=0
686702
)
@@ -714,6 +730,9 @@ def total_bound(b, i_1, i_2, dic, n=None):
714730
print(r[valid].max())
715731

716732
# %%
733+
ModelMatrixLoggingOptions.all(
734+
use_subplots=True, add_mean={-1: None, 0: "tok_to_pos", 1: None}
735+
).plot_matrices_from_model(model)
717736
'''
718737
def least_attention_2(a, b, i_1, i_2, j):
719738

0 commit comments

Comments
 (0)