Skip to content

Commit 3ddf4f8

Browse files
committed
Upgrade attention to a version capable of handling 4 dimensions and add a comparison with Triton
1 parent 7407864 commit 3ddf4f8

File tree

1 file changed

+203
-33
lines changed

1 file changed

+203
-33
lines changed

attention.py

Lines changed: 203 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,42 +3,59 @@
33
import torch
44
import torch.nn.functional as F
55
import triton
6+
import triton.language as tl
67
from ninetoothed import Symbol, Tensor
78

8-
BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
9-
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
109

11-
q = Tensor(2, constexpr_shape=True)
12-
k = Tensor(2, constexpr_shape=True)
13-
v = Tensor(2, constexpr_shape=True)
14-
o = Tensor(2, constexpr_shape=True)
10+
def arrangement(q, k, v, o):
11+
BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
12+
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
1513

16-
q_tiled = q.tile((BLOCK_SIZE_M, -1))
17-
k_tiled = k.tile((BLOCK_SIZE_N, -1)).tile((-1, -1)).expand((q_tiled.shape[0], -1))
18-
v_tiled = v.tile((BLOCK_SIZE_N, -1)).tile((-1, -1)).expand((q_tiled.shape[0], -1))
19-
o_tiled = o.tile((BLOCK_SIZE_M, -1))
14+
def arrange_q_or_o(input):
15+
arranged = input.tile((1, 1, BLOCK_SIZE_M, -1))
16+
arranged.dtype = arranged.dtype.squeeze((0, 1))
2017

18+
return arranged
2119

22-
@ninetoothed.jit
23-
def attention_kernel(q: q_tiled, k: k_tiled, v: v_tiled, o: o_tiled):
20+
def arrange_k_or_v(input):
21+
arranged = (
22+
input.tile((1, 1, BLOCK_SIZE_N, -1))
23+
.tile((1, 1, -1, -1))
24+
.expand((-1, -1, q_arranged.shape[-2], -1))
25+
)
26+
arranged.dtype = arranged.dtype.squeeze((0, 1, 3))
27+
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))
28+
29+
return arranged
30+
31+
q_arranged = arrange_q_or_o(q)
32+
33+
return q_arranged, arrange_k_or_v(k), arrange_k_or_v(v), arrange_q_or_o(o)
34+
35+
36+
def application(q, k, v, o):
2437
acc = ntl.zeros((q.shape[-2], q.shape[-1]), dtype=ntl.float32)
2538
l_i = ntl.full((q.shape[-2],), 1, dtype=ntl.float32)
2639
m_i = ntl.full((q.shape[-2],), float("-inf"), dtype=ntl.float32)
2740

2841
for i in range(k.shape[0]):
29-
qk = ntl.dot((q * 1.44269504089).to(ntl.float16), ntl.trans(k[i, 0]))
42+
qk = ntl.dot((q * 1.44269504089).to(ntl.float16), ntl.trans(k[i]))
3043

3144
m_ij = ntl.maximum(m_i, ntl.max(qk, 1))
3245
p = ntl.exp2(qk - m_ij[:, None])
3346
l_ij = ntl.sum(p, 1)
3447

3548
alpha = ntl.exp2(m_i - m_ij)
36-
acc = acc * alpha[:, None] + ntl.dot(p.to(ntl.float16), v[i, 0])
49+
acc = acc * alpha[:, None] + ntl.dot(p.to(ntl.float16), v[i])
3750
m_i = m_ij
3851
l_i = l_i * alpha + l_ij
3952

4053
acc /= l_i[:, None]
41-
o = acc.to(ntl.float32) # noqa: F841
54+
o = acc # noqa: F841
55+
56+
57+
q, k, v, o = (Tensor(4, constexpr_shape=True) for _ in range(4))
58+
attention_kernel = ninetoothed.make(arrangement, application, (q, k, v, o))
4259

4360

4461
def attention(q, k, v):
@@ -49,59 +66,212 @@ def attention(q, k, v):
4966
return o
5067

5168

