Skip to content

Commit 986f810

Browse files
committed
works
1 parent 5a0d72f commit 986f810

File tree

1 file changed

+102
-112
lines changed

1 file changed

+102
-112
lines changed

examples/flash_attention.py

Lines changed: 102 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from pathlib import Path
2+
13
import mlir.extras.types as T
24
import numpy as np
35
from hip import hip
46
from mlir.ir import InsertionPoint, IntegerAttr, UnitAttr
5-
67
from mlir.extras.ast.canonicalize import canonicalize
78
from mlir.extras.context import RAIIMLIRContextModule
89
from mlir.extras.dialects.ext import memref, scf, arith, gpu, llvm
@@ -23,6 +24,40 @@
2324
# noinspection PyUnresolvedReferences
2425
from util import hip_check, launch_kernel, hip_synchronize
2526

27+
28+
def init_copy_host_device():
29+
q_h = np.random.randint(0, 10, (B * nh * N * d)).astype(dtype=np.float32)
30+
k_h = np.random.randint(0, 10, (B * nh * N * d)).astype(dtype=np.float32)
31+
v_h = np.random.randint(0, 10, (B * nh * N * d)).astype(dtype=np.float32)
32+
l_h = np.zeros((B * nh * N), dtype=np.float32)
33+
m_h = np.full((B * nh * N), float(np.finfo(np.float32).min), dtype=np.float32)
34+
O_h = np.zeros_like(q_h, dtype=np.float32)
35+
36+
host = [q_h, k_h, v_h, l_h, m_h, O_h]
37+
device = [hip_check(hip.hipMalloc(h.size * h.itemsize)) for h in host]
38+
39+
for dev, h in zip(device, host):
40+
hip_check(
41+
hip.hipMemcpy(
42+
dev, h, h.size * h.itemsize, hip.hipMemcpyKind.hipMemcpyHostToDevice
43+
)
44+
)
45+
46+
return host, device
47+
48+
49+
def copy_device_host(host, device):
50+
for d, h in zip(device, host):
51+
hip_check(
52+
hip.hipMemcpy(
53+
h, d, h.size * h.itemsize, hip.hipMemcpyKind.hipMemcpyDeviceToHost
54+
)
55+
)
56+
hip_check(hip.hipFree(d))
57+
58+
return host
59+
60+
2661
# just so it doesn't get DCE'd by black/reformat
2762
# TypeError: 'mlir._mlir_libs._mlir.ir.BlockArgument' object is not subscriptable
2863
_ = memref
@@ -44,25 +79,19 @@ def gpu_module():
4479
ip = InsertionPoint.at_block_begin(gpu_module.regions[0].blocks[0])
4580
ip.__enter__()
4681

47-
batch_size = 16
48-
n_head = 12
49-
seq_len = 64
50-
head_embd = 64
51-
5282
Bc = 32
5383
Br = 32
5484

55-
B = batch_size
56-
nh = n_head
57-
N = seq_len
58-
d = head_embd
85+
B = 16
86+
nh = 12
87+
N = 128
88+
d = 128
5989

6090
import math
6191

6292
Tc = math.ceil(N / Bc)
6393
Tr = math.ceil(N / Br)
6494
softmax_scale = 1.0 / math.sqrt(d)
65-
tile_size = Bc * d # size of Qi, Kj, Vj
6695

6796

6897
def softmax(x, axis=None):
@@ -75,11 +104,11 @@ def manual_attn(q, k, v):
75104
# the kernel below overwrites the global math.........
76105
import math
77106

78-
q = q.reshape(batch_size, n_head, seq_len, head_embd)
79-
k = k.reshape(batch_size, n_head, seq_len, head_embd)
80-
v = v.reshape(batch_size, n_head, seq_len, head_embd)
107+
q = q.reshape(B, nh, N, d)
108+
k = k.reshape(B, nh, N, d)
109+
v = v.reshape(B, nh, N, d)
81110

82-
att = q @ k.transpose(0, 1, -2, -1) * (1.0 / math.sqrt(k.shape[-1]))
111+
att = q @ k.transpose(0, 1, 3, 2) * (1.0 / math.sqrt(k.shape[-1]))
83112
att = softmax(att, axis=-1)
84113
y = att @ v
85114
return y.flatten()
@@ -92,40 +121,46 @@ def manual_attn(q, k, v):
92121
@gpu_func(emit=True)
93122
@canonicalize(using=[scf.canonicalizer, arith.canonicalizer])
94123
def flash_attention(
95-
Q: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()),
96-
K: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()),
97-
V: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()),
124+
Q: T.memref(B * nh * N * d, T.f32()),
125+
K: T.memref(B * nh * N * d, T.f32()),
126+
V: T.memref(B * nh * N * d, T.f32()),
98127
l: T.memref(B * nh * N, T.f32()),
99128
m: T.memref(B * nh * N, T.f32()),
100-
O: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()),
129+
O: T.memref(B * nh * N * d, T.f32()),
101130
):
102131
tx = thread_idx.x
103-
bx = block_idx.x
104-
by = block_idx.y # batch and head index
132+
# batch idx, head_idx
133+
bx, by = block_idx.x, block_idx.y
134+
# gpu.printf("bx %ld, by %ld\n", bx, by)
105135

