Skip to content

Commit 46a7793

Browse files
committed
working demo
1 parent fa1ee7d commit 46a7793

File tree

3 files changed

+48
-21
lines changed

3 files changed

+48
-21
lines changed

examples/demo.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def gpu_module():
5858
set_container_module(ctx.module)
5959

6060
v_len = 16
61-
M, K, N = 512, 512, 512
61+
M, K, N = 16, 16, 16
6262
TILE_SIZE = BK = 16
6363
dtype = T.f16()
6464
np_dtype = np.float16
@@ -78,24 +78,26 @@ def kernel(
7878

7979
row = block_idx.y * TILE_SIZE + thread_idx.y
8080
col = block_idx.x * TILE_SIZE + thread_idx.x
81+
# gpu.printf("(%ld, %ld)\n", row, col)
82+
# vector.print_(source=row)
8183

8284
sum = arith.constant(np.full([v_len], 0.0, np_dtype), v16)
8385
for t, sum, _ in scf.range_(0, N, BK, iter_args=[sum]):
84-
Bs[thread_idx.y, thread_idx.x] = B[thread_idx.y + t, col]
86+
Bs[thread_idx.y, thread_idx.x] = B[col, thread_idx.y + t]
8587
As[thread_idx.y, thread_idx.x] = A[row, thread_idx.x + t]
8688

8789
gpu.barrier()
8890

89-
a_frag = As @ vector.load(v16) @ [thread_idx.y, 0]
90-
b_frag = Bs @ vector.load(v16) @ [0, thread_idx.x]
91+
lane = thread_idx.x % v_len
92+
a_frag = As @ vector.load(v16) @ [lane, 0]
93+
b_frag = Bs @ vector.load(v16) @ [lane, 0]
94+
95+
# call the WMMA intrinsic
9196
false = arith.constant(False, T.bool())
9297
sum = rocdl.wmma_f16_16x16x16_f16(v16, [a_frag, b_frag, sum, false])
93-
94-
gpu.barrier()
95-
9698
sum = yield sum
9799

98-
C[row, col] = sum
100+
C[row, col] = sum[2 * (row // 2)]
99101

100102

101103
props = hip.hipDeviceProp_t()
@@ -110,13 +112,21 @@ def gpu_module():
110112

111113
ip.__exit__(None, None, None)
112114

115+
# gpu_module = run_pipeline(gpu_module, Pipeline().cse())
116+
# print(gpu_module)
117+
113118
O = 3
114119
output_format = "binary"
115120

116121
lowered_module = run_pipeline(
117122
gpu_module,
118123
Pipeline()
119-
.Gpu(Pipeline().convert_gpu_to_rocdl(use_bare_ptr_memref_call_conv=True))
124+
.Gpu(
125+
Pipeline().convert_gpu_to_rocdl(
126+
use_bare_ptr_memref_call_conv=True,
127+
runtime="HIP",
128+
)
129+
)
120130
.rocdl_attach_target(chip=arch, abi="500", O=O)
121131
.gpu_to_llvm()
122132
.lower_to_llvm()
@@ -132,12 +142,20 @@ def gpu_module():
132142
hip_module = hip_check(hip.hipModuleLoadData(hsaco))
133143
function = hip_check(hip.hipModuleGetFunction(hip_module, kernel.__name__.encode()))
134144

135-
# a_h = np.random.randint(0, 10, (M, K)).astype(dtype=np_dtype)
136-
# b_h = np.random.randint(0, 10, (K, N)).astype(dtype=np_dtype)
137-
a_h = np.ones((M, K)).astype(dtype=np_dtype)
138-
b_h = np.ones((K, N)).astype(dtype=np_dtype)
139-
c_h = -3 * np.ones((M, N), dtype=np_dtype)
145+
a_h = np.random.randint(0, 10, (M, K)).astype(dtype=np_dtype)
146+
b_h = np.random.randint(0, 10, (K, N)).astype(dtype=np_dtype)
147+
# a_h = np.ones((M, K)).astype(dtype=np_dtype)
148+
# b_h = np.ones((K, N)).astype(dtype=np_dtype)
149+
c_h = 0 * np.ones((M, N), dtype=np_dtype)
150+
151+
for k in range(K):
152+
a = a_h[:, k]
153+
b = b_h[k, :]
154+
c_h += np.outer(a, b)
155+
156+
assert np.allclose(a_h @ b_h, c_h)
140157

158+
c_h = -3 * np.ones((M, N), dtype=np_dtype)
141159
a_num_bytes = a_h.size * a_h.itemsize
142160
b_num_bytes = b_h.size * b_h.itemsize
143161
c_num_bytes = c_h.size * c_h.itemsize
@@ -190,14 +208,14 @@ def gpu_module():
190208
assert not np.allclose(correct, c_h)
191209
hip_check(hip.hipMemcpy(c_h, c_d, c_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost))
192210

193-
194211
if not np.allclose(c_h, correct):
195212
with np.printoptions(threshold=np.inf, linewidth=np.inf):
196-
# print("correct", correct)
197-
# print("c_h", c_h)
213+
print("correct\n", correct)
214+
print("c_h\n", c_h)
198215
print("off by atol", np.max(np.abs(correct - c_h)))
199216
print("off by rtol", np.max(np.abs(correct - c_h) / correct))
200217

218+
201219
hip_check(hip.hipFree(a_d))
202220
hip_check(hip.hipFree(b_d))
203221
hip_check(hip.hipFree(c_d))

mlir/extras/dialects/ext/vector.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,7 @@ def shuffle(v1, v2, mask, *, loc=None, ip=None):
282282
_load = load
283283

284284

285-
@Infix
286-
def load(base, indices, result, *, nontemporal=None, loc=None, ip=None):
285+
def load_(base, indices, result, *, nontemporal=None, loc=None, ip=None):
287286
if loc is None:
288287
loc = get_user_code_loc()
289288
for j, i in enumerate(indices):
@@ -297,3 +296,6 @@ def load(base, indices, result, *, nontemporal=None, loc=None, ip=None):
297296
loc=loc,
298297
ip=ip,
299298
).result
299+
300+
301+
load = Infix(load_)

tests/test_gpu.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from mlir.dialects.memref import cast
1616

1717
from mlir.extras.ast.canonicalize import canonicalize
18-
from mlir.extras.dialects.ext import arith, scf, memref, rocdl
18+
from mlir.extras.dialects.ext import arith, scf, memref, rocdl, gpu
1919
from mlir.extras.dialects.ext.func import func
2020

2121
# noinspection PyUnresolvedReferences
@@ -1232,6 +1232,9 @@ def smol_matmul(
12321232
false = arith.constant(False, T.bool())
12331233
c_frag = rocdl.wmma_f16_16x16x16_f16(v16f16, [a_frag, b_frag, c_frag, false])
12341234

1235+
for i in scf.range_(v_len):
1236+
gpu.printf("(%02ld, %02ld, %02ld), %f\n", lIdx, lane, i, c_frag[i])
1237+
12351238
for ele in scf.range_(v_len // 2):
12361239
r = ele * 2 + (lIdx // v_len)
12371240
# store results from unpacked c_frag output
@@ -1250,7 +1253,11 @@ def gpu_module():
12501253
lowered_module = run_pipeline(
12511254
gpu_module,
12521255
Pipeline()
1253-
.Gpu(Pipeline().convert_gpu_to_rocdl(use_bare_ptr_memref_call_conv=True))
1256+
.Gpu(
1257+
Pipeline().convert_gpu_to_rocdl(
1258+
use_bare_ptr_memref_call_conv=True, runtime="HIP"
1259+
)
1260+
)
12541261
.rocdl_attach_target(chip=arch, abi="500")
12551262
.gpu_to_llvm()
12561263
.lower_to_llvm()

0 commit comments

Comments
 (0)