69+
@triton.autotune(
70+
configs=[
71+
triton.Config(
72+
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=8
73+
),
74+
triton.Config(
75+
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=8
76+
),
77+
triton.Config(
78+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=4
79+
),
80+
triton.Config(
81+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=4
82+
),
83+
triton.Config(
84+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=8
85+
),
86+
triton.Config(
87+
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32}, num_stages=4, num_warps=8
88+
),
89+
],
90+
key=["EMB_DIM"],
91+
)
92+
@triton.jit
93+
def triton_attention_kernel(
94+
q_ptr,
95+
k_ptr,
96+
v_ptr,
97+
o_ptr,
98+
q_stride_z,
99+
q_stride_h,
100+
q_stride_m,
101+
q_stride_k,
102+
k_stride_z,
103+
k_stride_h,
104+
k_stride_n,
105+
k_stride_k,
106+
v_stride_z,
107+
v_stride_h,
108+
v_stride_k,
109+
v_stride_n,
110+
o_stride_z,
111+
o_stride_h,
112+
o_stride_m,
113+
o_stride_n,
114+
SEQ_LEN: tl.constexpr,
115+
EMB_DIM: tl.constexpr,
116+
BLOCK_SIZE_M: tl.constexpr,
117+
BLOCK_SIZE_N: tl.constexpr,
118+
):
119+
off_m = tl.program_id(0)
120+
off_h = tl.program_id(1)
121+
off_z = tl.program_id(2)
122+
123+
offs_m_start = off_m * BLOCK_SIZE_M
124+
125+
q_off = off_z * q_stride_z + off_h * q_stride_h
126+
q_block_ptr = tl.make_block_ptr(
127+
base=q_ptr + q_off,
128+
shape=(SEQ_LEN, EMB_DIM),
129+
strides=(q_stride_m, q_stride_k),
130+
offsets=(offs_m_start, 0),
131+
block_shape=(BLOCK_SIZE_M, EMB_DIM),
132+
order=(1, 0),
133+
)
134+
k_off = off_z * k_stride_z + off_h * k_stride_h
135+
k_block_ptr = tl.make_block_ptr(
136+
base=k_ptr + k_off,
137+
shape=(EMB_DIM, SEQ_LEN),
138+
strides=(k_stride_k, k_stride_n),
139+
offsets=(0, 0),
140+
block_shape=(EMB_DIM, BLOCK_SIZE_N),
141+
order=(0, 1),
142+
)
143+
v_off = off_z * v_stride_z + off_h * v_stride_h
144+
v_block_ptr = tl.make_block_ptr(
145+
base=v_ptr + v_off,
146+
shape=(SEQ_LEN, EMB_DIM),
147+
strides=(v_stride_k, v_stride_n),
148+
offsets=(0, 0),
149+
block_shape=(BLOCK_SIZE_N, EMB_DIM),
150+
order=(1, 0),
151+
)
152+
o_off = off_z * o_stride_z + off_h * o_stride_h
153+
o_block_ptr = tl.make_block_ptr(
154+
base=o_ptr + o_off,
155+
shape=(SEQ_LEN, EMB_DIM),
156+
strides=(o_stride_m, o_stride_n),
157+
offsets=(offs_m_start, 0),
158+
block_shape=(BLOCK_SIZE_M, EMB_DIM),
159+
order=(1, 0),
160+
)
161+
162+
q = (tl.load(q_block_ptr) * 1.44269504089).to(q_block_ptr.type.element_ty)
163+
164+
acc = tl.zeros((BLOCK_SIZE_M, EMB_DIM), dtype=tl.float32)
165+
l_i = tl.full((BLOCK_SIZE_M,), 1, dtype=tl.float32)
166+
m_i = tl.full((BLOCK_SIZE_M,), float("-inf"), dtype=tl.float32)
167+
168+
for _ in range(0, tl.cdiv(SEQ_LEN, BLOCK_SIZE_N)):
169+
k = tl.load(k_block_ptr)
170+
171+
qk = tl.dot(q, k)
172+
173+
m_ij = tl.maximum(m_i, tl.max(qk, 1))
174+
qk -= m_ij[:, None]
175+
p = tl.exp2(qk)
176+
l_ij = tl.sum(p, 1)
177+
178+
v = tl.load(v_block_ptr)
179+
alpha = tl.exp2(m_i - m_ij)
180+
acc = acc * alpha[:, None] + tl.dot(p.to(v_block_ptr.type.element_ty), v)
181+
m_i = m_ij
182+
l_i = l_i * alpha + l_ij
183+
184+
v_block_ptr = tl.advance(v_block_ptr, (BLOCK_SIZE_N, 0))
185+
k_block_ptr = tl.advance(k_block_ptr, (0, BLOCK_SIZE_N))
186+
187+
acc /= l_i[:, None]
188+
189+
tl.store(o_block_ptr, acc.to(o_ptr.type.element_ty))
190+
191+
192+
def triton_attention(q, k, v):
193+
o = torch.empty_like(q)
194+
195+
batch_size, num_heads, seq_len, emb_dim = q.shape
196+
197+
def grid(meta):
198+
return (
199+
triton.cdiv(seq_len, meta["BLOCK_SIZE_M"]),
200+
num_heads,
201+
batch_size,
202+
)
203+
204+
triton_attention_kernel[grid](
205+
q,
206+
k,
207+
v,
208+
o,
209+
*q.stride(),
210+
*k.stride(),
211+
*v.stride(),
212+
*o.stride(),
213+
SEQ_LEN=seq_len,
214+
EMB_DIM=emb_dim,
215+
)
216+
217+
return o
218+
219+
52220
if __name__ == "__main__":
53221
torch.manual_seed(0)
54-
shape = (1, 1, 1024, 64)
222+
shape = (2, 4, 1024, 64)
55223
dtype = torch.float16
56224
q = torch.randn(shape, dtype=dtype, device="cuda")
57225
k = torch.randn(shape, dtype=dtype, device="cuda")
58226
v = torch.randn(shape, dtype=dtype, device="cuda")
59227

