Skip to content

[NO MERGE] flash attention noslicing #144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 64 additions & 72 deletions examples/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import mlir.extras.types as T
import numpy as np
from hip import hip
from mlir.ir import InsertionPoint, IntegerAttr, UnitAttr

from mlir.extras.ast.canonicalize import canonicalize
from mlir.extras.context import RAIIMLIRContextModule
from mlir.extras.dialects.ext import memref, scf, arith, gpu, llvm
from mlir.dialects import math

# noinspection PyUnresolvedReferences
from mlir.extras.dialects.ext.gpu import (
Expand All @@ -23,17 +22,15 @@
from mlir.extras.util import find_ops

# noinspection PyUnresolvedReferences
from util import hip_check, launch_kernel, hip_synchronize, hip_bindings_not_installed, get_hip_arch
from util import hip_check, launch_kernel, hip_synchronize


def init_copy_host_device(B, nh, N, d):
from hip import hip

q_h = np.random.randint(0, 10, (B, nh, N, d)).astype(dtype=np.float32)
k_h = np.random.randint(0, 10, (B, nh, N, d)).astype(dtype=np.float32)
v_h = np.random.randint(0, 10, (B, nh, N, d)).astype(dtype=np.float32)
l_h = np.zeros((B, nh, N), dtype=np.float32)
m_h = np.full((B, nh, N), float(np.finfo(np.float32).min), dtype=np.float32)
def init_copy_host_device():
q_h = np.random.randint(0, 10, (B * nh * N * d)).astype(dtype=np.float32)
k_h = np.random.randint(0, 10, (B * nh * N * d)).astype(dtype=np.float32)
v_h = np.random.randint(0, 10, (B * nh * N * d)).astype(dtype=np.float32)
l_h = np.zeros((B * nh * N), dtype=np.float32)
m_h = np.full((B * nh * N), float(np.finfo(np.float32).min), dtype=np.float32)
O_h = np.zeros_like(q_h, dtype=np.float32)

host = [q_h, k_h, v_h, l_h, m_h, O_h]
Expand All @@ -50,8 +47,6 @@ def init_copy_host_device(B, nh, N, d):


def copy_device_host(host, device):
from hip import hip

for d, h in zip(device, host):
hip_check(
hip.hipMemcpy(
Expand All @@ -70,6 +65,10 @@ def copy_device_host(host, device):
ctx = RAIIMLIRContextModule()
set_container_module(ctx.module)

props = hip.hipDeviceProp_t()
hip_check(hip.hipGetDeviceProperties(props, 0))
arch = props.gcnArchName.decode()


# just a default attr - actual target is set blow
@module("kernels", [f'#rocdl.target<abi = "500">'])
Expand All @@ -88,7 +87,11 @@ def gpu_module():
N = 128
d = 128

softmax_scale = 1.0 / float(np.sqrt(d))
import math

Tc = math.ceil(N / Bc)
Tr = math.ceil(N / Br)
softmax_scale = 1.0 / math.sqrt(d)


def softmax(x, axis=None):
Expand All @@ -98,80 +101,81 @@ def softmax(x, axis=None):


def manual_attn(q, k, v):
att = q @ k.transpose(0, 1, 3, 2) * (1.0 / float(np.sqrt(k.shape[-1])))
# the kernel below overwrites the global math.........
import math

q = q.reshape(B, nh, N, d)
k = k.reshape(B, nh, N, d)
v = v.reshape(B, nh, N, d)

att = q @ k.transpose(0, 1, 3, 2) * (1.0 / math.sqrt(k.shape[-1]))
att = softmax(att, axis=-1)
y = att @ v
return y
return y.flatten()


rank_reduce = memref.rank_reduce
from mlir.dialects import math


# https://github.com/tspeterkim/flash-attention-minimal/blob/main/flash.cu
@gpu_func(emit=True)
@canonicalize(using=[scf.canonicalizer, arith.canonicalizer])
def flash_attention(
Q: T.memref(B, nh, N, d, T.f32()),
K: T.memref(B, nh, N, d, T.f32()),
V: T.memref(B, nh, N, d, T.f32()),
l: T.memref(B, nh, N, T.f32()),
m: T.memref(B, nh, N, T.f32()),
O: T.memref(B, nh, N, d, T.f32()),
Q: T.memref(B * nh * N * d, T.f32()),
K: T.memref(B * nh * N * d, T.f32()),
V: T.memref(B * nh * N * d, T.f32()),
l: T.memref(B * nh * N, T.f32()),
m: T.memref(B * nh * N, T.f32()),
O: T.memref(B * nh * N * d, T.f32()),
):
tx = thread_idx.x
# batch idx, head_idx
bx, by = block_idx.x, block_idx.y
# gpu.printf("bx %ld, by %ld\n", bx, by)

# Offset into Q,K,V,O,l,m - different for each batch and head
K = K[bx, by, :, :, rank_reduce]
V = V[bx, by, :, :, rank_reduce]
Q = Q[bx, by, :, :, rank_reduce]
O = O[bx, by, :, :, rank_reduce]
l = l[bx, by, :, rank_reduce]
m = m[bx, by, :, rank_reduce]
qkv_offset = bx * nh * N * d + by * N * d
lm_offset = bx * nh * N + by * N # offset for l and m

# Define SRAM for Q,K,V,S
sram = gpu.dynamic_shared_memory()
Qi = memref.view(sram, (Br, d), dtype=T.f32())
Kj = memref.view(sram, (Bc, d), dtype=T.f32(), shift=Qi.n_elements)
Vj = memref.view(sram, (Bc, d), dtype=T.f32(), shift=Qi.n_elements + Kj.n_elements)
Qi = memref.view(sram, (Br * d,), dtype=T.f32())
Kj = memref.view(sram, (Bc * d,), dtype=T.f32(), shift=Qi.n_elements)
Vj = memref.view(
sram, (Bc * d,), dtype=T.f32(), shift=Qi.n_elements + Kj.n_elements
)
S = memref.view(
sram,
(Br, Bc),
(Br * Bc,),
dtype=T.f32(),
shift=Qi.n_elements + Kj.n_elements + Vj.n_elements,
)

for bc in scf.range_(0, N, Bc):
for j in scf.range_(0, Tc):
# Load Kj, Vj to SRAM
K_ = K[bc : bc + 1, :]
V_ = V[bc : bc + 1, :]
for x in scf.range_(0, d):
Kj[tx, x] = K_[tx, x]
Vj[tx, x] = V_[tx, x]
Kj[tx * d + x] = K[qkv_offset + Bc * d * j + tx * d + x]
Vj[tx * d + x] = V[qkv_offset + Bc * d * j + tx * d + x]

for br in scf.range_(0, N, Br):
for i in scf.range_(0, Tr):
# Load Qi to SRAM, l and m to registers
Q_ = Q[br : br + 1, :]
for x in scf.range_(0, d):
Qi[tx, x] = Q_[tx, x]
ii = qkv_offset + Bc * d * i + tx * d + x
Qi[tx * d + x] = Q[ii]

l_ = l[br : br + 1]
m_ = m[br : br + 1]
row_l_prev = l_[tx]
row_m_prev = m_[tx]
row_m_prev = m[lm_offset + Br * i + tx]
row_l_prev = l[lm_offset + Br * i + tx]

# S = QK^T, row_m = rowmax(S)
row_m: T.f32() = float(np.finfo(np.float32).min)
for y, row_m, _ in scf.range_(0, Bc, iter_args=[row_m]):
sum: T.f32() = 0.0
for x, sum, _ in scf.range_(0, d, iter_args=[sum]):
sum += Qi[tx, x] * Kj[y, x]
sum += Qi[tx * d + x] * Kj[y * d + x]
sum = yield sum

sum *= softmax_scale
S[tx, y] = sum
S[Bc * tx + y] = sum

if sum > row_m:
row_m_ = yield sum
Expand All @@ -183,8 +187,8 @@ def flash_attention(
# P = exp(S - row_m), row_l = rowsum(P)
row_l: T.f32() = 0.0
for y, row_l, _ in scf.range_(0, Bc, iter_args=[row_l]):
S[tx, y] = math.exp(S[tx, y] - row_m)
row_l += S[tx, y]
S[Bc * tx + y] = math.exp(S[Bc * tx + y] - row_m)
row_l += S[Bc * tx + y]
row_l = yield row_l

# Compute new m and l
Expand All @@ -194,30 +198,26 @@ def flash_attention(
+ math.exp(row_m - row_m_new) * row_l
)
div = 1.0 / row_l_new
f1 = row_l_prev * math.exp(row_m_prev - row_m_new)
f2 = math.exp(row_m - row_m_new)
c = row_l_prev * math.exp(row_m_prev - row_m_new)

# Write O, l, m to HBM
O_ = O[br : br + 1, :]
for x in scf.range_(0, d):
pv: T.f32() = 0.0 # Pij * Vj
for y, pv, _ in scf.range_(0, Bc, iter_args=[pv]):
pv += S[tx, y] * Vj[y, x]
pv += S[Bc * tx + y] * Vj[y * d + x]
pv = yield pv

O_[tx, x] = div * (f1 * O_[tx, x] + f2 * pv)
ii = qkv_offset + Bc * d * i + tx * d + x
O[ii] = div * (c * O[ii] + math.exp(row_m - row_m_new) * pv)

l_[tx] = row_l_new
m_[tx] = row_m_new
m[lm_offset + Br * i + tx] = row_m_new
l[lm_offset + Br * i + tx] = row_l_new

gpu.barrier()


ip.__exit__(None, None, None)

assert gpu_module.operation.verify()
# print(gpu_module)

sram_size = 4 * Bc * d * np.float32().itemsize

launch_params = {
Expand All @@ -235,11 +235,9 @@ def flash_attention(
.cse()
.loop_invariant_code_motion()
.loop_invariant_subset_hoisting()
.rocdl_attach_target(chip=get_hip_arch(), O=3, abi="500"),
.rocdl_attach_target(chip=arch, O=3, abi="500"),
)

assert simplified_module.operation.verify()

# print(simplified_module)
# exit()

Expand All @@ -258,8 +256,6 @@ def flash_attention(
# .Nested("llvm.func", Pipeline().sroa()),
)

assert lowered_module.operation.verify()

# print(lowered_module)
gep = find_ops(lowered_module.operation, lambda o: isinstance(o.opview, llvm.GEPOp))
for g in gep:
Expand All @@ -276,10 +272,6 @@ def flash_attention(
T.index(), np.prod(thread_dims)
)

if hip_bindings_not_installed():
exit()
from hip import hip

output_format = "bin"
# output_format = "llvm"
# output_format = "isa"
Expand All @@ -289,11 +281,10 @@ def flash_attention(
)
hsaco = get_compile_object_bytes(lowered_module)
if output_format in {"isa", "llvm", "offloading"}:
with open(Path(__file__).parent / f"flashattention.{output_format}", "wb") as f:
with open(Path(__file__).parent / "flashattention.amdgcn", "wb") as f:
f.write(hsaco)
exit()


hip_module = hip_check(hip.hipModuleLoadData(hsaco))

stream = 0
Expand Down Expand Up @@ -323,7 +314,7 @@ def flash_attention(
shared_memory,
) = launch_params[kernel.__name__]

host, device = init_copy_host_device(B, nh, N, d)
host, device = init_copy_host_device()
q_h, k_h, v_h, *_ = host
correct = manual_attn(q_h, k_h, v_h)

Expand All @@ -345,7 +336,8 @@ def flash_attention(
with np.printoptions(threshold=np.inf, linewidth=np.inf):
print(
"correct - output:\n",
correct.round() - O_h.round(),
correct.round().reshape(B, nh, N, d)
- O_h.round().reshape(B, nh, N, d),
)
print(f"{kernel.__name__} failed\n")
else:
Expand Down
Loading