Skip to content

Commit 6177dd4

Browse files
stuff
1 parent 5b6a68e commit 6177dd4

File tree

1 file changed

+51
-64
lines changed

1 file changed

+51
-64
lines changed

gbmi/exp_indhead/noise_bound.py

Lines changed: 51 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -189,71 +189,30 @@ def noise(M, v):
189189
)
190190

191191

192-
def hand_bound(s, v):
192+
def hand_bound(
193+
W_E, W_pos, W_V_0, W_V_1, W_O_0, W_O_1, W_Q_0, W_Q_1, W_K_0, W_K_1, W_U, s, v
194+
):
193195

194-
W_E = ein.array(lambda i, j: i == j, sizes=[d_voc, d_model]).float().to(device)
195196
W_E = noise(W_E, v)
196-
W_pos = (
197-
ein.array(lambda i, j: ((i + d_voc) == j) * 1.0, sizes=[n_ctx, d_model])
198-
.float()
199-
.to(device)
200-
)
197+
201198
W_pos = noise(W_pos, v)
202-
W_O_0 = (
203-
ein.array(lambda i, j: ((i + n_ctx + d_voc) == j) * 1.0, sizes=[d_voc, d_model])
204-
.float()
205-
.to(device)
206-
)
199+
207200
W_O_0 = noise(W_O_0, v)
208-
W_V_0 = (
209-
ein.array(lambda i, j: (i == j) * 1.0, sizes=[d_model, d_voc])
210-
.float()
211-
.to(device)
212-
)
201+
213202
W_V_0 = noise(W_V_0, v)
214-
W_V_1 = (
215-
ein.array(lambda i, j: (i == j) * 1.0, sizes=[d_model, d_voc])
216-
.float()
217-
.to(device)
218-
)
203+
219204
W_V_1 = noise(W_V_1, v)
220-
W_O_1 = (
221-
ein.array(lambda i, j: (i == j) * 100, sizes=[d_voc, d_model])
222-
.float()
223-
.to(device)
224-
)
205+
225206
W_O_1 = noise(W_O_1, v)
226-
W_Q_0 = (
227-
ein.array(
228-
lambda i, j: where((i + d_voc + 1) == j, c, 0), sizes=[n_ctx, d_model]
229-
)
230-
.float()
231-
.to(device)
232-
.T
233-
)
207+
234208
W_Q_0 = noise(W_Q_0, v)
235-
W_Q_1 = (
236-
ein.array(lambda i, j: where(i == j, d, 0), sizes=[d_voc, d_model])
237-
.float()
238-
.T.to(device)
239-
)
209+
240210
W_Q_1 = noise(W_Q_1, v)
241-
W_K_0 = (
242-
ein.array(lambda i, j: where((i + d_voc) == j, c, 0), sizes=[n_ctx, d_model])
243-
.float()
244-
.T
245-
).to(device)
211+
246212
W_K_0 = noise(W_K_0, v)
247-
W_K_1 = (
248-
ein.array(
249-
lambda i, j: where((i + n_ctx + d_voc) == j, d, 0),
250-
sizes=[d_voc, d_model],
251-
)
252-
.float()
253-
.T
254-
).to(device)
213+
255214
W_K_1 = noise(W_K_1, v)
256-
W_U = ein.array(lambda i, j: i == j, sizes=[d_model, d_voc]).float().to(device)
215+
257216
W_U = noise(W_U, v)
258217

259218
e_p = W_E.unsqueeze(dim=0) + W_pos.unsqueeze(dim=1)
@@ -335,6 +294,9 @@ def hand_bound(s, v):
335294
"q_pos q_val k, k l, l m, m n, n p, p q -> q_pos q_val q",
336295
)
337296

297+
if s == -1:
298+
return (term_0, term_1, term_2, term_3, term_4, term_5, term_6, term_7, term_8)
299+
338300
table = torch.zeros((d_voc, d_voc, n_ctx - 2, d_voc)) + float(
339301
"nan"
340302
) # p Represents the position of 'b' at index + 1
@@ -706,24 +668,49 @@ def total_bound(b, i_1, i_2, dic):
706668
+ loss_diff_4(b, i_1, i_2, dic)
707669
)
708670

709-
out = torch.zeros((d_voc, n_ctx, n_ctx)) + torch.inf
671+
if s == 3:
672+
673+
out = torch.zeros((d_voc, n_ctx, n_ctx)) + torch.inf
674+
# b i_2 i_1
675+
676+
for b in range(e_p.shape[1]):
677+
678+
for i_2 in range(e_p.shape[0] - 1):
679+
for i_1 in range(1, i_2):
680+
681+
if (i_1 < i_2) & (i_1 > 0):
682+
dic = {i_1: b}
683+
for i in range(8):
684+
dic.setdefault(i, torch.arange(26))
685+
686+
out[b, i_2, i_1] = total_bound(b, i_1, i_2, dic)
687+
688+
out_2 = 1 / (1 + ((d_voc - 1) * torch.exp(out)))
689+
690+
return (attn_1, bound, bound_2, out, out_2)
691+
692+
out = torch.zeros((d_voc, n_ctx, n_ctx, d_voc)) + torch.inf
710693
# b i_2 i_1
711694

712695
for b in range(e_p.shape[1]):
696+
for n in range(e_p.shape[1]):
697+
for i_2 in range(e_p.shape[0] - 1):
698+
for i_1 in range(1, i_2):
713699

714-
for i_2 in range(e_p.shape[0] - 1):
715-
for i_1 in range(1, i_2):
700+
if (i_1 < i_2) & (i_1 > 0):
701+
dic = {i_1: b}
702+
for i in range(8):
703+
dic.setdefault(i, torch.arange(26))
716704

717-
if (i_1 < i_2) & (i_1 > 0):
718-
dic = {i_1: b}
719-
for i in range(8):
720-
dic.setdefault(i, torch.arange(26))
705+
out[b, i_2, i_1, n] = total_bound(b, i_1, i_2, dic, n)
721706

722-
out[b, i_2, i_1] = total_bound(b, i_1, i_2, dic)
707+
out_2 = einops.einsum(out.softmax(dim=-1), "b i_2 i_1 b -> b i_2 i_1")
723708

724-
out_2 = 1 / (1 + ((d_voc - 1) * torch.exp(out)))
709+
out_3 = einops.einsum(
710+
out - out.max(dim=-1).values.unsqueeze(dim=-1), "b i_2 i_1 b -> b i_2 i_1"
711+
)
725712

726-
return (attn_1, bound, bound_2, out, out_2)
713+
return (attn_1, bound, bound_2, out, out_2, out_3)
727714

728715

729716
# %%

0 commit comments

Comments
 (0)