@@ -63,52 +63,7 @@ def loss_bound(model, s):
63
63
)
64
64
/ attn_scale_0
65
65
)
66
- table = torch .zeros ((d_voc , d_voc , n_ctx - 2 , d_voc )) + float (
67
- "nan"
68
- ) # p Represents the position of 'b' at index + 1
69
-
70
- for p in range (2 , n_ctx ): #
71
- tmp = torch .zeros ((p , d_voc ))
72
- for t_q in range (d_voc ):
73
- tmp [- 1 , :] = term_0 [p - 1 , t_q , p - 1 , t_q ]
74
-
75
- for t_k in range (d_voc ):
76
- tmp [- 2 , :] = term_0 [p - 1 , t_q , p - 2 , t_k ]
77
- tmp [:- 2 , :] = term_0 [p - 1 , t_q , : p - 2 , :]
78
- tmp_sm = tmp .softmax (dim = 0 )
79
- table [t_q , t_k , p - 2 , :] = tmp_sm [- 2 , :]
80
- # Table represents post softmax attention paid to t_k, if the final entry is spammed everywhere, and t_q is used as the first entry, at pth poisition
81
-
82
- # term_0 looks like EQKE, table looks like you're indexing by query, key, position (of key?), and other token in the sequence.
83
- # They you're computing softmax of d_voc - 2 copies of the other token, one copy of t_k in p-2, and the query in p-1.
84
- # Then you store the post-softmax attention paid to t_k.
85
- #
86
- #
87
- #
88
- ## xEQKE^tx^t
89
- #
90
- ##
91
- # t_q vocab paying attention to t_k another letter, if other one gets spammed
92
- #
93
- ##
94
- #
95
- #
96
- #
97
- ##
98
- #
99
- #
100
- #
101
- #
102
- #
103
- #
104
- #
105
- #
106
- attn_1 = table .min (dim = 1 ).values .min (dim = 2 ).values
107
66
108
- if s == 1 :
109
- return attn_1
110
-
111
- # attn_1=torch.ones(attn_1.shape)
112
67
term_1 = (
113
68
einops .einsum (
114
69
e_p ,
@@ -206,6 +161,53 @@ def loss_bound(model, s):
206
161
]
207
162
).max ()
208
163
164
+ table = torch .zeros ((d_voc , d_voc , n_ctx - 2 , d_voc )) + float (
165
+ "nan"
166
+ ) # p Represents the position of 'b' at index + 1
167
+
168
+ for p in range (2 , n_ctx ): #
169
+ tmp = torch .zeros ((p , d_voc ))
170
+ for t_q in range (d_voc ):
171
+ tmp [- 1 , :] = term_0 [p - 1 , t_q , p - 1 , t_q ]
172
+
173
+ for t_k in range (d_voc ):
174
+ tmp [- 2 , :] = term_0 [p - 1 , t_q , p - 2 , t_k ]
175
+ tmp [:- 2 , :] = term_0 [p - 1 , t_q , : p - 2 , :]
176
+ tmp_sm = tmp .softmax (dim = 0 )
177
+ table [t_q , t_k , p - 2 , :] = tmp_sm [- 2 , :]
178
+ # Table represents post softmax attention paid to t_k, if the final entry is spammed everywhere, and t_q is used as the first entry, at pth poisition
179
+
180
+ # term_0 looks like EQKE, table looks like you're indexing by query, key, position (of key?), and other token in the sequence.
181
+ # They you're computing softmax of d_voc - 2 copies of the other token, one copy of t_k in p-2, and the query in p-1.
182
+ # Then you store the post-softmax attention paid to t_k.
183
+ #
184
+ #
185
+ #
186
+ ## xEQKE^tx^t
187
+ #
188
+ ##
189
+ # t_q vocab paying attention to t_k another letter, if other one gets spammed
190
+ #
191
+ ##
192
+ #
193
+ #
194
+ #
195
+ ##
196
+ #
197
+ #
198
+ #
199
+ #
200
+ #
201
+ #
202
+ #
203
+ #
204
+ attn_1 = table .min (dim = 1 ).values .min (dim = 2 ).values
205
+
206
+ if s == 1 :
207
+ return attn_1
208
+
209
+ # attn_1=torch.ones(attn_1.shape)
210
+
209
211
def diff_1 (a , i_1 , i_2 , j , dic ):
210
212
211
213
if j == i_1 :
0 commit comments