106136
# Offset into Q,K,V,O,l,m - different for each batch and head
107-
qkv_offset = bx * grid_dim.y * N * d + by * N * d # gridDim.y = nh
108-
lm_offset = bx * grid_dim.y * N + by * N # offset for l and m
137+
qkv_offset = bx * nh * N * d + by * N * d
138+
lm_offset = bx * nh * N + by * N # offset for l and m
109139

110140
# Define SRAM for Q,K,V,S
111141
sram = gpu.dynamic_shared_memory()
112-
Qi = memref.view(sram, (tile_size,), dtype=T.f32())
113-
Kj = memref.view(sram, (tile_size,), dtype=T.f32(), shift=tile_size * 1)
114-
Vj = memref.view(sram, (tile_size,), dtype=T.f32(), shift=tile_size * 2)
115-
S = memref.view(sram, (tile_size,), dtype=T.f32(), shift=tile_size * 3)
142+
Qi = memref.view(sram, (Br * d,), dtype=T.f32())
143+
Kj = memref.view(sram, (Bc * d,), dtype=T.f32(), shift=Qi.n_elements)
144+
Vj = memref.view(
145+
sram, (Bc * d,), dtype=T.f32(), shift=Qi.n_elements + Kj.n_elements
146+
)
147+
S = memref.view(
148+
sram,
149+
(Br * Bc,),
150+
dtype=T.f32(),
151+
shift=Qi.n_elements + Kj.n_elements + Vj.n_elements,
152+
)
116153

117154
for j in scf.range_(0, Tc):
118155
# Load Kj, Vj to SRAM
119156
for x in scf.range_(0, d):
120-
Kj[tx * d + x] = K[qkv_offset + tile_size * j + tx * d + x]
121-
Vj[tx * d + x] = V[qkv_offset + tile_size * j + tx * d + x]
122-
123-
gpu.barrier() # such that the inner loop can use the correct Kj, Vj
157+
Kj[tx * d + x] = K[qkv_offset + Bc * d * j + tx * d + x]
158+
Vj[tx * d + x] = V[qkv_offset + Bc * d * j + tx * d + x]
124159

125160
for i in scf.range_(0, Tr):
126161
# Load Qi to SRAM, l and m to registers
127162
for x in scf.range_(0, d):
128-
ii = qkv_offset + tile_size * i + tx * d + x
163+
ii = qkv_offset + Bc * d * i + tx * d + x
129164
Qi[tx * d + x] = Q[ii]
130165

