diff --git a/examples/flash_attention.py b/examples/flash_attention.py index 9614739..e3f157d 100644 --- a/examples/flash_attention.py +++ b/examples/flash_attention.py @@ -2,12 +2,11 @@ import mlir.extras.types as T import numpy as np +from hip import hip from mlir.ir import InsertionPoint, IntegerAttr, UnitAttr - from mlir.extras.ast.canonicalize import canonicalize from mlir.extras.context import RAIIMLIRContextModule from mlir.extras.dialects.ext import memref, scf, arith, gpu, llvm -from mlir.dialects import math # noinspection PyUnresolvedReferences from mlir.extras.dialects.ext.gpu import ( @@ -23,17 +22,15 @@ from mlir.extras.util import find_ops # noinspection PyUnresolvedReferences -from util import hip_check, launch_kernel, hip_synchronize, hip_bindings_not_installed, get_hip_arch +from util import hip_check, launch_kernel, hip_synchronize -def init_copy_host_device(B, nh, N, d): - from hip import hip - - q_h = np.random.randint(0, 10, (B, nh, N, d)).astype(dtype=np.float32) - k_h = np.random.randint(0, 10, (B, nh, N, d)).astype(dtype=np.float32) - v_h = np.random.randint(0, 10, (B, nh, N, d)).astype(dtype=np.float32) - l_h = np.zeros((B, nh, N), dtype=np.float32) - m_h = np.full((B, nh, N), float(np.finfo(np.float32).min), dtype=np.float32) +def init_copy_host_device(): + q_h = np.random.randint(0, 10, (B * nh * N * d)).astype(dtype=np.float32) + k_h = np.random.randint(0, 10, (B * nh * N * d)).astype(dtype=np.float32) + v_h = np.random.randint(0, 10, (B * nh * N * d)).astype(dtype=np.float32) + l_h = np.zeros((B * nh * N), dtype=np.float32) + m_h = np.full((B * nh * N), float(np.finfo(np.float32).min), dtype=np.float32) O_h = np.zeros_like(q_h, dtype=np.float32) host = [q_h, k_h, v_h, l_h, m_h, O_h] @@ -50,8 +47,6 @@ def init_copy_host_device(B, nh, N, d): def copy_device_host(host, device): - from hip import hip - for d, h in zip(device, host): hip_check( hip.hipMemcpy( @@ -70,6 +65,10 @@ def copy_device_host(host, device): ctx = RAIIMLIRContextModule() set_container_module(ctx.module) +props = hip.hipDeviceProp_t() +hip_check(hip.hipGetDeviceProperties(props, 0)) +arch = props.gcnArchName.decode() + # just a default attr - actual target is set blow @module("kernels", [f'#rocdl.target']) @@ -88,7 +87,11 @@ def gpu_module(): N = 128 d = 128 -softmax_scale = 1.0 / float(np.sqrt(d)) +import math + +Tc = math.ceil(N / Bc) +Tr = math.ceil(N / Br) +softmax_scale = 1.0 / math.sqrt(d) def softmax(x, axis=None): @@ -98,25 +101,32 @@ def softmax(x, axis=None): def manual_attn(q, k, v): - att = q @ k.transpose(0, 1, 3, 2) * (1.0 / float(np.sqrt(k.shape[-1]))) + # the kernel below overwrites the global math......... + import math + + q = q.reshape(B, nh, N, d) + k = k.reshape(B, nh, N, d) + v = v.reshape(B, nh, N, d) + + att = q @ k.transpose(0, 1, 3, 2) * (1.0 / math.sqrt(k.shape[-1])) att = softmax(att, axis=-1) y = att @ v - return y + return y.flatten() -rank_reduce = memref.rank_reduce +from mlir.dialects import math # https://github.com/tspeterkim/flash-attention-minimal/blob/main/flash.cu @gpu_func(emit=True) @canonicalize(using=[scf.canonicalizer, arith.canonicalizer]) def flash_attention( - Q: T.memref(B, nh, N, d, T.f32()), - K: T.memref(B, nh, N, d, T.f32()), - V: T.memref(B, nh, N, d, T.f32()), - l: T.memref(B, nh, N, T.f32()), - m: T.memref(B, nh, N, T.f32()), - O: T.memref(B, nh, N, d, T.f32()), + Q: T.memref(B * nh * N * d, T.f32()), + K: T.memref(B * nh * N * d, T.f32()), + V: T.memref(B * nh * N * d, T.f32()), + l: T.memref(B * nh * N, T.f32()), + m: T.memref(B * nh * N, T.f32()), + O: T.memref(B * nh * N * d, T.f32()), ): tx = thread_idx.x # batch idx, head_idx @@ -124,54 +134,48 @@ def flash_attention( # gpu.printf("bx %ld, by %ld\n", bx, by) # Offset into Q,K,V,O,l,m - different for each batch and head - K = K[bx, by, :, :, rank_reduce] - V = V[bx, by, :, :, rank_reduce] - Q = Q[bx, by, :, :, rank_reduce] - O = O[bx, by, :, :, rank_reduce] - l = l[bx, by, :, rank_reduce] - m = m[bx, by, :, rank_reduce] + qkv_offset = bx * nh * N * d + by * N * d + lm_offset = bx * nh * N + by * N # offset for l and m # Define SRAM for Q,K,V,S sram = gpu.dynamic_shared_memory() - Qi = memref.view(sram, (Br, d), dtype=T.f32()) - Kj = memref.view(sram, (Bc, d), dtype=T.f32(), shift=Qi.n_elements) - Vj = memref.view(sram, (Bc, d), dtype=T.f32(), shift=Qi.n_elements + Kj.n_elements) + Qi = memref.view(sram, (Br * d,), dtype=T.f32()) + Kj = memref.view(sram, (Bc * d,), dtype=T.f32(), shift=Qi.n_elements) + Vj = memref.view( + sram, (Bc * d,), dtype=T.f32(), shift=Qi.n_elements + Kj.n_elements + ) S = memref.view( sram, - (Br, Bc), + (Br * Bc,), dtype=T.f32(), shift=Qi.n_elements + Kj.n_elements + Vj.n_elements, ) - for bc in scf.range_(0, N, Bc): + for j in scf.range_(0, Tc): # Load Kj, Vj to SRAM - K_ = K[bc : bc + 1, :] - V_ = V[bc : bc + 1, :] for x in scf.range_(0, d): - Kj[tx, x] = K_[tx, x] - Vj[tx, x] = V_[tx, x] + Kj[tx * d + x] = K[qkv_offset + Bc * d * j + tx * d + x] + Vj[tx * d + x] = V[qkv_offset + Bc * d * j + tx * d + x] - for br in scf.range_(0, N, Br): + for i in scf.range_(0, Tr): # Load Qi to SRAM, l and m to registers - Q_ = Q[br : br + 1, :] for x in scf.range_(0, d): - Qi[tx, x] = Q_[tx, x] + ii = qkv_offset + Bc * d * i + tx * d + x + Qi[tx * d + x] = Q[ii] - l_ = l[br : br + 1] - m_ = m[br : br + 1] - row_l_prev = l_[tx] - row_m_prev = m_[tx] + row_m_prev = m[lm_offset + Br * i + tx] + row_l_prev = l[lm_offset + Br * i + tx] # S = QK^T, row_m = rowmax(S) row_m: T.f32() = float(np.finfo(np.float32).min) for y, row_m, _ in scf.range_(0, Bc, iter_args=[row_m]): sum: T.f32() = 0.0 for x, sum, _ in scf.range_(0, d, iter_args=[sum]): - sum += Qi[tx, x] * Kj[y, x] + sum += Qi[tx * d + x] * Kj[y * d + x] sum = yield sum sum *= softmax_scale - S[tx, y] = sum + S[Bc * tx + y] = sum if sum > row_m: row_m_ = yield sum @@ -183,8 +187,8 @@ def flash_attention( # P = exp(S - row_m), row_l = rowsum(P) row_l: T.f32() = 0.0 for y, row_l, _ in scf.range_(0, Bc, iter_args=[row_l]): - S[tx, y] = math.exp(S[tx, y] - row_m) - row_l += S[tx, y] + S[Bc * tx + y] = math.exp(S[Bc * tx + y] - row_m) + row_l += S[Bc * tx + y] row_l = yield row_l # Compute new m and l @@ -194,30 +198,26 @@ def flash_attention( + math.exp(row_m - row_m_new) * row_l ) div = 1.0 / row_l_new - f1 = row_l_prev * math.exp(row_m_prev - row_m_new) - f2 = math.exp(row_m - row_m_new) + c = row_l_prev * math.exp(row_m_prev - row_m_new) # Write O, l, m to HBM - O_ = O[br : br + 1, :] for x in scf.range_(0, d): pv: T.f32() = 0.0 # Pij * Vj for y, pv, _ in scf.range_(0, Bc, iter_args=[pv]): - pv += S[tx, y] * Vj[y, x] + pv += S[Bc * tx + y] * Vj[y * d + x] pv = yield pv - O_[tx, x] = div * (f1 * O_[tx, x] + f2 * pv) + ii = qkv_offset + Bc * d * i + tx * d + x + O[ii] = div * (c * O[ii] + math.exp(row_m - row_m_new) * pv) - l_[tx] = row_l_new - m_[tx] = row_m_new + m[lm_offset + Br * i + tx] = row_m_new + l[lm_offset + Br * i + tx] = row_l_new gpu.barrier() ip.__exit__(None, None, None) -assert gpu_module.operation.verify() -# print(gpu_module) - sram_size = 4 * Bc * d * np.float32().itemsize launch_params = { @@ -235,11 +235,9 @@ def flash_attention( .cse() .loop_invariant_code_motion() .loop_invariant_subset_hoisting() - .rocdl_attach_target(chip=get_hip_arch(), O=3, abi="500"), + .rocdl_attach_target(chip=arch, O=3, abi="500"), ) -assert simplified_module.operation.verify() - # print(simplified_module) # exit() @@ -258,8 +256,6 @@ def flash_attention( # .Nested("llvm.func", Pipeline().sroa()), ) -assert lowered_module.operation.verify() - # print(lowered_module) gep = find_ops(lowered_module.operation, lambda o: isinstance(o.opview, llvm.GEPOp)) for g in gep: @@ -276,10 +272,6 @@ def flash_attention( T.index(), np.prod(thread_dims) ) -if hip_bindings_not_installed(): - exit() -from hip import hip - output_format = "bin" # output_format = "llvm" # output_format = "isa" @@ -289,11 +281,10 @@ def flash_attention( ) hsaco = get_compile_object_bytes(lowered_module) if output_format in {"isa", "llvm", "offloading"}: - with open(Path(__file__).parent / f"flashattention.{output_format}", "wb") as f: + with open(Path(__file__).parent / "flashattention.amdgcn", "wb") as f: f.write(hsaco) exit() - hip_module = hip_check(hip.hipModuleLoadData(hsaco)) stream = 0 @@ -323,7 +314,7 @@ def flash_attention( shared_memory, ) = launch_params[kernel.__name__] - host, device = init_copy_host_device(B, nh, N, d) + host, device = init_copy_host_device() q_h, k_h, v_h, *_ = host correct = manual_attn(q_h, k_h, v_h) @@ -345,7 +336,8 @@ def flash_attention( with np.printoptions(threshold=np.inf, linewidth=np.inf): print( "correct - output:\n", - correct.round() - O_h.round(), + correct.round().reshape(B, nh, N, d) + - O_h.round().reshape(B, nh, N, d), ) print(f"{kernel.__name__} failed\n") else: