From 5a0d72f5e2ad4792c3e437f2e6aad2130dbc7602 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 2 May 2025 01:20:32 -0400 Subject: [PATCH 1/2] flash attention --- examples/flash_attention.py | 288 ++++++++++++++++++------------------ 1 file changed, 145 insertions(+), 143 deletions(-) diff --git a/examples/flash_attention.py b/examples/flash_attention.py index 9614739..625c4eb 100644 --- a/examples/flash_attention.py +++ b/examples/flash_attention.py @@ -1,13 +1,11 @@ -from pathlib import Path - import mlir.extras.types as T import numpy as np +from hip import hip from mlir.ir import InsertionPoint, IntegerAttr, UnitAttr from mlir.extras.ast.canonicalize import canonicalize from mlir.extras.context import RAIIMLIRContextModule from mlir.extras.dialects.ext import memref, scf, arith, gpu, llvm -from mlir.dialects import math # noinspection PyUnresolvedReferences from mlir.extras.dialects.ext.gpu import ( @@ -23,45 +21,7 @@ from mlir.extras.util import find_ops # noinspection PyUnresolvedReferences -from util import hip_check, launch_kernel, hip_synchronize, hip_bindings_not_installed, get_hip_arch - - -def init_copy_host_device(B, nh, N, d): - from hip import hip - - q_h = np.random.randint(0, 10, (B, nh, N, d)).astype(dtype=np.float32) - k_h = np.random.randint(0, 10, (B, nh, N, d)).astype(dtype=np.float32) - v_h = np.random.randint(0, 10, (B, nh, N, d)).astype(dtype=np.float32) - l_h = np.zeros((B, nh, N), dtype=np.float32) - m_h = np.full((B, nh, N), float(np.finfo(np.float32).min), dtype=np.float32) - O_h = np.zeros_like(q_h, dtype=np.float32) - - host = [q_h, k_h, v_h, l_h, m_h, O_h] - device = [hip_check(hip.hipMalloc(h.size * h.itemsize)) for h in host] - - for dev, h in zip(device, host): - hip_check( - hip.hipMemcpy( - dev, h, h.size * h.itemsize, hip.hipMemcpyKind.hipMemcpyHostToDevice - ) - ) - - return host, device - - -def copy_device_host(host, device): - from hip import hip - - for d, h in zip(device, host): - hip_check( - hip.hipMemcpy( - h, d, h.size * h.itemsize, hip.hipMemcpyKind.hipMemcpyDeviceToHost - ) - ) - hip_check(hip.hipFree(d)) - - return host - +from util import hip_check, launch_kernel, hip_synchronize # just so it doesn't get DCE'd by black/reformat # TypeError: 'mlir._mlir_libs._mlir.ir.BlockArgument' object is not subscriptable @@ -70,6 +30,10 @@ def copy_device_host(host, device): ctx = RAIIMLIRContextModule() set_container_module(ctx.module) +props = hip.hipDeviceProp_t() +hip_check(hip.hipGetDeviceProperties(props, 0)) +arch = props.gcnArchName.decode() + # just a default attr - actual target is set blow @module("kernels", [f'#rocdl.target']) @@ -80,15 +44,25 @@ def gpu_module(): ip = InsertionPoint.at_block_begin(gpu_module.regions[0].blocks[0]) ip.__enter__() +batch_size = 16 +n_head = 12 +seq_len = 64 +head_embd = 64 + Bc = 32 Br = 32 -B = 16 -nh = 12 -N = 128 -d = 128 +B = batch_size +nh = n_head +N = seq_len +d = head_embd -softmax_scale = 1.0 / float(np.sqrt(d)) +import math + +Tc = math.ceil(N / Bc) +Tr = math.ceil(N / Br) +softmax_scale = 1.0 / math.sqrt(d) +tile_size = Bc * d # size of Qi, Kj, Vj def softmax(x, axis=None): @@ -98,80 +72,75 @@ def softmax(x, axis=None): def manual_attn(q, k, v): - att = q @ k.transpose(0, 1, 3, 2) * (1.0 / float(np.sqrt(k.shape[-1]))) + # the kernel below overwrites the global math......... + import math + + q = q.reshape(batch_size, n_head, seq_len, head_embd) + k = k.reshape(batch_size, n_head, seq_len, head_embd) + v = v.reshape(batch_size, n_head, seq_len, head_embd) + + att = q @ k.transpose(0, 1, -2, -1) * (1.0 / math.sqrt(k.shape[-1])) att = softmax(att, axis=-1) y = att @ v - return y + return y.flatten() -rank_reduce = memref.rank_reduce +from mlir.dialects import math # https://github.com/tspeterkim/flash-attention-minimal/blob/main/flash.cu @gpu_func(emit=True) @canonicalize(using=[scf.canonicalizer, arith.canonicalizer]) def flash_attention( - Q: T.memref(B, nh, N, d, T.f32()), - K: T.memref(B, nh, N, d, T.f32()), - V: T.memref(B, nh, N, d, T.f32()), - l: T.memref(B, nh, N, T.f32()), - m: T.memref(B, nh, N, T.f32()), - O: T.memref(B, nh, N, d, T.f32()), + Q: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()), + K: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()), + V: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()), + l: T.memref(B * nh * N, T.f32()), + m: T.memref(B * nh * N, T.f32()), + O: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()), ): tx = thread_idx.x - # batch idx, head_idx - bx, by = block_idx.x, block_idx.y - # gpu.printf("bx %ld, by %ld\n", bx, by) + bx = block_idx.x + by = block_idx.y # batch and head index # Offset into Q,K,V,O,l,m - different for each batch and head - K = K[bx, by, :, :, rank_reduce] - V = V[bx, by, :, :, rank_reduce] - Q = Q[bx, by, :, :, rank_reduce] - O = O[bx, by, :, :, rank_reduce] - l = l[bx, by, :, rank_reduce] - m = m[bx, by, :, rank_reduce] + qkv_offset = bx * grid_dim.y * N * d + by * N * d # gridDim.y = nh + lm_offset = bx * grid_dim.y * N + by * N # offset for l and m # Define SRAM for Q,K,V,S sram = gpu.dynamic_shared_memory() - Qi = memref.view(sram, (Br, d), dtype=T.f32()) - Kj = memref.view(sram, (Bc, d), dtype=T.f32(), shift=Qi.n_elements) - Vj = memref.view(sram, (Bc, d), dtype=T.f32(), shift=Qi.n_elements + Kj.n_elements) - S = memref.view( - sram, - (Br, Bc), - dtype=T.f32(), - shift=Qi.n_elements + Kj.n_elements + Vj.n_elements, - ) + Qi = memref.view(sram, (tile_size,), dtype=T.f32()) + Kj = memref.view(sram, (tile_size,), dtype=T.f32(), shift=tile_size * 1) + Vj = memref.view(sram, (tile_size,), dtype=T.f32(), shift=tile_size * 2) + S = memref.view(sram, (tile_size,), dtype=T.f32(), shift=tile_size * 3) - for bc in scf.range_(0, N, Bc): + for j in scf.range_(0, Tc): # Load Kj, Vj to SRAM - K_ = K[bc : bc + 1, :] - V_ = V[bc : bc + 1, :] for x in scf.range_(0, d): - Kj[tx, x] = K_[tx, x] - Vj[tx, x] = V_[tx, x] + Kj[tx * d + x] = K[qkv_offset + tile_size * j + tx * d + x] + Vj[tx * d + x] = V[qkv_offset + tile_size * j + tx * d + x] + + gpu.barrier() # such that the inner loop can use the correct Kj, Vj - for br in scf.range_(0, N, Br): + for i in scf.range_(0, Tr): # Load Qi to SRAM, l and m to registers - Q_ = Q[br : br + 1, :] for x in scf.range_(0, d): - Qi[tx, x] = Q_[tx, x] + ii = qkv_offset + tile_size * i + tx * d + x + Qi[tx * d + x] = Q[ii] - l_ = l[br : br + 1] - m_ = m[br : br + 1] - row_l_prev = l_[tx] - row_m_prev = m_[tx] + row_m_prev = m[lm_offset + Br * i + tx] + row_l_prev = l[lm_offset + Br * i + tx] # S = QK^T, row_m = rowmax(S) row_m: T.f32() = float(np.finfo(np.float32).min) for y, row_m, _ in scf.range_(0, Bc, iter_args=[row_m]): sum: T.f32() = 0.0 for x, sum, _ in scf.range_(0, d, iter_args=[sum]): - sum += Qi[tx, x] * Kj[y, x] + sum += Qi[tx * d + x] * Kj[y * d + x] sum = yield sum sum *= softmax_scale - S[tx, y] = sum + S[Bc * tx + y] = sum if sum > row_m: row_m_ = yield sum @@ -183,8 +152,8 @@ def flash_attention( # P = exp(S - row_m), row_l = rowsum(P) row_l: T.f32() = 0.0 for y, row_l, _ in scf.range_(0, Bc, iter_args=[row_l]): - S[tx, y] = math.exp(S[tx, y] - row_m) - row_l += S[tx, y] + S[Bc * tx + y] = math.exp(S[Bc * tx + y] - row_m) + row_l += S[Bc * tx + y] row_l = yield row_l # Compute new m and l @@ -194,31 +163,30 @@ def flash_attention( + math.exp(row_m - row_m_new) * row_l ) div = 1.0 / row_l_new - f1 = row_l_prev * math.exp(row_m_prev - row_m_new) - f2 = math.exp(row_m - row_m_new) + c = row_l_prev * math.exp(row_m_prev - row_m_new) # Write O, l, m to HBM - O_ = O[br : br + 1, :] for x in scf.range_(0, d): pv: T.f32() = 0.0 # Pij * Vj for y, pv, _ in scf.range_(0, Bc, iter_args=[pv]): - pv += S[tx, y] * Vj[y, x] + pv += S[Bc * tx + y] * Vj[y * d + x] pv = yield pv - O_[tx, x] = div * (f1 * O_[tx, x] + f2 * pv) + ii = qkv_offset + tile_size * i + tx * d + x + O[ii] = div * (c * O[ii] + math.exp(row_m - row_m_new) * pv) - l_[tx] = row_l_new - m_[tx] = row_m_new + gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop - gpu.barrier() + m[lm_offset + Br * i + tx] = row_m_new + l[lm_offset + Br * i + tx] = row_l_new + # gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop + # gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop -ip.__exit__(None, None, None) -assert gpu_module.operation.verify() -# print(gpu_module) +ip.__exit__(None, None, None) -sram_size = 4 * Bc * d * np.float32().itemsize +sram_size = 4 * tile_size * np.float32().itemsize launch_params = { flash_attention.__name__: ( @@ -235,14 +203,9 @@ def flash_attention( .cse() .loop_invariant_code_motion() .loop_invariant_subset_hoisting() - .rocdl_attach_target(chip=get_hip_arch(), O=3, abi="500"), + .rocdl_attach_target(chip=arch, O=3, abi="500"), ) -assert simplified_module.operation.verify() - -# print(simplified_module) -# exit() - lowered_module = run_pipeline( simplified_module, Pipeline() @@ -253,13 +216,10 @@ def flash_attention( ) ) .gpu_to_llvm() - .lower_to_llvm() - .ensure_debug_info_scope_on_llvm_func(emission_kind="Full"), + .lower_to_llvm(), # .Nested("llvm.func", Pipeline().sroa()), ) -assert lowered_module.operation.verify() - # print(lowered_module) gep = find_ops(lowered_module.operation, lambda o: isinstance(o.opview, llvm.GEPOp)) for g in gep: @@ -276,32 +236,45 @@ def flash_attention( T.index(), np.prod(thread_dims) ) -if hip_bindings_not_installed(): - exit() -from hip import hip - -output_format = "bin" -# output_format = "llvm" -# output_format = "isa" - -lowered_module = run_pipeline( - lowered_module, Pipeline().gpu_module_to_binary(format=output_format) -) +lowered_module = run_pipeline(lowered_module, Pipeline().gpu_module_to_binary()) hsaco = get_compile_object_bytes(lowered_module) -if output_format in {"isa", "llvm", "offloading"}: - with open(Path(__file__).parent / f"flashattention.{output_format}", "wb") as f: - f.write(hsaco) - exit() - hip_module = hip_check(hip.hipModuleLoadData(hsaco)) +q_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype( + dtype=np.float32 +) +k_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype( + dtype=np.float32 +) +v_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype( + dtype=np.float32 +) +l_h = np.zeros((B * nh * N), dtype=np.float32) +m_h = np.full((B * nh * N), float(np.finfo(np.float32).min), dtype=np.float32) +O_h = np.zeros_like(q_h, dtype=np.float32) + +q_num_bytes = q_h.size * q_h.itemsize +k_num_bytes = k_h.size * k_h.itemsize +v_num_bytes = v_h.size * v_h.itemsize +l_num_bytes = l_h.size * l_h.itemsize +m_num_bytes = m_h.size * m_h.itemsize +O_num_bytes = O_h.size * O_h.itemsize + +q_d = hip_check(hip.hipMalloc(q_num_bytes)) +k_d = hip_check(hip.hipMalloc(k_num_bytes)) +v_d = hip_check(hip.hipMalloc(v_num_bytes)) +l_d = hip_check(hip.hipMalloc(l_num_bytes)) +m_d = hip_check(hip.hipMalloc(m_num_bytes)) +O_d = hip_check(hip.hipMalloc(O_num_bytes)) + stream = 0 times = { flash_attention: 0, } -runs = 32 +# random.shuffle(kernels) +runs = 16 for kernel in times: for i in range(runs): function = hip_check( @@ -309,6 +282,22 @@ def flash_attention( ) hip_check(hip.hipDeviceSynchronize()) + for d, h, num_bytes in zip( + [q_d, k_d, v_d, l_d, m_d, O_d], + [q_h, k_h, v_h, l_h, m_h, O_h], + [ + q_num_bytes, + k_num_bytes, + v_num_bytes, + l_num_bytes, + m_num_bytes, + O_num_bytes, + ], + ): + hip_check( + hip.hipMemcpy(d, h, num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice) + ) + ( ( blocks_per_grid_x, @@ -323,10 +312,6 @@ def flash_attention( shared_memory, ) = launch_params[kernel.__name__] - host, device = init_copy_host_device(B, nh, N, d) - q_h, k_h, v_h, *_ = host - correct = manual_attn(q_h, k_h, v_h) - time_compute = launch_kernel( function.as_c_void_p(), blocks_per_grid_x, @@ -337,19 +322,36 @@ def flash_attention( threads_per_block_z, stream, shared_memory, - *device, + q_d, + k_d, + v_d, + l_d, + m_d, + O_d, ) - *_, O_h = copy_device_host(host, device) + hip_check( + hip.hipMemcpy( + l_h, l_d, l_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost + ) + ) + hip_check( + hip.hipMemcpy( + m_h, m_d, m_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost + ) + ) + hip_check( + hip.hipMemcpy( + O_h, O_d, O_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost + ) + ) + correct = manual_attn(q_h, k_h, v_h) if not np.allclose(correct, O_h): - with np.printoptions(threshold=np.inf, linewidth=np.inf): - print( - "correct - output:\n", - correct.round() - O_h.round(), - ) - print(f"{kernel.__name__} failed\n") - else: - print(f"{kernel.__name__}: {time_compute:.03f}ms") + print("correct", correct) + print("l_h", l_h) + print("m_h", m_h) + print("output", O_h) + print(f"{kernel.__name__} failed") times[kernel] += time_compute From 986f8104744043b41731597ca86d180317fd546e Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 2 May 2025 13:59:32 -0400 Subject: [PATCH 2/2] works --- examples/flash_attention.py | 214 +++++++++++++++++------------------- 1 file changed, 102 insertions(+), 112 deletions(-) diff --git a/examples/flash_attention.py b/examples/flash_attention.py index 625c4eb..e3f157d 100644 --- a/examples/flash_attention.py +++ b/examples/flash_attention.py @@ -1,8 +1,9 @@ +from pathlib import Path + import mlir.extras.types as T import numpy as np from hip import hip from mlir.ir import InsertionPoint, IntegerAttr, UnitAttr - from mlir.extras.ast.canonicalize import canonicalize from mlir.extras.context import RAIIMLIRContextModule from mlir.extras.dialects.ext import memref, scf, arith, gpu, llvm @@ -23,6 +24,40 @@ # noinspection PyUnresolvedReferences from util import hip_check, launch_kernel, hip_synchronize + +def init_copy_host_device(): + q_h = np.random.randint(0, 10, (B * nh * N * d)).astype(dtype=np.float32) + k_h = np.random.randint(0, 10, (B * nh * N * d)).astype(dtype=np.float32) + v_h = np.random.randint(0, 10, (B * nh * N * d)).astype(dtype=np.float32) + l_h = np.zeros((B * nh * N), dtype=np.float32) + m_h = np.full((B * nh * N), float(np.finfo(np.float32).min), dtype=np.float32) + O_h = np.zeros_like(q_h, dtype=np.float32) + + host = [q_h, k_h, v_h, l_h, m_h, O_h] + device = [hip_check(hip.hipMalloc(h.size * h.itemsize)) for h in host] + + for dev, h in zip(device, host): + hip_check( + hip.hipMemcpy( + dev, h, h.size * h.itemsize, hip.hipMemcpyKind.hipMemcpyHostToDevice + ) + ) + + return host, device + + +def copy_device_host(host, device): + for d, h in zip(device, host): + hip_check( + hip.hipMemcpy( + h, d, h.size * h.itemsize, hip.hipMemcpyKind.hipMemcpyDeviceToHost + ) + ) + hip_check(hip.hipFree(d)) + + return host + + # just so it doesn't get DCE'd by black/reformat # TypeError: 'mlir._mlir_libs._mlir.ir.BlockArgument' object is not subscriptable _ = memref @@ -44,25 +79,19 @@ def gpu_module(): ip = InsertionPoint.at_block_begin(gpu_module.regions[0].blocks[0]) ip.__enter__() -batch_size = 16 -n_head = 12 -seq_len = 64 -head_embd = 64 - Bc = 32 Br = 32 -B = batch_size -nh = n_head -N = seq_len -d = head_embd +B = 16 +nh = 12 +N = 128 +d = 128 import math Tc = math.ceil(N / Bc) Tr = math.ceil(N / Br) softmax_scale = 1.0 / math.sqrt(d) -tile_size = Bc * d # size of Qi, Kj, Vj def softmax(x, axis=None): @@ -75,11 +104,11 @@ def manual_attn(q, k, v): # the kernel below overwrites the global math......... import math - q = q.reshape(batch_size, n_head, seq_len, head_embd) - k = k.reshape(batch_size, n_head, seq_len, head_embd) - v = v.reshape(batch_size, n_head, seq_len, head_embd) + q = q.reshape(B, nh, N, d) + k = k.reshape(B, nh, N, d) + v = v.reshape(B, nh, N, d) - att = q @ k.transpose(0, 1, -2, -1) * (1.0 / math.sqrt(k.shape[-1])) + att = q @ k.transpose(0, 1, 3, 2) * (1.0 / math.sqrt(k.shape[-1])) att = softmax(att, axis=-1) y = att @ v return y.flatten() @@ -92,40 +121,46 @@ def manual_attn(q, k, v): @gpu_func(emit=True) @canonicalize(using=[scf.canonicalizer, arith.canonicalizer]) def flash_attention( - Q: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()), - K: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()), - V: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()), + Q: T.memref(B * nh * N * d, T.f32()), + K: T.memref(B * nh * N * d, T.f32()), + V: T.memref(B * nh * N * d, T.f32()), l: T.memref(B * nh * N, T.f32()), m: T.memref(B * nh * N, T.f32()), - O: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()), + O: T.memref(B * nh * N * d, T.f32()), ): tx = thread_idx.x - bx = block_idx.x - by = block_idx.y # batch and head index + # batch idx, head_idx + bx, by = block_idx.x, block_idx.y + # gpu.printf("bx %ld, by %ld\n", bx, by) # Offset into Q,K,V,O,l,m - different for each batch and head - qkv_offset = bx * grid_dim.y * N * d + by * N * d # gridDim.y = nh - lm_offset = bx * grid_dim.y * N + by * N # offset for l and m + qkv_offset = bx * nh * N * d + by * N * d + lm_offset = bx * nh * N + by * N # offset for l and m # Define SRAM for Q,K,V,S sram = gpu.dynamic_shared_memory() - Qi = memref.view(sram, (tile_size,), dtype=T.f32()) - Kj = memref.view(sram, (tile_size,), dtype=T.f32(), shift=tile_size * 1) - Vj = memref.view(sram, (tile_size,), dtype=T.f32(), shift=tile_size * 2) - S = memref.view(sram, (tile_size,), dtype=T.f32(), shift=tile_size * 3) + Qi = memref.view(sram, (Br * d,), dtype=T.f32()) + Kj = memref.view(sram, (Bc * d,), dtype=T.f32(), shift=Qi.n_elements) + Vj = memref.view( + sram, (Bc * d,), dtype=T.f32(), shift=Qi.n_elements + Kj.n_elements + ) + S = memref.view( + sram, + (Br * Bc,), + dtype=T.f32(), + shift=Qi.n_elements + Kj.n_elements + Vj.n_elements, + ) for j in scf.range_(0, Tc): # Load Kj, Vj to SRAM for x in scf.range_(0, d): - Kj[tx * d + x] = K[qkv_offset + tile_size * j + tx * d + x] - Vj[tx * d + x] = V[qkv_offset + tile_size * j + tx * d + x] - - gpu.barrier() # such that the inner loop can use the correct Kj, Vj + Kj[tx * d + x] = K[qkv_offset + Bc * d * j + tx * d + x] + Vj[tx * d + x] = V[qkv_offset + Bc * d * j + tx * d + x] for i in scf.range_(0, Tr): # Load Qi to SRAM, l and m to registers for x in scf.range_(0, d): - ii = qkv_offset + tile_size * i + tx * d + x + ii = qkv_offset + Bc * d * i + tx * d + x Qi[tx * d + x] = Q[ii] row_m_prev = m[lm_offset + Br * i + tx] @@ -172,21 +207,18 @@ def flash_attention( pv += S[Bc * tx + y] * Vj[y * d + x] pv = yield pv - ii = qkv_offset + tile_size * i + tx * d + x + ii = qkv_offset + Bc * d * i + tx * d + x O[ii] = div * (c * O[ii] + math.exp(row_m - row_m_new) * pv) - gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop - m[lm_offset + Br * i + tx] = row_m_new l[lm_offset + Br * i + tx] = row_l_new - # gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop - # gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop + gpu.barrier() ip.__exit__(None, None, None) -sram_size = 4 * tile_size * np.float32().itemsize +sram_size = 4 * Bc * d * np.float32().itemsize launch_params = { flash_attention.__name__: ( @@ -206,6 +238,9 @@ def flash_attention( .rocdl_attach_target(chip=arch, O=3, abi="500"), ) +# print(simplified_module) +# exit() + lowered_module = run_pipeline( simplified_module, Pipeline() @@ -216,7 +251,8 @@ def flash_attention( ) ) .gpu_to_llvm() - .lower_to_llvm(), + .lower_to_llvm() + .ensure_debug_info_scope_on_llvm_func(emission_kind="Full"), # .Nested("llvm.func", Pipeline().sroa()), ) @@ -236,45 +272,27 @@ def flash_attention( T.index(), np.prod(thread_dims) ) -lowered_module = run_pipeline(lowered_module, Pipeline().gpu_module_to_binary()) +output_format = "bin" +# output_format = "llvm" +# output_format = "isa" + +lowered_module = run_pipeline( + lowered_module, Pipeline().gpu_module_to_binary(format=output_format) +) hsaco = get_compile_object_bytes(lowered_module) +if output_format in {"isa", "llvm", "offloading"}: + with open(Path(__file__).parent / "flashattention.amdgcn", "wb") as f: + f.write(hsaco) + exit() hip_module = hip_check(hip.hipModuleLoadData(hsaco)) -q_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype( - dtype=np.float32 -) -k_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype( - dtype=np.float32 -) -v_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype( - dtype=np.float32 -) -l_h = np.zeros((B * nh * N), dtype=np.float32) -m_h = np.full((B * nh * N), float(np.finfo(np.float32).min), dtype=np.float32) -O_h = np.zeros_like(q_h, dtype=np.float32) - -q_num_bytes = q_h.size * q_h.itemsize -k_num_bytes = k_h.size * k_h.itemsize -v_num_bytes = v_h.size * v_h.itemsize -l_num_bytes = l_h.size * l_h.itemsize -m_num_bytes = m_h.size * m_h.itemsize -O_num_bytes = O_h.size * O_h.itemsize - -q_d = hip_check(hip.hipMalloc(q_num_bytes)) -k_d = hip_check(hip.hipMalloc(k_num_bytes)) -v_d = hip_check(hip.hipMalloc(v_num_bytes)) -l_d = hip_check(hip.hipMalloc(l_num_bytes)) -m_d = hip_check(hip.hipMalloc(m_num_bytes)) -O_d = hip_check(hip.hipMalloc(O_num_bytes)) - stream = 0 times = { flash_attention: 0, } -# random.shuffle(kernels) -runs = 16 +runs = 32 for kernel in times: for i in range(runs): function = hip_check( @@ -282,22 +300,6 @@ def flash_attention( ) hip_check(hip.hipDeviceSynchronize()) - for d, h, num_bytes in zip( - [q_d, k_d, v_d, l_d, m_d, O_d], - [q_h, k_h, v_h, l_h, m_h, O_h], - [ - q_num_bytes, - k_num_bytes, - v_num_bytes, - l_num_bytes, - m_num_bytes, - O_num_bytes, - ], - ): - hip_check( - hip.hipMemcpy(d, h, num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice) - ) - ( ( blocks_per_grid_x, @@ -312,6 +314,10 @@ def flash_attention( shared_memory, ) = launch_params[kernel.__name__] + host, device = init_copy_host_device() + q_h, k_h, v_h, *_ = host + correct = manual_attn(q_h, k_h, v_h) + time_compute = launch_kernel( function.as_c_void_p(), blocks_per_grid_x, @@ -322,36 +328,20 @@ def flash_attention( threads_per_block_z, stream, shared_memory, - q_d, - k_d, - v_d, - l_d, - m_d, - O_d, + *device, ) - hip_check( - hip.hipMemcpy( - l_h, l_d, l_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost - ) - ) - hip_check( - hip.hipMemcpy( - m_h, m_d, m_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost - ) - ) - hip_check( - hip.hipMemcpy( - O_h, O_d, O_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost - ) - ) - correct = manual_attn(q_h, k_h, v_h) + *_, O_h = copy_device_host(host, device) if not np.allclose(correct, O_h): - print("correct", correct) - print("l_h", l_h) - print("m_h", m_h) - print("output", O_h) - print(f"{kernel.__name__} failed") + with np.printoptions(threshold=np.inf, linewidth=np.inf): + print( + "correct - output:\n", + correct.round().reshape(B, nh, N, d) + - O_h.round().reshape(B, nh, N, d), + ) + print(f"{kernel.__name__} failed\n") + else: + print(f"{kernel.__name__}: {time_compute:.03f}ms") times[kernel] += time_compute