Skip to content

Commit 4a2c7ab

Browse files
committed
added stuff
1 parent 28f5508 commit 4a2c7ab

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+493140
-80
lines changed

gbmi/exp_indhead/handcoded.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def show(matrix):
2929
d = 10
3030
W_pos = model.W_pos
3131
W_E = model.W_E
32-
epsilon = 0.5
32+
epsilon = 0.3
3333
n_ctx = W_pos.shape[0]
3434
d_voc = W_E.shape[0]
3535
d_model = W_E.shape[1]

gbmi/exp_indhead/probabilistic.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# %%
2+
from gbmi.exp_indhead.train import ABCAB8_1H
3+
from torch import where
4+
from gbmi.model import train_or_load_model
5+
import torch
6+
from torch import tensor
7+
from math import *
8+
import plotly.express as px
9+
from gbmi.utils.sequences import generate_all_sequences
10+
import copy
11+
from inspect import signature
12+
13+
import plotly.express as px
14+
15+
16+
def show(matrix):
17+
if len(matrix.shape) == 1:
18+
matrix = matrix.unsqueeze(0)
19+
px.imshow(matrix.detach().cpu()).show()
20+
21+
22+
device = "cuda" if torch.cuda.is_available() else "cpu"
23+
torch.set_default_device(device)
24+
runtime_model_1, model = train_or_load_model(ABCAB8_1H, force="load")
25+
model.to(device)
26+
27+
W_pos = model.W_pos
28+
W_E = model.W_E
29+
n_ctx = W_pos.shape[0]
30+
d_voc = W_E.shape[0]
31+
d_model = W_E.shape[1]
32+
33+
34+
# %%
35+
attn_scale_0 = model.blocks[0].attn.attn_scale
36+
attn_scale_1 = model.blocks[1].attn.attn_scale
37+
W_pos = model.W_pos
38+
W_E = model.W_E
39+
W_K_1 = model.W_K[1, 0]
40+
W_U = model.W_U
41+
W_V_1 = model.W_V[1, 0]
42+
W_K_0 = model.W_K[0, 0]
43+
W_V_0 = model.W_V[0, 0]
44+
W_O_0 = model.W_O[0, 0]
45+
W_Q_1 = model.W_Q[1, 0]
46+
W_Q_0 = model.W_Q[0, 0]
47+
W_O_1 = model.W_O[1, 0]
48+
W_Q_0 = model.W_Q[0, 0]
49+
o = W_O_0
50+
v = W_V_0
51+
q_1 = W_Q_1
52+
k_1 = W_K_1
53+
v_1 = W_V_1
54+
o_1 = W_O_1
55+
# %%
56+
57+
58+
EQKP = (W_E @ W_Q_0 @ W_K_0.T @ W_pos.T) / (attn_scale_0)
59+
PQKP = (W_pos @ W_Q_0 @ W_K_0.T @ W_pos.T) / (attn_scale_0)
60+
PQKE = (W_pos @ W_Q_0 @ W_K_0.T @ W_E.T) / (attn_scale_0)
61+
EQKE = (W_E @ W_Q_0 @ W_K_0.T @ W_E.T) / (attn_scale_0)
62+
63+
64+
# %%
65+
66+
pos_pattern_pres = []
67+
for index in range(1, 9):
68+
pos_pattern_pres.append(
69+
torch.softmax(PQKP[index - 1, :index] + EQKP[:, :index], dim=1)
70+
)
71+
72+
other_parts = torch.exp(PQKE[-index] + EQKE)
73+
74+
75+
# %%
76+
pvo = torch.zeros(8, 64)
77+
for index in range(1, 9):
78+
pvo[index - 1] = W_pos[index - 1] + (
79+
(W_pos[:index] @ v @ o) * (pos_pattern_pres[index - 1].mean(dim=0)).unsqueeze(1)
80+
).sum(dim=0)
81+
82+
83+
# %%
84+
pvoqkpvo = (pvo @ q_1 @ k_1.T @ pvo.T) / (attn_scale_1)
85+
eqkpvo = (W_E @ q_1 @ k_1.T @ pvo.T) / (attn_scale_1)
86+
evoqkpvo = (W_E @ v @ o @ q_1 @ k_1.T @ pvo.T) / (attn_scale_1)
87+
# %%
88+
index = 6
89+
pvo_pattern = torch.softmax(
90+
eqkpvo[:, :index] + evoqkpvo[:, :index].mean() + pvoqkpvo[index - 1, :index], dim=1
91+
)
92+
show(pvo_pattern)
93+
# %%
94+
pvoqke = (pvo @ q_1 @ k_1.T @ W_E.T) / (attn_scale_1)
95+
eqke = (W_E @ q_1 @ k_1.T @ W_E.T) / (attn_scale_1)
96+
evoqke = (W_E @ v @ o @ q_1 @ k_1.T @ W_E.T) / (attn_scale_1)
97+
pvoqkevo = (W_pos @ v @ o @ q_1 @ k_1.T @ (W_E @ v @ o).T) / (attn_scale_1)
98+
evoqkevo = (W_E @ v @ o @ q_1 @ k_1.T @ (W_E @ v @ o).T) / (attn_scale_1)
99+
# %%
100+
# e in itself
101+
show(pvoqkevo)
102+
show(evoqkevo)
103+
show(eqkevo)
104+
show(pvoqke)
105+
show(eqke) # a -> b
106+
# a -> a
107+
show(evoqke) # c -> a
108+
# c - > a
109+
# %%
110+
pvoqkevo = (pvo @ q_1 @ k_1.T @ (W_E @ v @ o).T) / (attn_scale_1)
111+
eqkevo = (W_E @ q_1 @ k_1.T @ (W_E @ v @ o).T) / (attn_scale_1)
112+
evoqkevo = (W_E @ v @ o @ q_1 @ k_1.T @ (W_E @ v @ o).T) / (attn_scale_1)
113+
show(torch.exp(evoqkevo))
114+
show(eqkevo)
115+
show(torch.exp(pvoqkevo[1:-1]))
116+
# %%

0 commit comments

Comments
 (0)