Skip to content

Commit 1cded3f

Browse files
committed
working demo
1 parent c8fdf3c commit 1cded3f

File tree

2 files changed

+38
-17
lines changed

2 files changed

+38
-17
lines changed

examples/demo.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def gpu_module():
5858
set_container_module(ctx.module)
5959

6060
v_len = 16
61-
M, K, N = 512, 512, 512
61+
M, K, N = 16, 16, 16
6262
TILE_SIZE = BK = 16
6363
dtype = T.f16()
6464
np_dtype = np.float16
@@ -78,24 +78,26 @@ def kernel(
7878

7979
row = block_idx.y * TILE_SIZE + thread_idx.y
8080
col = block_idx.x * TILE_SIZE + thread_idx.x
81+
# gpu.printf("(%ld, %ld)\n", row, col)
82+
# vector.print_(source=row)
8183

8284
sum = arith.constant(np.full([v_len], 0.0, np_dtype), v16)
8385
for t, sum, _ in scf.range_(0, N, BK, iter_args=[sum]):
84-
Bs[thread_idx.y, thread_idx.x] = B[thread_idx.y + t, col]
86+
Bs[thread_idx.y, thread_idx.x] = B[col, thread_idx.y + t]
8587
As[thread_idx.y, thread_idx.x] = A[row, thread_idx.x + t]
8688

8789
gpu.barrier()
8890

89-
a_frag = As @ vector.load(v16) @ [thread_idx.y, 0]
90-
b_frag = Bs @ vector.load(v16) @ [0, thread_idx.x]
91+
lane = thread_idx.x % v_len
92+
a_frag = As @ vector.load(v16) @ [lane, 0]
93+
b_frag = Bs @ vector.load(v16) @ [lane, 0]
94+
95+
# call the WMMA intrinsic
9196
false = arith.constant(False, T.bool())
9297
sum = rocdl.wmma_f16_16x16x16_f16(v16, [a_frag, b_frag, sum, false])
93-
94-
gpu.barrier()
95-
9698
sum = yield sum
9799

98-
C[row, col] = sum
100+
C[row, col] = sum[2 * (row // 2)]
99101

100102

101103
props = hip.hipDeviceProp_t()
@@ -110,13 +112,21 @@ def gpu_module():
110112

111113
ip.__exit__(None, None, None)
112114

115+
# gpu_module = run_pipeline(gpu_module, Pipeline().cse())
116+
# print(gpu_module)
117+
113118
O = 3
114119
output_format = "binary"
115120

116121
lowered_module = run_pipeline(
117122
gpu_module,
118123
Pipeline()
119-
.Gpu(Pipeline().convert_gpu_to_rocdl(use_bare_ptr_memref_call_conv=True))
124+
.Gpu(
125+
Pipeline().convert_gpu_to_rocdl(
126+
use_bare_ptr_memref_call_conv=True,
127+
runtime="HIP",
128+
)
129+
)
120130
.rocdl_attach_target(chip=arch, abi="500", O=O)
121131
.gpu_to_llvm()
122132
.lower_to_llvm()
@@ -132,12 +142,20 @@ def gpu_module():
132142
hip_module = hip_check(hip.hipModuleLoadData(hsaco))
133143
function = hip_check(hip.hipModuleGetFunction(hip_module, kernel.__name__.encode()))
134144

135-
# a_h = np.random.randint(0, 10, (M, K)).astype(dtype=np_dtype)
136-
# b_h = np.random.randint(0, 10, (K, N)).astype(dtype=np_dtype)
137-
a_h = np.ones((M, K)).astype(dtype=np_dtype)
138-
b_h = np.ones((K, N)).astype(dtype=np_dtype)
139-
c_h = -3 * np.ones((M, N), dtype=np_dtype)
145+
a_h = np.random.randint(0, 10, (M, K)).astype(dtype=np_dtype)
146+
b_h = np.random.randint(0, 10, (K, N)).astype(dtype=np_dtype)
147+
# a_h = np.ones((M, K)).astype(dtype=np_dtype)
148+
# b_h = np.ones((K, N)).astype(dtype=np_dtype)
149+
c_h = 0 * np.ones((M, N), dtype=np_dtype)
150+
151+
for k in range(K):
152+
a = a_h[:, k]
153+
b = b_h[k, :]
154+
c_h += np.outer(a, b)
155+
156+
assert np.allclose(a_h @ b_h, c_h)
140157

158+
c_h = -3 * np.ones((M, N), dtype=np_dtype)
141159
a_num_bytes = a_h.size * a_h.itemsize
142160
b_num_bytes = b_h.size * b_h.itemsize
143161
c_num_bytes = c_h.size * c_h.itemsize
@@ -190,14 +208,14 @@ def gpu_module():
190208
assert not np.allclose(correct, c_h)
191209
hip_check(hip.hipMemcpy(c_h, c_d, c_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost))
192210

193-
194211
if not np.allclose(c_h, correct):
195212
with np.printoptions(threshold=np.inf, linewidth=np.inf):
196-
# print("correct", correct)
197-
# print("c_h", c_h)
213+
print("correct\n", correct)
214+
print("c_h\n", c_h)
198215
print("off by atol", np.max(np.abs(correct - c_h)))
199216
print("off by rtol", np.max(np.abs(correct - c_h) / correct))
200217

218+
201219
hip_check(hip.hipFree(a_d))
202220
hip_check(hip.hipFree(b_d))
203221
hip_check(hip.hipFree(c_d))

tests/test_gpu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,9 @@ def smol_matmul(
12301230

12311231
c_frag = rocdl.wmma_f16_16x16x16_f16(a_frag, b_frag, c_frag)
12321232

1233+
for i in scf.range_(v_len):
1234+
gpu.printf("(%02ld, %02ld, %02ld), %f\n", lIdx, lane, i, c_frag[i])
1235+
12331236
for i in scf.range_(v_len):
12341237
gpu.printf("(%02ld, %02ld, %02ld), %f\n", lIdx, lane, i, c_frag[i])
12351238

0 commit comments

Comments
 (0)