Skip to content

Commit a70a3be

Browse files
stuff
1 parent 5336fc5 commit a70a3be

File tree

1 file changed

+24
-25
lines changed

1 file changed

+24
-25
lines changed

gbmi/exp_indhead/finetune_ind.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ def good_loss_bound(model):
696696
a = second_layer_attention(matrices, attn_1)
697697
loss = 1 - (torch.nansum(a * weights_1) / (weights_1.sum()))
698698
print(a[~torch.isnan(a)].min())
699-
print(a[~torch.isnan(a)].mean())
699+
print(torch.nansum(a * weights_1) / (weights_1.sum()))
700700
print(a[~torch.isnan(a)].max())
701701
while loss > 0.5:
702702
# torch.autograd.set_detect_anomaly(True)
@@ -724,38 +724,37 @@ def good_loss_bound(model):
724724
.bool()
725725
.to(device)
726726
)
727+
weights_2 = ein.array(
728+
lambda i, j, k: where(k > 0, where(j > k, where(j < 7, 1, 0), 0), 0)
729+
* ((d_voc - 1) * ((d_voc - 1) ** (j - 2))),
730+
sizes=[d_voc, n_ctx, n_ctx],
731+
).to(device)
727732
# %%
728733
optimiser = torch.optim.AdamW(
729-
model_1.parameters(), lr=1, betas=(0.9, 0.999), weight_decay=0
734+
model_1.parameters(), lr=1e-2, betas=(0.9, 0.999), weight_decay=1.0
730735
)
731736
# %%
732-
optimiser = torch.optim.SGD(model_1.parameters(), lr=100)
737+
# optimiser = torch.optim.SGD(model_1.parameters(), lr=100)
733738
# %%
734-
a = loss_bound(model_1, 3)[4]
735-
loss = 1 - a[valid].mean()
736-
print(a[valid].min())
737-
print(a[valid].mean())
738-
print(a[valid].max())
739-
for i in range(1):
740-
print(i + 1)
741-
739+
bound = loss_bound(model_1)[1]
740+
loss = 1 - (torch.nansum(bound * weights_2) / (weights_2.sum()))
741+
print(bound[valid].min())
742+
print(torch.nansum(bound * weights_2) / (weights_2.sum()))
743+
print(bound[valid].max())
744+
while loss > 0.5:
745+
# torch.autograd.set_detect_anomaly(True)
742746
loss.backward()
747+
# torch.nn.utils.clip_grad_norm_(model_1.parameters(), max_norm=1.0)
743748
optimiser.step()
744-
for param in model_1.parameters():
745-
if param.requires_grad:
746-
print(param.grad.norm()) # Check gradient norms
747-
748749
optimiser.zero_grad()
749-
a = loss_bound(model_1, 3)[4]
750-
loss = 1 - a[valid].mean()
751-
print(a[valid].min())
752-
print(a[valid].mean())
753-
print(a[valid].max())
754-
if i % 10 == 1:
755-
r = loss_bound(model_1, 4)[5]
756-
print(r[valid].min())
757-
print(r[valid].mean())
758-
print(r[valid].max())
750+
bound = loss_bound(model_1)[1]
751+
loss = 1 - (torch.nansum(bound * weights_2) / (weights_2.sum()))
752+
counter += 1
753+
print(counter)
754+
print(bound[valid].min())
755+
print(torch.nansum(bound * weights_2) / (weights_2.sum()))
756+
print(bound[valid].max())
757+
759758

760759
# %%
761760
ModelMatrixLoggingOptions.all(

0 commit comments

Comments
 (0)