@@ -189,71 +189,30 @@ def noise(M, v):
189
189
)
190
190
191
191
192
- def hand_bound (s , v ):
192
+ def hand_bound (
193
+ W_E , W_pos , W_V_0 , W_V_1 , W_O_0 , W_O_1 , W_Q_0 , W_Q_1 , W_K_0 , W_K_1 , W_U , s , v
194
+ ):
193
195
194
- W_E = ein .array (lambda i , j : i == j , sizes = [d_voc , d_model ]).float ().to (device )
195
196
W_E = noise (W_E , v )
196
- W_pos = (
197
- ein .array (lambda i , j : ((i + d_voc ) == j ) * 1.0 , sizes = [n_ctx , d_model ])
198
- .float ()
199
- .to (device )
200
- )
197
+
201
198
W_pos = noise (W_pos , v )
202
- W_O_0 = (
203
- ein .array (lambda i , j : ((i + n_ctx + d_voc ) == j ) * 1.0 , sizes = [d_voc , d_model ])
204
- .float ()
205
- .to (device )
206
- )
199
+
207
200
W_O_0 = noise (W_O_0 , v )
208
- W_V_0 = (
209
- ein .array (lambda i , j : (i == j ) * 1.0 , sizes = [d_model , d_voc ])
210
- .float ()
211
- .to (device )
212
- )
201
+
213
202
W_V_0 = noise (W_V_0 , v )
214
- W_V_1 = (
215
- ein .array (lambda i , j : (i == j ) * 1.0 , sizes = [d_model , d_voc ])
216
- .float ()
217
- .to (device )
218
- )
203
+
219
204
W_V_1 = noise (W_V_1 , v )
220
- W_O_1 = (
221
- ein .array (lambda i , j : (i == j ) * 100 , sizes = [d_voc , d_model ])
222
- .float ()
223
- .to (device )
224
- )
205
+
225
206
W_O_1 = noise (W_O_1 , v )
226
- W_Q_0 = (
227
- ein .array (
228
- lambda i , j : where ((i + d_voc + 1 ) == j , c , 0 ), sizes = [n_ctx , d_model ]
229
- )
230
- .float ()
231
- .to (device )
232
- .T
233
- )
207
+
234
208
W_Q_0 = noise (W_Q_0 , v )
235
- W_Q_1 = (
236
- ein .array (lambda i , j : where (i == j , d , 0 ), sizes = [d_voc , d_model ])
237
- .float ()
238
- .T .to (device )
239
- )
209
+
240
210
W_Q_1 = noise (W_Q_1 , v )
241
- W_K_0 = (
242
- ein .array (lambda i , j : where ((i + d_voc ) == j , c , 0 ), sizes = [n_ctx , d_model ])
243
- .float ()
244
- .T
245
- ).to (device )
211
+
246
212
W_K_0 = noise (W_K_0 , v )
247
- W_K_1 = (
248
- ein .array (
249
- lambda i , j : where ((i + n_ctx + d_voc ) == j , d , 0 ),
250
- sizes = [d_voc , d_model ],
251
- )
252
- .float ()
253
- .T
254
- ).to (device )
213
+
255
214
W_K_1 = noise (W_K_1 , v )
256
- W_U = ein . array ( lambda i , j : i == j , sizes = [ d_model , d_voc ]). float (). to ( device )
215
+
257
216
W_U = noise (W_U , v )
258
217
259
218
e_p = W_E .unsqueeze (dim = 0 ) + W_pos .unsqueeze (dim = 1 )
@@ -335,6 +294,9 @@ def hand_bound(s, v):
335
294
"q_pos q_val k, k l, l m, m n, n p, p q -> q_pos q_val q" ,
336
295
)
337
296
297
+ if s == - 1 :
298
+ return (term_0 , term_1 , term_2 , term_3 , term_4 , term_5 , term_6 , term_7 , term_8 )
299
+
338
300
table = torch .zeros ((d_voc , d_voc , n_ctx - 2 , d_voc )) + float (
339
301
"nan"
340
302
) # p Represents the position of 'b' at index + 1
@@ -706,24 +668,49 @@ def total_bound(b, i_1, i_2, dic):
706
668
+ loss_diff_4 (b , i_1 , i_2 , dic )
707
669
)
708
670
709
- out = torch .zeros ((d_voc , n_ctx , n_ctx )) + torch .inf
671
+ if s == 3 :
672
+
673
+ out = torch .zeros ((d_voc , n_ctx , n_ctx )) + torch .inf
674
+ # b i_2 i_1
675
+
676
+ for b in range (e_p .shape [1 ]):
677
+
678
+ for i_2 in range (e_p .shape [0 ] - 1 ):
679
+ for i_1 in range (1 , i_2 ):
680
+
681
+ if (i_1 < i_2 ) & (i_1 > 0 ):
682
+ dic = {i_1 : b }
683
+ for i in range (8 ):
684
+ dic .setdefault (i , torch .arange (26 ))
685
+
686
+ out [b , i_2 , i_1 ] = total_bound (b , i_1 , i_2 , dic )
687
+
688
+ out_2 = 1 / (1 + ((d_voc - 1 ) * torch .exp (out )))
689
+
690
+ return (attn_1 , bound , bound_2 , out , out_2 )
691
+
692
+ out = torch .zeros ((d_voc , n_ctx , n_ctx , d_voc )) + torch .inf
710
693
# b i_2 i_1
711
694
712
695
for b in range (e_p .shape [1 ]):
696
+ for n in range (e_p .shape [1 ]):
697
+ for i_2 in range (e_p .shape [0 ] - 1 ):
698
+ for i_1 in range (1 , i_2 ):
713
699
714
- for i_2 in range (e_p .shape [0 ] - 1 ):
715
- for i_1 in range (1 , i_2 ):
700
+ if (i_1 < i_2 ) & (i_1 > 0 ):
701
+ dic = {i_1 : b }
702
+ for i in range (8 ):
703
+ dic .setdefault (i , torch .arange (26 ))
716
704
717
- if (i_1 < i_2 ) & (i_1 > 0 ):
718
- dic = {i_1 : b }
719
- for i in range (8 ):
720
- dic .setdefault (i , torch .arange (26 ))
705
+ out [b , i_2 , i_1 , n ] = total_bound (b , i_1 , i_2 , dic , n )
721
706
722
- out [ b , i_2 , i_1 ] = total_bound ( b , i_1 , i_2 , dic )
707
+ out_2 = einops . einsum ( out . softmax ( dim = - 1 ), "b i_2 i_1 b -> b i_2 i_1" )
723
708
724
- out_2 = 1 / (1 + ((d_voc - 1 ) * torch .exp (out )))
709
+ out_3 = einops .einsum (
710
+ out - out .max (dim = - 1 ).values .unsqueeze (dim = - 1 ), "b i_2 i_1 b -> b i_2 i_1"
711
+ )
725
712
726
- return (attn_1 , bound , bound_2 , out , out_2 )
713
+ return (attn_1 , bound , bound_2 , out , out_2 , out_3 )
727
714
728
715
729
716
# %%
0 commit comments