Skip to content

Commit d421777

Browse files
committed
Refactor attention.application to preload q
1 parent f4d0fba commit d421777

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@ def arrange_k_or_v(input):
3434

3535

3636
def application(q, k, v, o):
37+
q_loaded = (q * 1.44269504089).to(ntl.float16)
38+
3739
acc = ntl.zeros((q.shape[-2], q.shape[-1]), dtype=ntl.float32)
3840
l_i = ntl.full((q.shape[-2],), 1, dtype=ntl.float32)
3941
m_i = ntl.full((q.shape[-2],), float("-inf"), dtype=ntl.float32)
4042

4143
for i in range(k.shape[0]):
42-
qk = ntl.dot((q * 1.44269504089).to(ntl.float16), ntl.trans(k[i]))
44+
qk = ntl.dot(q_loaded, ntl.trans(k[i]))
4345

4446
m_ij = ntl.maximum(m_i, ntl.max(qk, 1))
4547
p = ntl.exp2(qk - m_ij[:, None])

0 commit comments

Comments
 (0)