Skip to content

Commit 06254cf

Browse files
checking
1 parent a48c11f commit 06254cf

File tree

5 files changed

+113
-23
lines changed

5 files changed

+113
-23
lines changed

gbmi/exp_indhead/finetune_ind.py

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,8 @@ def diff_3(a, i_1, i_2, j, dic, matrices, attn_1):
206206
c = torch.max(c, term_3[i_2, dic[i_2], i, dic[i]].max())
207207
c = torch.max(c, term_3[i_2, dic[i_2], j, dic[j]].max())
208208
t_3 += (1 - attn_1[dic[j], j - 1].min()) * c
209+
print(a, i_1, i_2, j)
210+
print(c)
209211

210212
# print(t_3)
211213
if j == 1:
@@ -256,6 +258,8 @@ def diff_2_4(a, i_1, i_2, j, dic, matrices, attn_1):
256258
c = torch.max(c, term_4[k, dic[k], i][..., dic[i]].max())
257259
c = torch.max(c, term_4[k, dic[k], j][..., dic[j]].max())
258260
d = d + (1 - attn_1[dic[j], j - 1].min()) * c
261+
if k == 0:
262+
print(c)
259263

260264
if j == 0:
261265

@@ -353,6 +357,8 @@ def diff_2_3_4(a, i_1, i_2, j, dic, matrices, attn_1):
353357
+ term_3[i_2, dic[i_2], j, dic[j]].max(),
354358
)
355359
d = d + (1 - attn_1[dic[j], j - 1].min()) * c
360+
if k == 0:
361+
print(c)
356362

357363
if j == 0:
358364

@@ -445,13 +451,16 @@ def diff_2_3_4(a, i_1, i_2, j, dic, matrices, attn_1):
445451

446452

447453
def least_attention(a, i_1, i_2, j, dic, matrices, attn_1):
448-
e = diff_2_4(a, i_1, i_2, j, dic, matrices, attn_1)
449454

450-
return (
451-
diff_1(a, i_1, i_2, j, dic, matrices)
452-
+ e
453-
+ diff_3(a, i_1, i_2, j, dic, matrices, attn_1)
454-
)
455+
g = diff_3(a, i_1, i_2, j, dic, matrices, attn_1)
456+
f = diff_2_4(a, i_1, i_2, j, dic, matrices, attn_1)
457+
e = diff_2_3_4(a, i_1, i_2, j, dic, matrices, attn_1)
458+
459+
# print(a, i_1, i_2, j)
460+
# print(e)
461+
# print(f+g)
462+
# print(e-f-g)
463+
return diff_1(a, i_1, i_2, j, dic, matrices) + f + g
455464

456465

457466
def second_layer_attention(matrices, attn_1):
@@ -1027,4 +1036,69 @@ def good_loss_bound(model):
10271036
# Show the plot
10281037
plt.show()
10291038

1039+
# %%
1040+
import torch as t
1041+
1042+
1043+
def sample(a, b, i, d_voc):
1044+
# i goes from 1 to n_ctx-3
1045+
# randomly fill with tokens which are not equal to a
1046+
seq = t.randint(low=0, high=d_voc - 1, size=(i + 3,))
1047+
seq = seq + (seq >= a).int()
1048+
1049+
# fill last position with a
1050+
seq[-1] = a
1051+
1052+
# pick position of first a
1053+
m = t.randint(low=0, high=i, size=(1,)).item()
1054+
1055+
# fill position m with b
1056+
seq[m + 1] = a
1057+
seq[m + 2] = b
1058+
return seq
1059+
1060+
1061+
def sample_acc_and_loss(model, batch_size=15000):
1062+
d_vocab = model.W_E.shape[0]
1063+
n_ctx = model.W_pos.shape[0]
1064+
1065+
acc = 0
1066+
loss = 0
1067+
1068+
loss_CE = t.nn.CrossEntropyLoss()
1069+
1070+
# Compute probability of each sequence length
1071+
sample_seq_length = t.arange(1, n_ctx - 3)
1072+
prob_sample_seq_len = t.tensor([i * (d_vocab - 1) ** i for i in sample_seq_length])
1073+
prob_sample_seq_len = prob_sample_seq_len / prob_sample_seq_len.sum()
1074+
1075+
# sample the sequence length
1076+
sampled = sample_seq_length[
1077+
torch.multinomial(prob_sample_seq_len, num_samples=batch_size, replacement=True)
1078+
]
1079+
1080+
# sample a
1081+
sample_a = t.randint(0, d_vocab, (batch_size,))
1082+
1083+
with t.no_grad():
1084+
for i in range(batch_size):
1085+
# sample a
1086+
a = sample_a[i].item()
1087+
1088+
# sample b unequal to a
1089+
b = t.randint(0, d_vocab - 1, (1,)).item()
1090+
b = b + (b >= a)
1091+
length = sampled[i]
1092+
1093+
# sample sequence
1094+
seq = sample(a, b, length, d_vocab)
1095+
1096+
# measure accuracy and loss
1097+
logit = model(seq).squeeze()[-1]
1098+
acc += logit.argmax() == b
1099+
loss += loss_CE(logit.unsqueeze(0), t.tensor([b]))
1100+
1101+
return acc / batch_size, loss / batch_size
1102+
1103+
10301104
# %%

