From 5ef221e5b17e1cce88fcfa284e595c97fe6b16ce Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Sun, 13 Apr 2025 18:55:06 -0400 Subject: [PATCH 1/4] rocprof --- examples/att.json | 12 +++ examples/att.txt | 5 ++ examples/demo.py | 182 ++++++++++++++++++++++++++++++++++++++ examples/requirements.txt | 2 + examples/rocprof.sh | 14 +++ 5 files changed, 215 insertions(+) create mode 100644 examples/att.json create mode 100644 examples/att.txt create mode 100755 examples/demo.py create mode 100644 examples/requirements.txt create mode 100755 examples/rocprof.sh diff --git a/examples/att.json b/examples/att.json new file mode 100644 index 0000000..bba54f4 --- /dev/null +++ b/examples/att.json @@ -0,0 +1,12 @@ +{ + "jobs": [ + { + "advanced_thread_trace": true, + "att_parse" : "trace", + "att_target_cu" : 0, + "att_shader_engine_mask" : "0xF", + "att_simd_select": "0xF", + "att_buffer_size": "0x60000000" + } + ] +} diff --git a/examples/att.txt b/examples/att.txt new file mode 100644 index 0000000..1d58688 --- /dev/null +++ b/examples/att.txt @@ -0,0 +1,5 @@ +att: TARGET_CU=0 +SIMD_SELECT=0x3 +SE_MASK=0xFFFFFFFF +BUFFER_SIZE=192 +ISA_CAPTURE_MODE=2 \ No newline at end of file diff --git a/examples/demo.py b/examples/demo.py new file mode 100755 index 0000000..668dfbc --- /dev/null +++ b/examples/demo.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python + +import mlir.extras.types as T +import numpy as np +from hip import hip +from mlir.ir import InsertionPoint + +from mlir.extras.ast.canonicalize import canonicalize +from mlir.extras.context import RAIIMLIRContextModule +from mlir.extras.dialects.ext import memref, scf, arith, rocdl + +# noinspection PyUnresolvedReferences +from mlir.extras.dialects.ext.gpu import ( + all_reduce, + wait, + thread_attr as thread, + block_idx, + thread_idx, + block_dim, + GPUModuleMeta, + func as gpu_func, + set_container_module, + launch, + all_reduce_, + module, + get_compile_object_bytes, + lds_space, +) +from mlir.extras.runtime.passes import run_pipeline, Pipeline + +# noinspection PyUnresolvedReferences +from util import hip_check, launch_kernel, hip_synchronize + + +def time_to_gflops(time_ms, N): + return 1e-6 * (N * N * N * 2 + 3 * N * N) // time_ms + + +# just so it doesn't get DCE'd by black/reformat +# TypeError: 'mlir._mlir_libs._mlir.ir.BlockArgument' object is not subscriptable +_ = memref + +ctx = RAIIMLIRContextModule() +set_container_module(ctx.module) + +props = hip.hipDeviceProp_t() +hip_check(hip.hipGetDeviceProperties(props, 0)) +arch = props.gcnArchName.decode() + + +# just a default attr - actual target is set blow +@module("kernels", [f'#rocdl.target']) +def gpu_module(): + pass + + +ip = InsertionPoint.at_block_begin(gpu_module.regions[0].blocks[0]) +ip.__enter__() + +set_container_module(ctx.module) + +v_len = 16 +M, K, N = 1024, 1024, 1024 +v16f16 = T.vector(v_len, T.f16()) + + +@gpu_func +@canonicalize(using=scf.canonicalizer) +def smol_matmul( + a: T.memref(M, K, T.f16()), + b: T.memref(K, N, T.f16()), + c: T.memref(M, N, T.f16()), +): + lIdx = thread_idx.x + # a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and b + # a_frag will store one column of the 16x16 matrix A tile + # b_frag will store one row of the 16x16 matrix B tile + a_frag = arith.constant(np.full([v_len], 0.0, np.float16), v16f16) + b_frag = arith.constant(np.full([v_len], 0.0, np.float16), v16f16) + c_frag = arith.constant(np.full([v_len], 0.0, np.float16), v16f16) + + # lane is (0-31) mod 16 instead of 0-31 due to matrix replication in RDNA 3 + lane = lIdx % v_len + for ele in range(v_len): + b_frag[ele] = b[ele, lane] + a_frag[ele] = a[lane, ele] + # a_frag, b_frag = yield a_frag, b_frag + + # call the WMMA intrinsic + false = arith.constant(False, T.bool()) + c_frag = rocdl.wmma_f16_16x16x16_f16(v16f16, [a_frag, b_frag, c_frag, false]) + + for ele in range(v_len // 2): + r = ele * 2 + (lIdx // v_len) + # store results from unpacked c_frag output + c[r, lane] = c_frag[ele * 2] + + +props = hip.hipDeviceProp_t() +hip_check(hip.hipGetDeviceProperties(props, 0)) +arch = props.gcnArchName.decode().split(":")[0] + + +@module("naive", [f'#rocdl.target']) +def gpu_module(): + smol_matmul.emit() + + +ip.__exit__(None, None, None) + +lowered_module = run_pipeline( + gpu_module, + Pipeline() + .Gpu(Pipeline().convert_gpu_to_rocdl(use_bare_ptr_memref_call_conv=True)) + .rocdl_attach_target(chip=arch, abi="500", O=0) + .gpu_to_llvm() + .lower_to_llvm() + .ensure_debug_info_scope_on_llvm_func(emission_kind="Full") + .gpu_module_to_binary(), +) + +hsaco = get_compile_object_bytes(lowered_module) +hip_module = hip_check(hip.hipModuleLoadData(hsaco)) +function = hip_check( + hip.hipModuleGetFunction(hip_module, smol_matmul.__name__.encode()) +) + +a_h = np.random.randint(0, 10, (M, K)).astype(dtype=np.float16) +b_h = np.random.randint(0, 10, (K, N)).astype(dtype=np.float16) +c_h = -3 * np.ones((M, N), dtype=np.float16) + +a_num_bytes = a_h.size * a_h.itemsize +b_num_bytes = b_h.size * b_h.itemsize +c_num_bytes = c_h.size * c_h.itemsize + +a_d = hip_check(hip.hipMalloc(a_num_bytes)) +b_d = hip_check(hip.hipMalloc(b_num_bytes)) +c_d = hip_check(hip.hipMalloc(c_num_bytes)) + +hip_check(hip.hipMemcpy(a_d, a_h, a_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)) +hip_check(hip.hipMemcpy(b_d, b_h, b_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)) +hip_check(hip.hipMemcpy(c_d, c_h, c_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)) + +gridX = 32 +gridY = 32 +gridZ = 1 +warp_size = 32 +num_warps = 1 +stream = 0 +shared_memory = 0 + +launch_kernel( + function.as_c_void_p(), + gridX, + gridY, + gridZ, + warp_size, + num_warps, + 1, + stream, + shared_memory, + a_d, + b_d, + c_d, +) + +correct = a_h @ b_h +assert np.allclose(c_h, -3.0) +assert not np.allclose(correct, c_h) +hip_check(hip.hipMemcpy(c_h, c_d, c_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost)) + +# if not np.allclose(c_h, correct): +# with np.printoptions(threshold=np.inf, linewidth=200): +# print(correct) +# print(c_h) +# assert False + +hip_check(hip.hipFree(a_d)) +hip_check(hip.hipFree(b_d)) +hip_check(hip.hipFree(c_d)) + +hip_check(hip.hipModuleUnload(hip_module)) diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 0000000..3fb21a1 --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1,2 @@ +websockets +matplotlib diff --git a/examples/rocprof.sh b/examples/rocprof.sh new file mode 100755 index 0000000..c10ae9b --- /dev/null +++ b/examples/rocprof.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#set -eux + +export PATH=/opt/rocm-6.5.0/bin:$PATH +export PYTHONPATH=/home/mlevental/dev_projects/mlir-python-extras +export OUTPUT_PATH=$PWD +export ROCPROF_ATT_LIBRARY_PATH=/opt/rocm-6.5.0/att-decoder-v3-3.0.0-Linux/lib + +rm -rf traces +#rocprofv2 --kernel-trace /home/mlevental/dev_projects/mlir-python-extras/examples/demo.py +#rocprofv2 -i att.txt --kernel-trace --plugin att auto --mode file,csv -d traces/ /home/mlevental/dev_projects/mlir-python-extras/examples/demo.py +/opt/rocm-6.5.0/bin/rocprofv3 -i att.json -d traces -- /home/mlevental/dev_projects/mlir-python-extras/examples/demo.py +../../ROCProfiler-ATT-Viewer-amd-staging/cmake-build-debug/ATTViewer traces/ui* \ No newline at end of file From c8fdf3c60d961856a290191bae0ad09c2f943c36 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 18 Apr 2025 17:29:15 -0400 Subject: [PATCH 2/4] fix rocprof.sh --- examples/att.json | 2 +- examples/demo.py | 143 +++++++++++++++++++++++++------------------- examples/rocprof.sh | 26 ++++++-- 3 files changed, 104 insertions(+), 67 deletions(-) diff --git a/examples/att.json b/examples/att.json index bba54f4..80bfdee 100644 --- a/examples/att.json +++ b/examples/att.json @@ -5,7 +5,7 @@ "att_parse" : "trace", "att_target_cu" : 0, "att_shader_engine_mask" : "0xF", - "att_simd_select": "0xF", + "att_simd_select": "0x0", "att_buffer_size": "0x60000000" } ] diff --git a/examples/demo.py b/examples/demo.py index 668dfbc..464e720 100755 --- a/examples/demo.py +++ b/examples/demo.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +from pathlib import Path import mlir.extras.types as T import numpy as np @@ -7,7 +8,7 @@ from mlir.extras.ast.canonicalize import canonicalize from mlir.extras.context import RAIIMLIRContextModule -from mlir.extras.dialects.ext import memref, scf, arith, rocdl +from mlir.extras.dialects.ext import memref, scf, arith, rocdl, gpu, llvm, vector # noinspection PyUnresolvedReferences from mlir.extras.dialects.ext.gpu import ( @@ -25,6 +26,7 @@ module, get_compile_object_bytes, lds_space, + dynamic_shared_memory, ) from mlir.extras.runtime.passes import run_pipeline, Pipeline @@ -43,10 +45,6 @@ def time_to_gflops(time_ms, N): ctx = RAIIMLIRContextModule() set_container_module(ctx.module) -props = hip.hipDeviceProp_t() -hip_check(hip.hipGetDeviceProperties(props, 0)) -arch = props.gcnArchName.decode() - # just a default attr - actual target is set blow @module("kernels", [f'#rocdl.target']) @@ -60,40 +58,44 @@ def gpu_module(): set_container_module(ctx.module) v_len = 16 -M, K, N = 1024, 1024, 1024 -v16f16 = T.vector(v_len, T.f16()) +M, K, N = 512, 512, 512 +TILE_SIZE = BK = 16 +dtype = T.f16() +np_dtype = np.float16 +v16 = T.vector(v_len, dtype) @gpu_func @canonicalize(using=scf.canonicalizer) -def smol_matmul( - a: T.memref(M, K, T.f16()), - b: T.memref(K, N, T.f16()), - c: T.memref(M, N, T.f16()), +def kernel( + A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype) ): - lIdx = thread_idx.x - # a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and b - # a_frag will store one column of the 16x16 matrix A tile - # b_frag will store one row of the 16x16 matrix B tile - a_frag = arith.constant(np.full([v_len], 0.0, np.float16), v16f16) - b_frag = arith.constant(np.full([v_len], 0.0, np.float16), v16f16) - c_frag = arith.constant(np.full([v_len], 0.0, np.float16), v16f16) - - # lane is (0-31) mod 16 instead of 0-31 due to matrix replication in RDNA 3 - lane = lIdx % v_len - for ele in range(v_len): - b_frag[ele] = b[ele, lane] - a_frag[ele] = a[lane, ele] - # a_frag, b_frag = yield a_frag, b_frag - - # call the WMMA intrinsic - false = arith.constant(False, T.bool()) - c_frag = rocdl.wmma_f16_16x16x16_f16(v16f16, [a_frag, b_frag, c_frag, false]) - - for ele in range(v_len // 2): - r = ele * 2 + (lIdx // v_len) - # store results from unpacked c_frag output - c[r, lane] = c_frag[ele * 2] + base = dynamic_shared_memory() + As = memref.view(base, (TILE_SIZE, TILE_SIZE), dtype=dtype) + Bs = memref.view( + base, (TILE_SIZE, TILE_SIZE), dtype=dtype, shift=TILE_SIZE * TILE_SIZE + ) + + row = block_idx.y * TILE_SIZE + thread_idx.y + col = block_idx.x * TILE_SIZE + thread_idx.x + + sum = arith.constant(np.full([v_len], 0.0, np_dtype), v16) + for t, sum, _ in scf.range_(0, N, BK, iter_args=[sum]): + Bs[thread_idx.y, thread_idx.x] = B[thread_idx.y + t, col] + As[thread_idx.y, thread_idx.x] = A[row, thread_idx.x + t] + + gpu.barrier() + + a_frag = As @ vector.load(v16) @ [thread_idx.y, 0] + b_frag = Bs @ vector.load(v16) @ [0, thread_idx.x] + false = arith.constant(False, T.bool()) + sum = rocdl.wmma_f16_16x16x16_f16(v16, [a_frag, b_frag, sum, false]) + + gpu.barrier() + + sum = yield sum + + C[row, col] = sum props = hip.hipDeviceProp_t() @@ -103,31 +105,38 @@ def smol_matmul( @module("naive", [f'#rocdl.target']) def gpu_module(): - smol_matmul.emit() + kernel.emit() ip.__exit__(None, None, None) +O = 3 +output_format = "binary" + lowered_module = run_pipeline( gpu_module, Pipeline() .Gpu(Pipeline().convert_gpu_to_rocdl(use_bare_ptr_memref_call_conv=True)) - .rocdl_attach_target(chip=arch, abi="500", O=0) + .rocdl_attach_target(chip=arch, abi="500", O=O) .gpu_to_llvm() .lower_to_llvm() .ensure_debug_info_scope_on_llvm_func(emission_kind="Full") - .gpu_module_to_binary(), + .gpu_module_to_binary(format=output_format), ) hsaco = get_compile_object_bytes(lowered_module) +if output_format == "assembly": + with open(Path(__file__).parent / f"hsacoO{O}.txt", "wb") as f: + f.write(hsaco) + exit() hip_module = hip_check(hip.hipModuleLoadData(hsaco)) -function = hip_check( - hip.hipModuleGetFunction(hip_module, smol_matmul.__name__.encode()) -) +function = hip_check(hip.hipModuleGetFunction(hip_module, kernel.__name__.encode())) -a_h = np.random.randint(0, 10, (M, K)).astype(dtype=np.float16) -b_h = np.random.randint(0, 10, (K, N)).astype(dtype=np.float16) -c_h = -3 * np.ones((M, N), dtype=np.float16) +# a_h = np.random.randint(0, 10, (M, K)).astype(dtype=np_dtype) +# b_h = np.random.randint(0, 10, (K, N)).astype(dtype=np_dtype) +a_h = np.ones((M, K)).astype(dtype=np_dtype) +b_h = np.ones((K, N)).astype(dtype=np_dtype) +c_h = -3 * np.ones((M, N), dtype=np_dtype) a_num_bytes = a_h.size * a_h.itemsize b_num_bytes = b_h.size * b_h.itemsize @@ -141,22 +150,34 @@ def gpu_module(): hip_check(hip.hipMemcpy(b_d, b_h, b_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)) hip_check(hip.hipMemcpy(c_d, c_h, c_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)) -gridX = 32 -gridY = 32 -gridZ = 1 -warp_size = 32 -num_warps = 1 +( + ( + blocks_per_grid_x, + blocks_per_grid_y, + blocks_per_grid_z, + ), + ( + threads_per_block_x, + threads_per_block_y, + threads_per_block_z, + ), + shared_memory, +) = ( + (N // TILE_SIZE, N // TILE_SIZE, 1), + (TILE_SIZE, TILE_SIZE, 1), + 2 * TILE_SIZE * TILE_SIZE * dtype.width // 8, +) + stream = 0 -shared_memory = 0 launch_kernel( function.as_c_void_p(), - gridX, - gridY, - gridZ, - warp_size, - num_warps, - 1, + blocks_per_grid_x, + blocks_per_grid_y, + blocks_per_grid_z, + threads_per_block_x, + threads_per_block_y, + threads_per_block_z, stream, shared_memory, a_d, @@ -169,11 +190,13 @@ def gpu_module(): assert not np.allclose(correct, c_h) hip_check(hip.hipMemcpy(c_h, c_d, c_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost)) -# if not np.allclose(c_h, correct): -# with np.printoptions(threshold=np.inf, linewidth=200): -# print(correct) -# print(c_h) -# assert False + +if not np.allclose(c_h, correct): + with np.printoptions(threshold=np.inf, linewidth=np.inf): + # print("correct", correct) + # print("c_h", c_h) + print("off by atol", np.max(np.abs(correct - c_h))) + print("off by rtol", np.max(np.abs(correct - c_h) / correct)) hip_check(hip.hipFree(a_d)) hip_check(hip.hipFree(b_d)) diff --git a/examples/rocprof.sh b/examples/rocprof.sh index c10ae9b..46f24e4 100755 --- a/examples/rocprof.sh +++ b/examples/rocprof.sh @@ -2,13 +2,27 @@ #set -eux +cd "$(dirname "$0")" +SCRIPT_DIR="$(pwd)" +echo "Script directory: $SCRIPT_DIR" + export PATH=/opt/rocm-6.5.0/bin:$PATH -export PYTHONPATH=/home/mlevental/dev_projects/mlir-python-extras -export OUTPUT_PATH=$PWD +export PYTHONPATH=$SCRIPT_DIR/.. +export OUTPUT_PATH=$SCRIPT_DIR export ROCPROF_ATT_LIBRARY_PATH=/opt/rocm-6.5.0/att-decoder-v3-3.0.0-Linux/lib +export ATT_VIEWER=../../ROCProfiler-ATT-Viewer-amd-staging/cmake-build-debug/ATTViewer + rm -rf traces -#rocprofv2 --kernel-trace /home/mlevental/dev_projects/mlir-python-extras/examples/demo.py -#rocprofv2 -i att.txt --kernel-trace --plugin att auto --mode file,csv -d traces/ /home/mlevental/dev_projects/mlir-python-extras/examples/demo.py -/opt/rocm-6.5.0/bin/rocprofv3 -i att.json -d traces -- /home/mlevental/dev_projects/mlir-python-extras/examples/demo.py -../../ROCProfiler-ATT-Viewer-amd-staging/cmake-build-debug/ATTViewer traces/ui* \ No newline at end of file +/opt/rocm-6.5.0/bin/rocprofv3 -i att.json -d traces -o demo_trace -- $SCRIPT_DIR/demo.py + +for ui in $(ls $SCRIPT_DIR/traces) ; do + if [ -d $SCRIPT_DIR/traces/$ui ]; then + ls $SCRIPT_DIR/traces/$ui | grep se > /dev/null + if [ $? == 0 ]; then + UI_PATH=$SCRIPT_DIR/traces/$ui + fi + fi +done + +$ATT_VIEWER $UI_PATH \ No newline at end of file From 1cded3fdafdba3a34be137ba9218500e9d513bca Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 18 Apr 2025 22:02:45 -0400 Subject: [PATCH 3/4] working demo --- examples/demo.py | 52 +++++++++++++++++++++++++++++++---------------- tests/test_gpu.py | 3 +++ 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/examples/demo.py b/examples/demo.py index 464e720..cb429ca 100755 --- a/examples/demo.py +++ b/examples/demo.py @@ -58,7 +58,7 @@ def gpu_module(): set_container_module(ctx.module) v_len = 16 -M, K, N = 512, 512, 512 +M, K, N = 16, 16, 16 TILE_SIZE = BK = 16 dtype = T.f16() np_dtype = np.float16 @@ -78,24 +78,26 @@ def kernel( row = block_idx.y * TILE_SIZE + thread_idx.y col = block_idx.x * TILE_SIZE + thread_idx.x + # gpu.printf("(%ld, %ld)\n", row, col) + # vector.print_(source=row) sum = arith.constant(np.full([v_len], 0.0, np_dtype), v16) for t, sum, _ in scf.range_(0, N, BK, iter_args=[sum]): - Bs[thread_idx.y, thread_idx.x] = B[thread_idx.y + t, col] + Bs[thread_idx.y, thread_idx.x] = B[col, thread_idx.y + t] As[thread_idx.y, thread_idx.x] = A[row, thread_idx.x + t] gpu.barrier() - a_frag = As @ vector.load(v16) @ [thread_idx.y, 0] - b_frag = Bs @ vector.load(v16) @ [0, thread_idx.x] + lane = thread_idx.x % v_len + a_frag = As @ vector.load(v16) @ [lane, 0] + b_frag = Bs @ vector.load(v16) @ [lane, 0] + + # call the WMMA intrinsic false = arith.constant(False, T.bool()) sum = rocdl.wmma_f16_16x16x16_f16(v16, [a_frag, b_frag, sum, false]) - - gpu.barrier() - sum = yield sum - C[row, col] = sum + C[row, col] = sum[2 * (row // 2)] props = hip.hipDeviceProp_t() @@ -110,13 +112,21 @@ def gpu_module(): ip.__exit__(None, None, None) +# gpu_module = run_pipeline(gpu_module, Pipeline().cse()) +# print(gpu_module) + O = 3 output_format = "binary" lowered_module = run_pipeline( gpu_module, Pipeline() - .Gpu(Pipeline().convert_gpu_to_rocdl(use_bare_ptr_memref_call_conv=True)) + .Gpu( + Pipeline().convert_gpu_to_rocdl( + use_bare_ptr_memref_call_conv=True, + runtime="HIP", + ) + ) .rocdl_attach_target(chip=arch, abi="500", O=O) .gpu_to_llvm() .lower_to_llvm() @@ -132,12 +142,20 @@ def gpu_module(): hip_module = hip_check(hip.hipModuleLoadData(hsaco)) function = hip_check(hip.hipModuleGetFunction(hip_module, kernel.__name__.encode())) -# a_h = np.random.randint(0, 10, (M, K)).astype(dtype=np_dtype) -# b_h = np.random.randint(0, 10, (K, N)).astype(dtype=np_dtype) -a_h = np.ones((M, K)).astype(dtype=np_dtype) -b_h = np.ones((K, N)).astype(dtype=np_dtype) -c_h = -3 * np.ones((M, N), dtype=np_dtype) +a_h = np.random.randint(0, 10, (M, K)).astype(dtype=np_dtype) +b_h = np.random.randint(0, 10, (K, N)).astype(dtype=np_dtype) +# a_h = np.ones((M, K)).astype(dtype=np_dtype) +# b_h = np.ones((K, N)).astype(dtype=np_dtype) +c_h = 0 * np.ones((M, N), dtype=np_dtype) + +for k in range(K): + a = a_h[:, k] + b = b_h[k, :] + c_h += np.outer(a, b) + +assert np.allclose(a_h @ b_h, c_h) +c_h = -3 * np.ones((M, N), dtype=np_dtype) a_num_bytes = a_h.size * a_h.itemsize b_num_bytes = b_h.size * b_h.itemsize c_num_bytes = c_h.size * c_h.itemsize @@ -190,14 +208,14 @@ def gpu_module(): assert not np.allclose(correct, c_h) hip_check(hip.hipMemcpy(c_h, c_d, c_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost)) - if not np.allclose(c_h, correct): with np.printoptions(threshold=np.inf, linewidth=np.inf): - # print("correct", correct) - # print("c_h", c_h) + print("correct\n", correct) + print("c_h\n", c_h) print("off by atol", np.max(np.abs(correct - c_h))) print("off by rtol", np.max(np.abs(correct - c_h) / correct)) + hip_check(hip.hipFree(a_d)) hip_check(hip.hipFree(b_d)) hip_check(hip.hipFree(c_d)) diff --git a/tests/test_gpu.py b/tests/test_gpu.py index 7abccb4..e9695fc 100644 --- a/tests/test_gpu.py +++ b/tests/test_gpu.py @@ -1230,6 +1230,9 @@ def smol_matmul( c_frag = rocdl.wmma_f16_16x16x16_f16(a_frag, b_frag, c_frag) + for i in scf.range_(v_len): + gpu.printf("(%02ld, %02ld, %02ld), %f\n", lIdx, lane, i, c_frag[i]) + for i in scf.range_(v_len): gpu.printf("(%02ld, %02ld, %02ld), %f\n", lIdx, lane, i, c_frag[i]) From 400e9ae376ce213fc7b0dc623159cc6665c9cfee Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 18 Apr 2025 23:52:38 -0400 Subject: [PATCH 4/4] working demo --- examples/att.txt | 5 --- examples/requirements.txt | 2 - examples/rocprof.sh | 3 +- examples/{demo.py => schedule_barriers.py} | 51 ++++++++++++++-------- tests/test_gpu.py | 3 -- 5 files changed, 33 insertions(+), 31 deletions(-) delete mode 100644 examples/att.txt delete mode 100644 examples/requirements.txt rename examples/{demo.py => schedule_barriers.py} (79%) diff --git a/examples/att.txt b/examples/att.txt deleted file mode 100644 index 1d58688..0000000 --- a/examples/att.txt +++ /dev/null @@ -1,5 +0,0 @@ -att: TARGET_CU=0 -SIMD_SELECT=0x3 -SE_MASK=0xFFFFFFFF -BUFFER_SIZE=192 -ISA_CAPTURE_MODE=2 \ No newline at end of file diff --git a/examples/requirements.txt b/examples/requirements.txt deleted file mode 100644 index 3fb21a1..0000000 --- a/examples/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -websockets -matplotlib diff --git a/examples/rocprof.sh b/examples/rocprof.sh index 46f24e4..765277e 100755 --- a/examples/rocprof.sh +++ b/examples/rocprof.sh @@ -8,13 +8,12 @@ echo "Script directory: $SCRIPT_DIR" export PATH=/opt/rocm-6.5.0/bin:$PATH export PYTHONPATH=$SCRIPT_DIR/.. -export OUTPUT_PATH=$SCRIPT_DIR export ROCPROF_ATT_LIBRARY_PATH=/opt/rocm-6.5.0/att-decoder-v3-3.0.0-Linux/lib export ATT_VIEWER=../../ROCProfiler-ATT-Viewer-amd-staging/cmake-build-debug/ATTViewer rm -rf traces -/opt/rocm-6.5.0/bin/rocprofv3 -i att.json -d traces -o demo_trace -- $SCRIPT_DIR/demo.py +rocprofv3 -i att.json -d traces -o demo_trace -- $SCRIPT_DIR/schedule_barriers.py for ui in $(ls $SCRIPT_DIR/traces) ; do if [ -d $SCRIPT_DIR/traces/$ui ]; then diff --git a/examples/demo.py b/examples/schedule_barriers.py similarity index 79% rename from examples/demo.py rename to examples/schedule_barriers.py index cb429ca..402920d 100755 --- a/examples/demo.py +++ b/examples/schedule_barriers.py @@ -58,7 +58,7 @@ def gpu_module(): set_container_module(ctx.module) v_len = 16 -M, K, N = 16, 16, 16 +M, K, N = 512, 512, 512 TILE_SIZE = BK = 16 dtype = T.f16() np_dtype = np.float16 @@ -78,23 +78,27 @@ def kernel( row = block_idx.y * TILE_SIZE + thread_idx.y col = block_idx.x * TILE_SIZE + thread_idx.x + lane = thread_idx.x % v_len # gpu.printf("(%ld, %ld)\n", row, col) # vector.print_(source=row) sum = arith.constant(np.full([v_len], 0.0, np_dtype), v16) - for t, sum, _ in scf.range_(0, N, BK, iter_args=[sum]): - Bs[thread_idx.y, thread_idx.x] = B[col, thread_idx.y + t] - As[thread_idx.y, thread_idx.x] = A[row, thread_idx.x + t] + Bs[thread_idx.y, thread_idx.x] = B[col, thread_idx.y + 0] + As[thread_idx.y, thread_idx.x] = A[row, thread_idx.x + 0] + + for t, sum, _ in scf.range_(BK, N + BK, BK, iter_args=[sum]): gpu.barrier() - lane = thread_idx.x % v_len a_frag = As @ vector.load(v16) @ [lane, 0] b_frag = Bs @ vector.load(v16) @ [lane, 0] - # call the WMMA intrinsic - false = arith.constant(False, T.bool()) - sum = rocdl.wmma_f16_16x16x16_f16(v16, [a_frag, b_frag, sum, false]) + sum = rocdl.wmma_f16_16x16x16_f16(a_frag, b_frag, sum) + + if arith.index_cast(t, T.i32()) < N: + Bs[thread_idx.y, thread_idx.x] = B[col, thread_idx.y + t] + As[thread_idx.y, thread_idx.x] = A[row, thread_idx.x + t] + sum = yield sum C[row, col] = sum[2 * (row // 2)] @@ -142,18 +146,25 @@ def gpu_module(): hip_module = hip_check(hip.hipModuleLoadData(hsaco)) function = hip_check(hip.hipModuleGetFunction(hip_module, kernel.__name__.encode())) -a_h = np.random.randint(0, 10, (M, K)).astype(dtype=np_dtype) -b_h = np.random.randint(0, 10, (K, N)).astype(dtype=np_dtype) -# a_h = np.ones((M, K)).astype(dtype=np_dtype) -# b_h = np.ones((K, N)).astype(dtype=np_dtype) -c_h = 0 * np.ones((M, N), dtype=np_dtype) +# a_h = np.random.randint(1, 5, (M, K)).astype(dtype=np_dtype) +# b_h = np.random.randint(1, 5, (K, N)).astype(dtype=np_dtype) +# a_h = np.random.rand(M, K).astype(np_dtype) +# b_h = np.random.rand(K, N).astype(np_dtype) + +a_h = 3 * np.ones((M, K)).astype(dtype=np_dtype) +a_h[0 : M // 2, 0 : K // 2] = 0 +a_h[M // 2 : M, K // 2 : K] = 1 +b_h = 2 * np.ones((K, N)).astype(dtype=np_dtype) +b_h[0 : K // 2, 0 : N // 2] = 2 +b_h[K // 2 : K, N // 2 : N] = 3 + +c_h = 0 * np.ones((M, N), dtype=np.float32) for k in range(K): - a = a_h[:, k] - b = b_h[k, :] + a = a_h.astype(np.float32)[:, k] + b = b_h.astype(np.float32)[k, :] c_h += np.outer(a, b) - -assert np.allclose(a_h @ b_h, c_h) +assert np.allclose(a_h.astype(np.float32) @ b_h.astype(np.float32), c_h) c_h = -3 * np.ones((M, N), dtype=np_dtype) a_num_bytes = a_h.size * a_h.itemsize @@ -210,10 +221,12 @@ def gpu_module(): if not np.allclose(c_h, correct): with np.printoptions(threshold=np.inf, linewidth=np.inf): - print("correct\n", correct) - print("c_h\n", c_h) + # print("correct\n", correct) + # print("c_h\n", c_h) print("off by atol", np.max(np.abs(correct - c_h))) print("off by rtol", np.max(np.abs(correct - c_h) / correct)) + print("num incorrect", np.sum(np.abs(correct - c_h) != 0)) + print("fraction incorrect", np.sum(np.abs(correct - c_h) != 0) / (M * N)) hip_check(hip.hipFree(a_d)) diff --git a/tests/test_gpu.py b/tests/test_gpu.py index e9695fc..7abccb4 100644 --- a/tests/test_gpu.py +++ b/tests/test_gpu.py @@ -1230,9 +1230,6 @@ def smol_matmul( c_frag = rocdl.wmma_f16_16x16x16_f16(a_frag, b_frag, c_frag) - for i in scf.range_(v_len): - gpu.printf("(%02ld, %02ld, %02ld), %f\n", lIdx, lane, i, c_frag[i]) - for i in scf.range_(v_len): gpu.printf("(%02ld, %02ld, %02ld), %f\n", lIdx, lane, i, c_frag[i])