Skip to content

Commit 4660cd5

Browse files
committed
flash attention
1 parent f6bff8f commit 4660cd5

File tree

5 files changed

+549
-90
lines changed

5 files changed

+549
-90
lines changed

examples/flash_attention.py

Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
import mlir.extras.types as T
2+
import numpy as np
3+
from hip import hip
4+
from mlir.ir import InsertionPoint, IntegerAttr, UnitAttr
5+
6+
from mlir.extras.ast.canonicalize import canonicalize
7+
from mlir.extras.context import RAIIMLIRContextModule
8+
from mlir.extras.dialects.ext import memref, scf, arith, gpu, llvm
9+
10+
# noinspection PyUnresolvedReferences
11+
from mlir.extras.dialects.ext.gpu import (
12+
block_idx,
13+
thread_idx,
14+
grid_dim,
15+
func as gpu_func,
16+
set_container_module,
17+
module,
18+
get_compile_object_bytes,
19+
)
20+
from mlir.extras.runtime.passes import run_pipeline, Pipeline
21+
from mlir.extras.util import find_ops
22+
23+
# noinspection PyUnresolvedReferences
24+
from util import hip_check, launch_kernel, hip_synchronize
25+
26+
# just so it doesn't get DCE'd by black/reformat
27+
# TypeError: 'mlir._mlir_libs._mlir.ir.BlockArgument' object is not subscriptable
28+
_ = memref
29+
30+
ctx = RAIIMLIRContextModule()
31+
set_container_module(ctx.module)
32+
33+
props = hip.hipDeviceProp_t()
34+
hip_check(hip.hipGetDeviceProperties(props, 0))
35+
arch = props.gcnArchName.decode()
36+
37+
38+
# just a default attr - actual target is set blow
39+
@module("kernels", [f'#rocdl.target<abi = "500">'])
40+
def gpu_module():
41+
pass
42+
43+
44+
ip = InsertionPoint.at_block_begin(gpu_module.regions[0].blocks[0])
45+
ip.__enter__()
46+
47+
batch_size = 16
48+
n_head = 12
49+
seq_len = 64
50+
head_embd = 64
51+
52+
Bc = 32
53+
Br = 32
54+
55+
B = batch_size
56+
nh = n_head
57+
N = seq_len
58+
d = head_embd
59+
60+
import math
61+
62+
Tc = math.ceil(N / Bc)
63+
Tr = math.ceil(N / Br)
64+
softmax_scale = 1.0 / math.sqrt(d)
65+
tile_size = Bc * d # size of Qi, Kj, Vj
66+
67+
68+
def softmax(x, axis=None):
69+
x_max = np.amax(x, axis=axis, keepdims=True)
70+
exp_x_shifted = np.exp(x - x_max)
71+
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
72+
73+
74+
def manual_attn(q, k, v):
75+
# the kernel below overwrites the global math.........
76+
import math
77+
78+
q = q.reshape(batch_size, n_head, seq_len, head_embd)
79+
k = k.reshape(batch_size, n_head, seq_len, head_embd)
80+
v = v.reshape(batch_size, n_head, seq_len, head_embd)
81+
82+
att = q @ k.transpose(0, 1, -2, -1) * (1.0 / math.sqrt(k.shape[-1]))
83+
att = softmax(att, axis=-1)
84+
y = att @ v
85+
return y.flatten()
86+
87+
88+
from mlir.dialects import math
89+
90+
91+
# https://github.com/tspeterkim/flash-attention-minimal/blob/main/flash.cu
92+
@gpu_func(emit=True)
93+
@canonicalize(using=[scf.canonicalizer, arith.canonicalizer])
94+
def flash_attention(
95+
Q: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()),
96+
K: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()),
97+
V: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()),
98+
l: T.memref(B * nh * N, T.f32()),
99+
m: T.memref(B * nh * N, T.f32()),
100+
O: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()),
101+
):
102+
tx = thread_idx.x
103+
bx = block_idx.x
104+
by = block_idx.y # batch and head index
105+
106+
# Offset into Q,K,V,O,l,m - different for each batch and head
107+
qkv_offset = bx * grid_dim.y * N * d + by * N * d # gridDim.y = nh
108+
lm_offset = bx * grid_dim.y * N + by * N # offset for l and m
109+
110+
# Define SRAM for Q,K,V,S
111+
sram = gpu.dynamic_shared_memory()
112+
Qi = memref.view(sram, (tile_size,), dtype=T.f32())
113+
Kj = memref.view(sram, (tile_size,), dtype=T.f32(), shift=tile_size * 1)
114+
Vj = memref.view(sram, (tile_size,), dtype=T.f32(), shift=tile_size * 2)
115+
S = memref.view(sram, (tile_size,), dtype=T.f32(), shift=tile_size * 3)
116+
117+
for j in scf.range_(0, Tc):
118+
# Load Kj, Vj to SRAM
119+
for x in scf.range_(0, d):
120+
Kj[tx * d + x] = K[qkv_offset + tile_size * j + tx * d + x]
121+
Vj[tx * d + x] = V[qkv_offset + tile_size * j + tx * d + x]
122+
123+
gpu.barrier() # such that the inner loop can use the correct Kj, Vj
124+
125+
for i in scf.range_(0, Tr):
126+
# Load Qi to SRAM, l and m to registers
127+
for x in scf.range_(0, d):
128+
ii = qkv_offset + tile_size * i + tx * d + x
129+
Qi[tx * d + x] = Q[ii]
130+
131+
row_m_prev = m[lm_offset + Br * i + tx]
132+
row_l_prev = l[lm_offset + Br * i + tx]
133+
134+
# S = QK^T, row_m = rowmax(S)
135+
row_m: T.f32() = float(np.finfo(np.float32).min)
136+
for y, row_m, _ in scf.range_(0, Bc, iter_args=[row_m]):
137+
sum: T.f32() = 0.0
138+
for x, sum, _ in scf.range_(0, d, iter_args=[sum]):
139+
sum += Qi[tx * d + x] * Kj[y * d + x]
140+
sum = yield sum
141+
142+
sum *= softmax_scale
143+
S[Bc * tx + y] = sum
144+
145+
if sum > row_m:
146+
row_m_ = yield sum
147+
else:
148+
row_m_ = yield row_m
149+
150+
row_m = yield row_m_
151+
152+
# P = exp(S - row_m), row_l = rowsum(P)
153+
row_l: T.f32() = 0.0
154+
for y, row_l, _ in scf.range_(0, Bc, iter_args=[row_l]):
155+
S[Bc * tx + y] = math.exp(S[Bc * tx + y] - row_m)
156+
row_l += S[Bc * tx + y]
157+
row_l = yield row_l
158+
159+
# Compute new m and l
160+
row_m_new = arith.maximumf(row_m_prev, row_m)
161+
row_l_new = (
162+
math.exp(row_m_prev - row_m_new) * row_l_prev
163+
+ math.exp(row_m - row_m_new) * row_l
164+
)
165+
div = 1.0 / row_l_new
166+
c = row_l_prev * math.exp(row_m_prev - row_m_new)
167+
168+
# Write O, l, m to HBM
169+
for x in scf.range_(0, d):
170+
pv: T.f32() = 0.0 # Pij * Vj
171+
for y, pv, _ in scf.range_(0, Bc, iter_args=[pv]):
172+
pv += S[Bc * tx + y] * Vj[y * d + x]
173+
pv = yield pv
174+
175+
ii = qkv_offset + tile_size * i + tx * d + x
176+
O[ii] = div * (c * O[ii] + math.exp(row_m - row_m_new) * pv)
177+
178+
gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
179+
180+
m[lm_offset + Br * i + tx] = row_m_new
181+
l[lm_offset + Br * i + tx] = row_l_new
182+
183+
# gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
184+
# gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
185+
186+
187+
ip.__exit__(None, None, None)
188+
189+
sram_size = 4 * tile_size * np.float32().itemsize
190+
191+
launch_params = {
192+
flash_attention.__name__: (
193+
(B, nh, 1),
194+
(Bc, 1, 1),
195+
sram_size,
196+
)
197+
}
198+
199+
simplified_module = run_pipeline(
200+
ctx.module,
201+
Pipeline()
202+
.canonicalize()
203+
.cse()
204+
.loop_invariant_code_motion()
205+
.loop_invariant_subset_hoisting()
206+
.rocdl_attach_target(chip=arch, O=3, abi="500"),
207+
)
208+
209+
lowered_module = run_pipeline(
210+
simplified_module,
211+
Pipeline()
212+
.Gpu(
213+
Pipeline().convert_gpu_to_rocdl(
214+
use_bare_ptr_memref_call_conv=True,
215+
runtime="HIP",
216+
)
217+
)
218+
.gpu_to_llvm()
219+
.lower_to_llvm(),
220+
# .Nested("llvm.func", Pipeline().sroa()),
221+
)
222+
223+
# print(lowered_module)
224+
gep = find_ops(lowered_module.operation, lambda o: isinstance(o.opview, llvm.GEPOp))
225+
for g in gep:
226+
g.attributes["inbounds"] = UnitAttr.get()
227+
228+
kernel_funcs = find_ops(
229+
lowered_module.operation, lambda o: isinstance(o.opview, llvm.LLVMFuncOp)
230+
)
231+
for k in kernel_funcs:
232+
if k.sym_name.value != flash_attention.__name__:
233+
continue
234+
_, thread_dims, _ = launch_params[k.sym_name.value]
235+
k.attributes["rocdl.max_flat_work_group_size"] = IntegerAttr.get(
236+
T.index(), np.prod(thread_dims)
237+
)
238+
239+
lowered_module = run_pipeline(lowered_module, Pipeline().gpu_module_to_binary())
240+
hsaco = get_compile_object_bytes(lowered_module)
241+
242+
hip_module = hip_check(hip.hipModuleLoadData(hsaco))
243+
244+
q_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype(
245+
dtype=np.float32
246+
)
247+
k_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype(
248+
dtype=np.float32
249+
)
250+
v_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype(
251+
dtype=np.float32
252+
)
253+
l_h = np.zeros((B * nh * N), dtype=np.float32)
254+
m_h = np.full((B * nh * N), float(np.finfo(np.float32).min), dtype=np.float32)
255+
O_h = np.zeros_like(q_h, dtype=np.float32)
256+
257+
q_num_bytes = q_h.size * q_h.itemsize
258+
k_num_bytes = k_h.size * k_h.itemsize
259+
v_num_bytes = v_h.size * v_h.itemsize
260+
l_num_bytes = l_h.size * l_h.itemsize
261+
m_num_bytes = m_h.size * m_h.itemsize
262+
O_num_bytes = O_h.size * O_h.itemsize
263+
264+
q_d = hip_check(hip.hipMalloc(q_num_bytes))
265+
k_d = hip_check(hip.hipMalloc(k_num_bytes))
266+
v_d = hip_check(hip.hipMalloc(v_num_bytes))
267+
l_d = hip_check(hip.hipMalloc(l_num_bytes))
268+
m_d = hip_check(hip.hipMalloc(m_num_bytes))
269+
O_d = hip_check(hip.hipMalloc(O_num_bytes))
270+
271+
stream = 0
272+
273+
times = {
274+
flash_attention: 0,
275+
}
276+
# random.shuffle(kernels)
277+
runs = 16
278+
for kernel in times:
279+
for i in range(runs):
280+
function = hip_check(
281+
hip.hipModuleGetFunction(hip_module, kernel.__name__.encode())
282+
)
283+
hip_check(hip.hipDeviceSynchronize())
284+
285+
for d, h, num_bytes in zip(
286+
[q_d, k_d, v_d, l_d, m_d, O_d],
287+
[q_h, k_h, v_h, l_h, m_h, O_h],
288+
[
289+
q_num_bytes,
290+
k_num_bytes,
291+
v_num_bytes,
292+
l_num_bytes,
293+
m_num_bytes,
294+
O_num_bytes,
295+
],
296+
):
297+
hip_check(
298+
hip.hipMemcpy(d, h, num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)
299+
)
300+
301+
(
302+
(
303+
blocks_per_grid_x,
304+
blocks_per_grid_y,
305+
blocks_per_grid_z,
306+
),
307+
(
308+
threads_per_block_x,
309+
threads_per_block_y,
310+
threads_per_block_z,
311+
),
312+
shared_memory,
313+
) = launch_params[kernel.__name__]
314+
315+
time_compute = launch_kernel(
316+
function.as_c_void_p(),
317+
blocks_per_grid_x,
318+
blocks_per_grid_y,
319+
blocks_per_grid_z,
320+
threads_per_block_x,
321+
threads_per_block_y,
322+
threads_per_block_z,
323+
stream,
324+
shared_memory,
325+
q_d,
326+
k_d,
327+
v_d,
328+
l_d,
329+
m_d,
330+
O_d,
331+
)
332+
333+
hip_check(
334+
hip.hipMemcpy(
335+
l_h, l_d, l_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost
336+
)
337+
)
338+
hip_check(
339+
hip.hipMemcpy(
340+
m_h, m_d, m_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost
341+
)
342+
)
343+
hip_check(
344+
hip.hipMemcpy(
345+
O_h, O_d, O_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost
346+
)
347+
)
348+
correct = manual_attn(q_h, k_h, v_h)
349+
if not np.allclose(correct, O_h):
350+
print("correct", correct)
351+
print("l_h", l_h)
352+
print("m_h", m_h)
353+
print("output", O_h)
354+
print(f"{kernel.__name__} failed")
355+
356+
times[kernel] += time_compute
357+
358+
for k in times:
359+
times[k] /= runs
360+
361+
for k, v in times.items():
362+
print(f"{k.__name__}: {v:.03f}ms")

0 commit comments

Comments
 (0)