131166
row_m_prev = m[lm_offset + Br * i + tx]
@@ -172,21 +207,18 @@ def flash_attention(
172207
pv += S[Bc * tx + y] * Vj[y * d + x]
173208
pv = yield pv
174209

175-
ii = qkv_offset + tile_size * i + tx * d + x
210+
ii = qkv_offset + Bc * d * i + tx * d + x
176211
O[ii] = div * (c * O[ii] + math.exp(row_m - row_m_new) * pv)
177212

178-
gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
179-
180213
m[lm_offset + Br * i + tx] = row_m_new
181214
l[lm_offset + Br * i + tx] = row_l_new
182215

183-
# gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
184-
# gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
216+
gpu.barrier()
185217

186218

187219
ip.__exit__(None, None, None)
188220

189-
sram_size = 4 * tile_size * np.float32().itemsize
221+
sram_size = 4 * Bc * d * np.float32().itemsize
190222

191223
launch_params = {
192224
flash_attention.__name__: (
@@ -206,6 +238,9 @@ def flash_attention(
206238
.rocdl_attach_target(chip=arch, O=3, abi="500"),
207239
)
208240

241+
# print(simplified_module)
242+
# exit()
243+
209244
lowered_module = run_pipeline(
210245
simplified_module,
211246
Pipeline()
@@ -216,7 +251,8 @@ def flash_attention(
216251
)
217252
)
218253
.gpu_to_llvm()
219-
.lower_to_llvm(),
254+
.lower_to_llvm()
255+
.ensure_debug_info_scope_on_llvm_func(emission_kind="Full"),
220256
# .Nested("llvm.func", Pipeline().sroa()),
221257
)
222258

@@ -236,68 +272,34 @@ def flash_attention(
236272
T.index(), np.prod(thread_dims)
237273
)
238274

239-
lowered_module = run_pipeline(lowered_module, Pipeline().gpu_module_to_binary())
275+
output_format = "bin"
276+
# output_format = "llvm"
277+
# output_format = "isa"
278+
279+
lowered_module = run_pipeline(
280+
lowered_module, Pipeline().gpu_module_to_binary(format=output_format)
281+
)
240282
hsaco = get_compile_object_bytes(lowered_module)
283+
if output_format in {"isa", "llvm", "offloading"}:
284+
with open(Path(__file__).parent / "flashattention.amdgcn", "wb") as f:
285+
f.write(hsaco)
286+
exit()
241287

242288
hip_module = hip_check(hip.hipModuleLoadData(hsaco))
243289

244-
q_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype(
245-
dtype=np.float32
246-
)
247-
k_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype(
248-
dtype=np.float32
249-
)
250-
v_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype(
251-
dtype=np.float32
252-
)
253-
l_h = np.zeros((B * nh * N), dtype=np.float32)
254-
m_h = np.full((B * nh * N), float(np.finfo(np.float32).min), dtype=np.float32)
255-
O_h = np.zeros_like(q_h, dtype=np.float32)
256-
257-
q_num_bytes = q_h.size * q_h.itemsize
258-
k_num_bytes = k_h.size * k_h.itemsize
259-
v_num_bytes = v_h.size * v_h.itemsize
260-
l_num_bytes = l_h.size * l_h.itemsize
261-
m_num_bytes = m_h.size * m_h.itemsize
262-
O_num_bytes = O_h.size * O_h.itemsize
263-
264-
q_d = hip_check(hip.hipMalloc(q_num_bytes))
265-
k_d = hip_check(hip.hipMalloc(k_num_bytes))
266-
v_d = hip_check(hip.hipMalloc(v_num_bytes))
267-
l_d = hip_check(hip.hipMalloc(l_num_bytes))
268-
m_d = hip_check(hip.hipMalloc(m_num_bytes))
269-
O_d = hip_check(hip.hipMalloc(O_num_bytes))
270-
271290
stream = 0
272291

273292
times = {
274293
flash_attention: 0,
275294
}
276-
# random.shuffle(kernels)
277-
runs = 16
295+
runs = 32
278296
for kernel in times:
279297
for i in range(runs):
280298
function = hip_check(
281299
hip.hipModuleGetFunction(hip_module, kernel.__name__.encode())
282300
)
283301
hip_check(hip.hipDeviceSynchronize())
284302

285-
for d, h, num_bytes in zip(
286-
[q_d, k_d, v_d, l_d, m_d, O_d],
287-
[q_h, k_h, v_h, l_h, m_h, O_h],
288-
[
289-
q_num_bytes,
290-
k_num_bytes,
291-
v_num_bytes,
292-
l_num_bytes,
293-
m_num_bytes,
294-
O_num_bytes,
295-
],
296-
):
297-
hip_check(
298-
hip.hipMemcpy(d, h, num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)
299-
)
300-
301303
(
302304
(
303305
blocks_per_grid_x,
@@ -312,6 +314,10 @@ def flash_attention(
312314
shared_memory,
313315
) = launch_params[kernel.__name__]
314316

317+
host, device = init_copy_host_device()
318+
q_h, k_h, v_h, *_ = host
319+
correct = manual_attn(q_h, k_h, v_h)
320+
315321
time_compute = launch_kernel(
316322
function.as_c_void_p(),
317323
blocks_per_grid_x,
@@ -322,36 +328,20 @@ def flash_attention(
322328
threads_per_block_z,
323329
stream,
324330
shared_memory,
325-
q_d,
326-
k_d,
327-
v_d,
328-
l_d,
329-
m_d,
330-
O_d,
331+
*device,
331332
)
332333

333-
hip_check(
334-
hip.hipMemcpy(
335-
l_h, l_d, l_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost
336-
)
337-
)
338-
hip_check(
339-
hip.hipMemcpy(
340-
m_h, m_d, m_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost
341-
)
342-
)
343-
hip_check(
344-
hip.hipMemcpy(
345-
O_h, O_d, O_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost
346-
)
347-
)
348-
correct = manual_attn(q_h, k_h, v_h)
334+
*_, O_h = copy_device_host(host, device)
349335
if not np.allclose(correct, O_h):
350-
print("correct", correct)
351-
print("l_h", l_h)
352-
print("m_h", m_h)
353-
print("output", O_h)
354-
print(f"{kernel.__name__} failed")
336+
with np.printoptions(threshold=np.inf, linewidth=np.inf):
337+
print(
338+
"correct - output:\n",
339+
correct.round().reshape(B, nh, N, d)
340+
- O_h.round().reshape(B, nh, N, d),
341+
)
342+
print(f"{kernel.__name__} failed\n")
343+
else:
344+
print(f"{kernel.__name__}: {time_compute:.03f}ms")
355345

356346
times[kernel] += time_compute
357347

0 commit comments

Comments
 (0)