Skip to content

Commit 78eed30

Browse files
stage by stage
1 parent aa906de commit 78eed30

File tree

1 file changed

+47
-45
lines changed

1 file changed

+47
-45
lines changed

gbmi/exp_indhead/finetunebound.py

Lines changed: 47 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -63,52 +63,7 @@ def loss_bound(model, s):
6363
)
6464
/ attn_scale_0
6565
)
66-
table = torch.zeros((d_voc, d_voc, n_ctx - 2, d_voc)) + float(
67-
"nan"
68-
) # p Represents the position of 'b' at index + 1
69-
70-
for p in range(2, n_ctx): #
71-
tmp = torch.zeros((p, d_voc))
72-
for t_q in range(d_voc):
73-
tmp[-1, :] = term_0[p - 1, t_q, p - 1, t_q]
74-
75-
for t_k in range(d_voc):
76-
tmp[-2, :] = term_0[p - 1, t_q, p - 2, t_k]
77-
tmp[:-2, :] = term_0[p - 1, t_q, : p - 2, :]
78-
tmp_sm = tmp.softmax(dim=0)
79-
table[t_q, t_k, p - 2, :] = tmp_sm[-2, :]
80-
# Table represents post softmax attention paid to t_k, if the final entry is spammed everywhere, and t_q is used as the first entry, at pth poisition
81-
82-
# term_0 looks like EQKE, table looks like you're indexing by query, key, position (of key?), and other token in the sequence.
83-
# They you're computing softmax of d_voc - 2 copies of the other token, one copy of t_k in p-2, and the query in p-1.
84-
# Then you store the post-softmax attention paid to t_k.
85-
#
86-
#
87-
#
88-
## xEQKE^tx^t
89-
#
90-
##
91-
# t_q vocab paying attention to t_k another letter, if other one gets spammed
92-
#
93-
##
94-
#
95-
#
96-
#
97-
##
98-
#
99-
#
100-
#
101-
#
102-
#
103-
#
104-
#
105-
#
106-
attn_1 = table.min(dim=1).values.min(dim=2).values
10766

108-
if s == 1:
109-
return attn_1
110-
111-
# attn_1=torch.ones(attn_1.shape)
11267
term_1 = (
11368
einops.einsum(
11469
e_p,
@@ -206,6 +161,53 @@ def loss_bound(model, s):
206161
]
207162
).max()
208163

164+
table = torch.zeros((d_voc, d_voc, n_ctx - 2, d_voc)) + float(
165+
"nan"
166+
) # p Represents the position of 'b' at index + 1
167+
168+
for p in range(2, n_ctx): #
169+
tmp = torch.zeros((p, d_voc))
170+
for t_q in range(d_voc):
171+
tmp[-1, :] = term_0[p - 1, t_q, p - 1, t_q]
172+
173+
for t_k in range(d_voc):
174+
tmp[-2, :] = term_0[p - 1, t_q, p - 2, t_k]
175+
tmp[:-2, :] = term_0[p - 1, t_q, : p - 2, :]
176+
tmp_sm = tmp.softmax(dim=0)
177+
table[t_q, t_k, p - 2, :] = tmp_sm[-2, :]
178+
# Table represents post softmax attention paid to t_k, if the final entry is spammed everywhere, and t_q is used as the first entry, at pth poisition
179+
180+
# term_0 looks like EQKE, table looks like you're indexing by query, key, position (of key?), and other token in the sequence.
181+
# They you're computing softmax of d_voc - 2 copies of the other token, one copy of t_k in p-2, and the query in p-1.
182+
# Then you store the post-softmax attention paid to t_k.
183+
#
184+
#
185+
#
186+
## xEQKE^tx^t
187+
#
188+
##
189+
# t_q vocab paying attention to t_k another letter, if other one gets spammed
190+
#
191+
##
192+
#
193+
#
194+
#
195+
##
196+
#
197+
#
198+
#
199+
#
200+
#
201+
#
202+
#
203+
#
204+
attn_1 = table.min(dim=1).values.min(dim=2).values
205+
206+
if s == 1:
207+
return attn_1
208+
209+
# attn_1=torch.ones(attn_1.shape)
210+
209211
def diff_1(a, i_1, i_2, j, dic):
210212

211213
if j == i_1:

0 commit comments

Comments
 (0)