60-
ninetoothed_output = attention(
61-
q.view(q.shape[-2:]), k.view(k.shape[-2:]), v.view(v.shape[-2:])
62-
)
228+
ninetoothed_output = attention(q, k, v)
63229
torch_output = F.scaled_dot_product_attention(q, k, v, scale=1)
230+
triton_output = triton_attention(q, k, v)
64231
print(ninetoothed_output)
65232
print(torch_output)
66-
if torch.allclose(ninetoothed_output, torch_output, atol=0.01, rtol=0.01):
233+
print(triton_output)
234+
if torch.allclose(ninetoothed_output, torch_output, atol=0.01):
67235
print("✅ NineToothed and PyTorch match.")
68236
else:
69237
print("❌ NineToothed and PyTorch differ.")
238+
if torch.allclose(ninetoothed_output, triton_output, atol=0.01):
239+
print("✅ NineToothed and Triton match.")
240+
else:
241+
print("❌ NineToothed and Triton differ.")
70242

71243
@triton.testing.perf_report(
72244
triton.testing.Benchmark(
73-
x_names=["n"],
245+
x_names=["seq_len"],
74246
x_vals=[2**i for i in range(10, 15)],
75247
line_arg="provider",
76-
line_vals=["ninetoothed", "torch"],
77-
line_names=["NineToothed", "PyTorch"],
78-
styles=[("blue", "-"), ("green", "-")],
248+
line_vals=["ninetoothed", "torch", "triton"],
249+
line_names=["NineToothed", "PyTorch", "Triton"],
250+
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
79251
ylabel="TFLOPS",
80252
plot_name="attention-performance",
81253
args={},
82254
)
83255
)
84-
def benchmark(n, provider):
85-
d = 64
86-
shape = (n, d)
256+
def benchmark(seq_len, provider):
257+
batch_size, num_heads, emb_dim = 4, 32, 64
258+
shape = (batch_size, num_heads, seq_len, emb_dim)
87259
dtype = torch.float16
88260
q = torch.randn(shape, dtype=dtype, device="cuda")
89261
k = torch.randn(shape, dtype=dtype, device="cuda")
90262
v = torch.randn(shape, dtype=dtype, device="cuda")
91263

92264
if provider == "ninetoothed":
93-
ms = triton.testing.do_bench(
94-
lambda: attention(
95-
q.view(q.shape[-2:]), k.view(k.shape[-2:]), v.view(v.shape[-2:])
96-
)
97-
)
265+
ms = triton.testing.do_bench(lambda: attention(q, k, v))
98266
elif provider == "torch":
99267
ms = triton.testing.do_bench(
100268
lambda: F.scaled_dot_product_attention(q, k, v, scale=1)
101269
)
270+
elif provider == "triton":
271+
ms = triton.testing.do_bench(lambda: triton_attention(q, k, v))
102272

103273
def perf(ms):
104-
flops_per_matmul = 2 * n * n * d
274+
flops_per_matmul = 2 * batch_size * num_heads * seq_len * seq_len * emb_dim
105275
total_flops = 2 * flops_per_matmul
106276

107277
return total_flops * 1e-12 / (ms * 1e-3)

0 commit comments

Comments
 (0)