Skip to content

Commit 9b55540

Browse files
committed
rocprof
1 parent ba5e9bc commit 9b55540

File tree

7 files changed

+221
-6
lines changed

7 files changed

+221
-6
lines changed

examples/att.json

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"jobs": [
3+
{
4+
"advanced_thread_trace": true,
5+
"att_parse" : "trace",
6+
"att_target_cu" : 0,
7+
"att_shader_engine_mask" : "0xF",
8+
"att_simd_select": "0xF",
9+
"att_buffer_size": "0x60000000"
10+
}
11+
]
12+
}

examples/att.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
att: TARGET_CU=0
2+
SIMD_SELECT=0x3
3+
SE_MASK=0xFFFFFFFF
4+
BUFFER_SIZE=192
5+
ISA_CAPTURE_MODE=2

examples/demo.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
#!/usr/bin/env python
2+
3+
import mlir.extras.types as T
4+
import numpy as np
5+
from hip import hip
6+
from mlir.ir import InsertionPoint
7+
8+
from mlir.extras.ast.canonicalize import canonicalize
9+
from mlir.extras.context import RAIIMLIRContextModule
10+
from mlir.extras.dialects.ext import memref, scf, arith, rocdl
11+
12+
# noinspection PyUnresolvedReferences
13+
from mlir.extras.dialects.ext.gpu import (
14+
all_reduce,
15+
wait,
16+
thread_attr as thread,
17+
block_idx,
18+
thread_idx,
19+
block_dim,
20+
GPUModuleMeta,
21+
func as gpu_func,
22+
set_container_module,
23+
launch,
24+
all_reduce_,
25+
module,
26+
get_compile_object_bytes,
27+
lds_space,
28+
)
29+
from mlir.extras.runtime.passes import run_pipeline, Pipeline
30+
31+
# noinspection PyUnresolvedReferences
32+
from util import hip_check, launch_kernel, hip_synchronize
33+
34+
35+
def time_to_gflops(time_ms, N):
36+
return 1e-6 * (N * N * N * 2 + 3 * N * N) // time_ms
37+
38+
39+
# just so it doesn't get DCE'd by black/reformat
40+
# TypeError: 'mlir._mlir_libs._mlir.ir.BlockArgument' object is not subscriptable
41+
_ = memref
42+
43+
ctx = RAIIMLIRContextModule()
44+
set_container_module(ctx.module)
45+
46+
props = hip.hipDeviceProp_t()
47+
hip_check(hip.hipGetDeviceProperties(props, 0))
48+
arch = props.gcnArchName.decode()
49+
50+
51+
# just a default attr - actual target is set blow
52+
@module("kernels", [f'#rocdl.target<abi = "500">'])
53+
def gpu_module():
54+
pass
55+
56+
57+
ip = InsertionPoint.at_block_begin(gpu_module.regions[0].blocks[0])
58+
ip.__enter__()
59+
60+
set_container_module(ctx.module)
61+
62+
v_len = 16
63+
M, K, N = 1024, 1024, 1024
64+
v16f16 = T.vector(v_len, T.f16())
65+
66+
67+
@gpu_func
68+
@canonicalize(using=scf.canonicalizer)
69+
def smol_matmul(
70+
a: T.memref(M, K, T.f16()),
71+
b: T.memref(K, N, T.f16()),
72+
c: T.memref(M, N, T.f16()),
73+
):
74+
lIdx = thread_idx.x
75+
# a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and b
76+
# a_frag will store one column of the 16x16 matrix A tile
77+
# b_frag will store one row of the 16x16 matrix B tile
78+
a_frag = arith.constant(np.full([v_len], 0.0, np.float16), v16f16)
79+
b_frag = arith.constant(np.full([v_len], 0.0, np.float16), v16f16)
80+
c_frag = arith.constant(np.full([v_len], 0.0, np.float16), v16f16)
81+
82+
# lane is (0-31) mod 16 instead of 0-31 due to matrix replication in RDNA 3
83+
lane = lIdx % v_len
84+
for ele in range(v_len):
85+
b_frag[ele] = b[ele, lane]
86+
a_frag[ele] = a[lane, ele]
87+
# a_frag, b_frag = yield a_frag, b_frag
88+
89+
# call the WMMA intrinsic
90+
false = arith.constant(False, T.bool())
91+
c_frag = rocdl.wmma_f16_16x16x16_f16(v16f16, [a_frag, b_frag, c_frag, false])
92+
93+
for ele in range(v_len // 2):
94+
r = ele * 2 + (lIdx // v_len)
95+
# store results from unpacked c_frag output
96+
c[r, lane] = c_frag[ele * 2]
97+
98+
99+
props = hip.hipDeviceProp_t()
100+
hip_check(hip.hipGetDeviceProperties(props, 0))
101+
arch = props.gcnArchName.decode().split(":")[0]
102+
103+
104+
@module("naive", [f'#rocdl.target<chip = "{arch}", abi = "500">'])
105+
def gpu_module():
106+
smol_matmul.emit()
107+
108+
109+
ip.__exit__(None, None, None)
110+
111+
lowered_module = run_pipeline(
112+
gpu_module,
113+
Pipeline()
114+
.Gpu(Pipeline().convert_gpu_to_rocdl(use_bare_ptr_memref_call_conv=True))
115+
.rocdl_attach_target(chip=arch, abi="500", O=0)
116+
.gpu_to_llvm()
117+
.lower_to_llvm()
118+
.ensure_debug_info_scope_on_llvm_func(emission_kind="Full")
119+
.gpu_module_to_binary(),
120+
)
121+
122+
hsaco = get_compile_object_bytes(lowered_module)
123+
hip_module = hip_check(hip.hipModuleLoadData(hsaco))
124+
function = hip_check(
125+
hip.hipModuleGetFunction(hip_module, smol_matmul.__name__.encode())
126+
)
127+
128+
a_h = np.random.randint(0, 10, (M, K)).astype(dtype=np.float16)
129+
b_h = np.random.randint(0, 10, (K, N)).astype(dtype=np.float16)
130+
c_h = -3 * np.ones((M, N), dtype=np.float16)
131+
132+
a_num_bytes = a_h.size * a_h.itemsize
133+
b_num_bytes = b_h.size * b_h.itemsize
134+
c_num_bytes = c_h.size * c_h.itemsize
135+
136+
a_d = hip_check(hip.hipMalloc(a_num_bytes))
137+
b_d = hip_check(hip.hipMalloc(b_num_bytes))
138+
c_d = hip_check(hip.hipMalloc(c_num_bytes))
139+
140+
hip_check(hip.hipMemcpy(a_d, a_h, a_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))
141+
hip_check(hip.hipMemcpy(b_d, b_h, b_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))
142+
hip_check(hip.hipMemcpy(c_d, c_h, c_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))
143+
144+
gridX = 32
145+
gridY = 32
146+
gridZ = 1
147+
warp_size = 32
148+
num_warps = 1
149+
stream = 0
150+
shared_memory = 0
151+
152+
launch_kernel(
153+
function.as_c_void_p(),
154+
gridX,
155+
gridY,
156+
gridZ,
157+
warp_size,
158+
num_warps,
159+
1,
160+
stream,
161+
shared_memory,
162+
a_d,
163+
b_d,
164+
c_d,
165+
)
166+
167+
correct = a_h @ b_h
168+
assert np.allclose(c_h, -3.0)
169+
assert not np.allclose(correct, c_h)
170+
hip_check(hip.hipMemcpy(c_h, c_d, c_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost))
171+
172+
# if not np.allclose(c_h, correct):
173+
# with np.printoptions(threshold=np.inf, linewidth=200):
174+
# print(correct)
175+
# print(c_h)
176+
# assert False
177+
178+
hip_check(hip.hipFree(a_d))
179+
hip_check(hip.hipFree(b_d))
180+
hip_check(hip.hipFree(c_d))
181+
182+
hip_check(hip.hipModuleUnload(hip_module))

examples/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
websockets
2+
matplotlib

examples/rocprof.sh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#!/bin/bash
2+
3+
#set -eux
4+
5+
export PATH=/opt/rocm-6.5.0/bin:$PATH
6+
export PYTHONPATH=/home/mlevental/dev_projects/mlir-python-extras
7+
export OUTPUT_PATH=$PWD
8+
export ROCPROF_ATT_LIBRARY_PATH=/opt/rocm-6.5.0/att-decoder-v3-3.0.0-Linux/lib
9+
10+
rm -rf traces
11+
#rocprofv2 --kernel-trace /home/mlevental/dev_projects/mlir-python-extras/examples/demo.py
12+
#rocprofv2 -i att.txt --kernel-trace --plugin att auto --mode file,csv -d traces/ /home/mlevental/dev_projects/mlir-python-extras/examples/demo.py
13+
/opt/rocm-6.5.0/bin/rocprofv3 -i att.json -d traces -- /home/mlevental/dev_projects/mlir-python-extras/examples/demo.py
14+
../../ROCProfiler-ATT-Viewer-amd-staging/cmake-build-debug/ATTViewer traces/ui*

mlir/extras/runtime/passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def run_pipeline(
3131
print_pipeline=False,
3232
verify=True,
3333
):
34-
module = Module.parse(str(module))
34+
module = Module.parse(module.operation.get_asm(enable_debug_info=True))
3535

3636
if isinstance(pipeline, Pipeline):
3737
pipeline = str(pipeline)

tests/test_gpu.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,7 @@ def mat_product_kernel(
758758

759759
props = hip.hipDeviceProp_t()
760760
hip_check(hip.hipGetDeviceProperties(props, 0))
761-
arch = props.gcnArchName.decode()
761+
arch = props.gcnArchName.decode().split(":")[0]
762762

763763
@module("naive", [f'#rocdl.target<chip = "{arch}", abi = "500">'])
764764
def gpu_module():
@@ -869,7 +869,7 @@ def mat_product_kernel(
869869

870870
props = hip.hipDeviceProp_t()
871871
hip_check(hip.hipGetDeviceProperties(props, 0))
872-
arch = props.gcnArchName.decode()
872+
arch = props.gcnArchName.decode().split(":")[0]
873873

874874
@module("naive", [f'#rocdl.target<chip = "{arch}", abi = "500">'])
875875
def gpu_module():
@@ -996,7 +996,7 @@ def smol_matmul(
996996

997997
props = hip.hipDeviceProp_t()
998998
hip_check(hip.hipGetDeviceProperties(props, 0))
999-
arch = props.gcnArchName.decode()
999+
arch = props.gcnArchName.decode().split(":")[0]
10001000

10011001
@module("naive", [f'#rocdl.target<chip = "{arch}", abi = "500">'])
10021002
def gpu_module():
@@ -1104,7 +1104,7 @@ def all_bank_conflicts(A: T.memref(M, M, T.f32()), B: T.memref(M, M, T.f32())):
11041104

11051105
props = hip.hipDeviceProp_t()
11061106
hip_check(hip.hipGetDeviceProperties(props, 0))
1107-
arch = props.gcnArchName.decode()
1107+
arch = props.gcnArchName.decode().split(":")[0]
11081108

11091109
@module("naive", [f'#rocdl.target<chip = "{arch}", abi = "500">'])
11101110
def gpu_module():
@@ -1239,7 +1239,7 @@ def smol_matmul(
12391239

12401240
props = hip.hipDeviceProp_t()
12411241
hip_check(hip.hipGetDeviceProperties(props, 0))
1242-
arch = props.gcnArchName.decode()
1242+
arch = props.gcnArchName.decode().split(":")[0]
12431243

12441244
@module("naive", [f'#rocdl.target<chip = "{arch}", abi = "500">'])
12451245
def gpu_module():

0 commit comments

Comments
 (0)