gbmi/exp_indhead/finetuned_model.pth

120 KB
Binary file not shown.

gbmi/exp_indhead/induction_head_results.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,7 @@ def sample_acc_and_loss(model, batch_size=15000):
791791

792792
W_U = ein.array(lambda i, j: i == j, sizes=[d_model, d_voc]).float().to(device)
793793

794+
raw_terms = [W_U, W_K_1, W_K_0, W_Q_0, W_Q_1, W_V_0, W_V_1, W_E, W_O_0, W_O_1, W_pos]
794795

795796
index_0_d = (
796797
ein.array(
@@ -1013,22 +1014,24 @@ def plot_diag(data, i):
10131014

10141015

10151016
# %%
1016-
def put_in_model(model):
1017-
model.W_U.data = W_U
1018-
model.blocks[0].attn.W_K.data[0] = W_K_0
1019-
model.blocks[1].attn.W_K.data[0] = W_K_1
1020-
model.blocks[0].attn.W_Q.data[0] = W_Q_0
1021-
model.blocks[1].attn.W_Q.data[0] = W_Q_1
1022-
model.blocks[0].attn.W_V.data[0] = W_V_0
1023-
model.blocks[1].attn.W_V.data[0] = W_V_1
1024-
1025-
model.W_E.data = W_E
1026-
model.blocks[0].attn.W_O.data[0] = W_O_0
1027-
model.blocks[1].attn.W_O.data[0] = W_O_1
1028-
model.W_pos.data = W_pos
1029-
1030-
1031-
put_in_model(model_2)
1017+
1018+
1019+
def put_in_model(model, raw):
1020+
model.W_U.data = raw[0]
1021+
model.blocks[0].attn.W_K.data[0] = raw[2]
1022+
model.blocks[1].attn.W_K.data[0] = raw[1]
1023+
model.blocks[0].attn.W_Q.data[0] = raw[3]
1024+
model.blocks[1].attn.W_Q.data[0] = raw[4]
1025+
model.blocks[0].attn.W_V.data[0] = raw[5]
1026+
model.blocks[1].attn.W_V.data[0] = raw[6]
1027+
1028+
model.W_E.data = raw[7]
1029+
model.blocks[0].attn.W_O.data[0] = raw[8]
1030+
model.blocks[1].attn.W_O.data[0] = raw[9]
1031+
model.W_pos.data = raw[10]
1032+
1033+
1034+
put_in_model(model_2, raw_terms)
10321035
correct_terms = terms(model_2)
10331036
correct_terms = tuple(term.clone().detach() for term in correct_terms)
10341037

@@ -1123,4 +1126,17 @@ def get_graphs(fun, model):
11231126
return term_dic
11241127

11251128

1129+
# %%
1130+
def noise(M, v):
1131+
return M + torch.randn_like(M) * v
1132+
1133+
1134+
def add_noise(model, v):
1135+
new_raw_terms = []
1136+
for i in range(len(raw_terms)):
1137+
new_raw_terms.append(noise(raw_terms[i].detach().clone(), v))
1138+
new_raw_terms[i].requires_grad = True
1139+
put_in_model(model, new_raw_terms)
1140+
1141+
11261142
# %%

gbmi/exp_indhead/noise_bound.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def armin(
5858

5959
# %%
6060
def noise(M, v):
61-
return M + (torch.rand_like(M) - 0.5) * 2 * v
61+
return M + torch.randn_like(M) * v
6262

6363

6464
W_E = ein.array(lambda i, j: i == j, sizes=[d_voc, d_model]).float().to(device)

gbmi/exp_indhead/term.pt

20.3 KB
Binary file not shown.

0 commit comments

Comments
 (0)