Skip to content

Commit 8274554

Browse files
committed
working demo
1 parent 46a7793 commit 8274554

File tree

8 files changed

+89
-51
lines changed

8 files changed

+89
-51
lines changed

examples/att.txt

Lines changed: 0 additions & 5 deletions
This file was deleted.

examples/requirements.txt

Lines changed: 0 additions & 2 deletions
This file was deleted.

examples/rocprof.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@ echo "Script directory: $SCRIPT_DIR"
88

99
export PATH=/opt/rocm-6.5.0/bin:$PATH
1010
export PYTHONPATH=$SCRIPT_DIR/..
11-
export OUTPUT_PATH=$SCRIPT_DIR
1211
export ROCPROF_ATT_LIBRARY_PATH=/opt/rocm-6.5.0/att-decoder-v3-3.0.0-Linux/lib
1312
export ATT_VIEWER=../../ROCProfiler-ATT-Viewer-amd-staging/cmake-build-debug/ATTViewer
1413

1514

1615
rm -rf traces
17-
/opt/rocm-6.5.0/bin/rocprofv3 -i att.json -d traces -o demo_trace -- $SCRIPT_DIR/demo.py
16+
rocprofv3 -i att.json -d traces -o demo_trace -- $SCRIPT_DIR/schedule_barriers.py
1817

1918
for ui in $(ls $SCRIPT_DIR/traces) ; do
2019
if [ -d $SCRIPT_DIR/traces/$ui ]; then

examples/demo.py renamed to examples/schedule_barriers.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def gpu_module():
5858
set_container_module(ctx.module)
5959

6060
v_len = 16
61-
M, K, N = 16, 16, 16
61+
M, K, N = 512, 512, 512
6262
TILE_SIZE = BK = 16
6363
dtype = T.f16()
6464
np_dtype = np.float16
@@ -78,23 +78,27 @@ def kernel(
7878

7979
row = block_idx.y * TILE_SIZE + thread_idx.y
8080
col = block_idx.x * TILE_SIZE + thread_idx.x
81+
lane = thread_idx.x % v_len
8182
# gpu.printf("(%ld, %ld)\n", row, col)
8283
# vector.print_(source=row)
8384

8485
sum = arith.constant(np.full([v_len], 0.0, np_dtype), v16)
85-
for t, sum, _ in scf.range_(0, N, BK, iter_args=[sum]):
86-
Bs[thread_idx.y, thread_idx.x] = B[col, thread_idx.y + t]
87-
As[thread_idx.y, thread_idx.x] = A[row, thread_idx.x + t]
8886

87+
Bs[thread_idx.y, thread_idx.x] = B[col, thread_idx.y + 0]
88+
As[thread_idx.y, thread_idx.x] = A[row, thread_idx.x + 0]
89+
90+
for t, sum, _ in scf.range_(BK, N + BK, BK, iter_args=[sum]):
8991
gpu.barrier()
9092

91-
lane = thread_idx.x % v_len
9293
a_frag = As @ vector.load(v16) @ [lane, 0]
9394
b_frag = Bs @ vector.load(v16) @ [lane, 0]
9495

95-
# call the WMMA intrinsic
96-
false = arith.constant(False, T.bool())
97-
sum = rocdl.wmma_f16_16x16x16_f16(v16, [a_frag, b_frag, sum, false])
96+
sum = rocdl.wmma_f16_16x16x16_f16(a_frag, b_frag, sum)
97+
98+
if arith.index_cast(t, T.i32()) < N:
99+
Bs[thread_idx.y, thread_idx.x] = B[col, thread_idx.y + t]
100+
As[thread_idx.y, thread_idx.x] = A[row, thread_idx.x + t]
101+
98102
sum = yield sum
99103

100104
C[row, col] = sum[2 * (row // 2)]
@@ -142,18 +146,25 @@ def gpu_module():
142146
hip_module = hip_check(hip.hipModuleLoadData(hsaco))
143147
function = hip_check(hip.hipModuleGetFunction(hip_module, kernel.__name__.encode()))
144148

145-
a_h = np.random.randint(0, 10, (M, K)).astype(dtype=np_dtype)
146-
b_h = np.random.randint(0, 10, (K, N)).astype(dtype=np_dtype)
147-
# a_h = np.ones((M, K)).astype(dtype=np_dtype)
148-
# b_h = np.ones((K, N)).astype(dtype=np_dtype)
149-
c_h = 0 * np.ones((M, N), dtype=np_dtype)
149+
# a_h = np.random.randint(1, 5, (M, K)).astype(dtype=np_dtype)
150+
# b_h = np.random.randint(1, 5, (K, N)).astype(dtype=np_dtype)
150151

152+
# a_h = np.random.rand(M, K).astype(np_dtype)
153+
# b_h = np.random.rand(K, N).astype(np_dtype)
154+
155+
a_h = 3 * np.ones((M, K)).astype(dtype=np_dtype)
156+
a_h[0 : M // 2, 0 : K // 2] = 0
157+
a_h[M // 2 : M, K // 2 : K] = 1
158+
b_h = 2 * np.ones((K, N)).astype(dtype=np_dtype)
159+
b_h[0 : K // 2, 0 : N // 2] = 2
160+
b_h[K // 2 : K, N // 2 : N] = 3
161+
162+
c_h = 0 * np.ones((M, N), dtype=np.float32)
151163
for k in range(K):
152-
a = a_h[:, k]
153-
b = b_h[k, :]
164+
a = a_h.astype(np.float32)[:, k]
165+
b = b_h.astype(np.float32)[k, :]
154166
c_h += np.outer(a, b)
155-
156-
assert np.allclose(a_h @ b_h, c_h)
167+
assert np.allclose(a_h.astype(np.float32) @ b_h.astype(np.float32), c_h)
157168

158169
c_h = -3 * np.ones((M, N), dtype=np_dtype)
159170
a_num_bytes = a_h.size * a_h.itemsize
@@ -210,10 +221,12 @@ def gpu_module():
210221

211222
if not np.allclose(c_h, correct):
212223
with np.printoptions(threshold=np.inf, linewidth=np.inf):
213-
print("correct\n", correct)
214-
print("c_h\n", c_h)
224+
# print("correct\n", correct)
225+
# print("c_h\n", c_h)
215226
print("off by atol", np.max(np.abs(correct - c_h)))
216227
print("off by rtol", np.max(np.abs(correct - c_h) / correct))
228+
print("num incorrect", np.sum(np.abs(correct - c_h) != 0))
229+
print("fraction incorrect", np.sum(np.abs(correct - c_h) != 0) / (M * N))
217230

218231

219232
hip_check(hip.hipFree(a_d))

mlir/extras/dialects/ext/gpu.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,43 +49,43 @@ def __get__(self, owner_self, owner_cls):
4949
class block_idx:
5050
@classproperty
5151
def x(cls):
52-
return _block_id("x")
52+
return _block_id("x", loc=get_user_code_loc())
5353

5454
@classproperty
5555
def y(cls):
56-
return _block_id("y")
56+
return _block_id("y", loc=get_user_code_loc())
5757

5858
@classproperty
5959
def z(cls):
60-
return _block_id("z")
60+
return _block_id("z", loc=get_user_code_loc())
6161

6262

6363
class block_dim:
6464
@classproperty
6565
def x(cls):
66-
return _block_dim("x")
66+
return _block_dim("x", loc=get_user_code_loc())
6767

6868
@classproperty
6969
def y(cls):
70-
return _block_dim("y")
70+
return _block_dim("y", loc=get_user_code_loc())
7171

7272
@classproperty
7373
def z(cls):
74-
return _block_dim("z")
74+
return _block_dim("z", loc=get_user_code_loc())
7575

7676

7777
class thread_idx:
7878
@classproperty
7979
def x(cls):
80-
return _thread_id("x")
80+
return _thread_id("x", loc=get_user_code_loc())
8181

8282
@classproperty
8383
def y(cls):
84-
return _thread_id("y")
84+
return _thread_id("y", loc=get_user_code_loc())
8585

8686
@classproperty
8787
def z(cls):
88-
return _thread_id("z")
88+
return _thread_id("z", loc=get_user_code_loc())
8989

9090

9191
def thread_id():
@@ -222,6 +222,8 @@ def __init__(
222222
loc=None,
223223
ip=None,
224224
):
225+
if loc is None:
226+
loc = get_user_code_loc()
225227
super().__init__(
226228
function_type=function_type,
227229
arg_attrs=arg_attrs,
@@ -301,10 +303,10 @@ def launch_(
301303
):
302304
if loc is None:
303305
loc = get_user_code_loc()
304-
for size in [grid_size, block_size]:
305-
for i, s in enumerate(size):
306-
if isinstance(s, int):
307-
size[i] = constant(s, index=True)
306+
for size in [grid_size, block_size]:
307+
for i, s in enumerate(size):
308+
if isinstance(s, int):
309+
size[i] = constant(s, index=True)
308310
launch_op = LaunchOp(
309311
grid_size,
310312
block_size,
@@ -371,13 +373,16 @@ def __call__(
371373
async_dependencies=None,
372374
dynamic_shared_memory_size: Optional[Value] = None,
373375
stream=None,
376+
loc=None,
377+
ip=None,
374378
):
375379
for size in [grid_size, block_size]:
376380
for i, s in enumerate(size):
377381
if isinstance(s, int):
378382
size[i] = constant(s, index=True)
379383

380-
loc = get_user_code_loc()
384+
if loc is None:
385+
loc = get_user_code_loc()
381386
return get_op_result_or_op_results(
382387
LaunchFuncOp(
383388
(
@@ -469,6 +474,8 @@ def all_reduce__(value: Value, *, op=None, uniform=None, loc=None, ip=None):
469474

470475

471476
def all_reduce_(value: Value, *, op=None, uniform=None, loc=None, ip=None):
477+
if loc is None:
478+
loc = get_user_code_loc()
472479
return get_op_result_or_op_results(
473480
all_reduce__(value, op=op, uniform=uniform, loc=loc, ip=ip)
474481
)
@@ -577,15 +584,18 @@ def get_compile_object_bytes(compiled_module):
577584
_printf = printf
578585

579586

580-
def printf(format, *args):
581-
loc = get_user_code_loc()
582-
return _printf(format=format, args=args, loc=loc)
587+
def printf(format, *args, loc=None, ip=None):
588+
if loc is None:
589+
loc = get_user_code_loc()
590+
return _printf(format=format, args=args, loc=loc, ip=ip)
583591

584592

585593
_dynamic_shared_memory = dynamic_shared_memory
586594

587595

588596
def dynamic_shared_memory(*, int=False, loc=None, ip=None):
597+
if loc is None:
598+
loc = get_user_code_loc()
589599
return _dynamic_shared_memory(
590600
T.memref(
591601
ShapedType.get_dynamic_size(),
@@ -611,3 +621,10 @@ def memset(dst, value, async_dependencies=None, *, loc=None, ip=None):
611621
if isinstance(value, (int, float, bool)):
612622
value = constant(value, type=dst.type.element_type)
613623
return _memset(async_token, async_dependencies, dst, value, loc=loc, ip=ip)
624+
625+
626+
def barrier(*, loc=None, ip=None):
627+
if loc is None:
628+
loc = get_user_code_loc()
629+
630+
return BarrierOp(loc=loc, ip=ip)

mlir/extras/dialects/ext/memref.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,8 @@ def _canonicalize_start_stop(start, stop, step):
281281
elif isinstance(start, int) and isinstance(stop, int):
282282
return stop - start
283283

284+
raise NotImplementedError
285+
284286

285287
def _subview(
286288
mem: MemRef,
@@ -362,6 +364,8 @@ def _copy_to_subview(
362364

363365

364366
def dim(source, index, *, loc=None, ip=None):
367+
if loc is None:
368+
loc = get_user_code_loc()
365369
if isinstance(index, int):
366370
index = constant(index, index=True)
367371
return _dim(source=source, index=index, loc=loc, ip=ip)
@@ -412,7 +416,9 @@ def global_(
412416
).opview
413417

414418

415-
def view(source, shape, dtype=None, shift=0, memory_space=None):
419+
def view(source, shape, dtype=None, shift=0, memory_space=None, loc=None, ip=None):
420+
if loc is None:
421+
loc = get_user_code_loc()
416422
if dtype is None:
417423
dtype = source.type.element_type
418424
byte_width_dtype = dtype.width // 8
@@ -425,6 +431,8 @@ def view(source, shape, dtype=None, shift=0, memory_space=None):
425431
source,
426432
byte_shift,
427433
[],
434+
loc=loc,
435+
ip=ip,
428436
)
429437

430438

@@ -434,6 +442,8 @@ def view(source, shape, dtype=None, shift=0, memory_space=None):
434442
def get_global(
435443
name_or_global, *, name=None, global_=None, result=None, loc=None, ip=None
436444
):
445+
if loc is None:
446+
loc = get_user_code_loc()
437447
if isinstance(name_or_global, GlobalOp):
438448
global_ = name_or_global
439449
elif isinstance(name_or_global, str):

mlir/extras/dialects/ext/rocdl.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class WMMA_F16_16X16X16_F16(ir.OpView):
2424
_ODS_REGIONS = (0, True)
2525

2626
def __init__(self, res, args, *, loc=None, ip=None):
27+
if loc is None:
28+
loc = get_user_code_loc()
2729
operands = []
2830
results = []
2931
attributes = {}
@@ -56,5 +58,11 @@ def res(self):
5658
return self.operation.results[0]
5759

5860

59-
def wmma_f16_16x16x16_f16(res, args, *, loc=None, ip=None) -> ir.Value:
60-
return WMMA_F16_16X16X16_F16(res=res, args=args, loc=loc, ip=ip).result
61+
def wmma_f16_16x16x16_f16(A, B, C, *, OPSEL=False, loc=None, ip=None) -> ir.Value:
62+
if loc is None:
63+
loc = get_user_code_loc()
64+
65+
opsel = arith.constant(OPSEL, ir.IntegerType.get_signless(1))
66+
args = [A, B, C, opsel]
67+
v16 = ir.VectorType.get((16,), ir.F16Type.get())
68+
return WMMA_F16_16X16X16_F16(res=v16, args=args, loc=loc, ip=ip).result

tests/test_gpu.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,9 +1228,7 @@ def smol_matmul(
12281228
a_frag[ele] = a[lane, ele]
12291229
a_frag, b_frag = yield a_frag, b_frag
12301230

1231-
# call the WMMA intrinsic
1232-
false = arith.constant(False, T.bool())
1233-
c_frag = rocdl.wmma_f16_16x16x16_f16(v16f16, [a_frag, b_frag, c_frag, false])
1231+
c_frag = rocdl.wmma_f16_16x16x16_f16(a_frag, b_frag, c_frag)
12341232

12351233
for i in scf.range_(v_len):
12361234
gpu.printf("(%02ld, %02ld, %02ld), %f\n", lIdx, lane, i, c_frag[i])

0 commit comments

Comments
 (0)