|
15 | 15 | from mlir.dialects.memref import cast
|
16 | 16 |
|
17 | 17 | from mlir.extras.ast.canonicalize import canonicalize
|
18 |
| -from mlir.extras.dialects.ext import arith, scf, memref |
| 18 | +from mlir.extras.dialects.ext import arith, scf, memref, rocdl |
19 | 19 | from mlir.extras.dialects.ext.func import func
|
20 | 20 |
|
21 | 21 | # noinspection PyUnresolvedReferences
|
|
36 | 36 | )
|
37 | 37 | from mlir.extras.dialects.ext.llvm import llvm_ptr_t
|
38 | 38 | from mlir.extras.dialects.ext.scf import forall, in_parallel_
|
39 |
| -from mlir.extras.dialects.ext.vector import outer, load, shuffle |
| 39 | +from mlir.extras.dialects.ext.vector import outer, load, shuffle, print_ |
40 | 40 | from mlir.extras.runtime.passes import run_pipeline, Pipeline
|
41 | 41 |
|
42 | 42 | # noinspection PyUnresolvedReferences
|
@@ -1193,3 +1193,135 @@ def gpu_module():
|
1193 | 1193 | times[all_bank_conflicts.__name__] /= runs
|
1194 | 1194 | for k, v in times.items():
|
1195 | 1195 | print(f"{k}: {v:.3e}ms")
|
| 1196 | + |
| 1197 | + |
| 1198 | +# https://gpuopen.com/learn/wmma_on_rdna3/ |
| 1199 | +@pytest.mark.skipif(hip_bindings_not_installed(), reason="hip not installed") |
| 1200 | +def test_amdgpu_vector_wmma(ctx: MLIRContext): |
| 1201 | + from hip import hip |
| 1202 | + |
| 1203 | + set_container_module(ctx.module) |
| 1204 | + |
| 1205 | + v_len = 16 |
| 1206 | + M, K, N = v_len, v_len, v_len |
| 1207 | + v16f16 = T.vector(v_len, T.f16()) |
| 1208 | + |
| 1209 | + @gpu_func |
| 1210 | + @canonicalize(using=scf.canonicalizer) |
| 1211 | + def smol_matmul( |
| 1212 | + a: T.memref(M, K, T.f16()), |
| 1213 | + b: T.memref(K, N, T.f16()), |
| 1214 | + c: T.memref(M, N, T.f16()), |
| 1215 | + ): |
| 1216 | + lIdx = thread_idx.x |
| 1217 | + # a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and b |
| 1218 | + # a_frag will store one column of the 16x16 matrix A tile |
| 1219 | + # b_frag will store one row of the 16x16 matrix B tile |
| 1220 | + a_frag = arith.constant(np.full([v_len], 0.0, np.float16), v16f16) |
| 1221 | + b_frag = arith.constant(np.full([v_len], 0.0, np.float16), v16f16) |
| 1222 | + c_frag = arith.constant(np.full([v_len], 0.0, np.float16), v16f16) |
| 1223 | + |
| 1224 | + # lane is (0-31) mod 16 instead of 0-31 due to matrix replication in RDNA 3 |
| 1225 | + lane = lIdx % v_len |
| 1226 | + for ele, [a_frag, b_frag], _ in scf.range_(v_len, iter_args=[a_frag, b_frag]): |
| 1227 | + b_frag[ele] = b[ele, lane] |
| 1228 | + a_frag[ele] = a[lane, ele] |
| 1229 | + a_frag, b_frag = yield a_frag, b_frag |
| 1230 | + |
| 1231 | + # call the WMMA intrinsic |
| 1232 | + false = arith.constant(False, T.bool()) |
| 1233 | + c_frag = rocdl.wmma_f16_16x16x16_f16(v16f16, [a_frag, b_frag, c_frag, false]) |
| 1234 | + |
| 1235 | + for ele in scf.range_(v_len // 2): |
| 1236 | + r = ele * 2 + (lIdx // v_len) |
| 1237 | + # store results from unpacked c_frag output |
| 1238 | + c[r, lane] = c_frag[ele * 2] |
| 1239 | + |
| 1240 | + props = hip.hipDeviceProp_t() |
| 1241 | + hip_check(hip.hipGetDeviceProperties(props, 0)) |
| 1242 | + arch = props.gcnArchName.decode() |
| 1243 | + |
| 1244 | + @module("naive", [f'#rocdl.target<chip = "{arch}", abi = "500">']) |
| 1245 | + def gpu_module(): |
| 1246 | + smol_matmul.emit() |
| 1247 | + |
| 1248 | + print(gpu_module) |
| 1249 | + |
| 1250 | + lowered_module = run_pipeline( |
| 1251 | + gpu_module, |
| 1252 | + Pipeline() |
| 1253 | + .Gpu(Pipeline().convert_gpu_to_rocdl(use_bare_ptr_memref_call_conv=True)) |
| 1254 | + .rocdl_attach_target(chip=arch, abi="500") |
| 1255 | + .gpu_to_llvm() |
| 1256 | + .lower_to_llvm() |
| 1257 | + .gpu_module_to_binary(), |
| 1258 | + ) |
| 1259 | + |
| 1260 | + hsaco = get_compile_object_bytes(lowered_module) |
| 1261 | + hip_module = hip_check(hip.hipModuleLoadData(hsaco)) |
| 1262 | + function = hip_check( |
| 1263 | + hip.hipModuleGetFunction(hip_module, smol_matmul.__name__.encode()) |
| 1264 | + ) |
| 1265 | + |
| 1266 | + a_h = np.random.randint(0, 10, (M, K)).astype(dtype=np.float16) |
| 1267 | + b_h = np.random.randint(0, 10, (K, N)).astype(dtype=np.float16) |
| 1268 | + c_h = -3 * np.ones((M, N), dtype=np.float16) |
| 1269 | + |
| 1270 | + a_num_bytes = a_h.size * a_h.itemsize |
| 1271 | + b_num_bytes = b_h.size * b_h.itemsize |
| 1272 | + c_num_bytes = c_h.size * c_h.itemsize |
| 1273 | + |
| 1274 | + a_d = hip_check(hip.hipMalloc(a_num_bytes)) |
| 1275 | + b_d = hip_check(hip.hipMalloc(b_num_bytes)) |
| 1276 | + c_d = hip_check(hip.hipMalloc(c_num_bytes)) |
| 1277 | + |
| 1278 | + hip_check( |
| 1279 | + hip.hipMemcpy(a_d, a_h, a_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice) |
| 1280 | + ) |
| 1281 | + hip_check( |
| 1282 | + hip.hipMemcpy(b_d, b_h, b_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice) |
| 1283 | + ) |
| 1284 | + hip_check( |
| 1285 | + hip.hipMemcpy(c_d, c_h, c_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice) |
| 1286 | + ) |
| 1287 | + |
| 1288 | + gridX = 1 |
| 1289 | + gridY = 1 |
| 1290 | + gridZ = 1 |
| 1291 | + warp_size = 32 |
| 1292 | + num_warps = 1 |
| 1293 | + stream = 0 |
| 1294 | + shared_memory = 0 |
| 1295 | + |
| 1296 | + launch_kernel( |
| 1297 | + function.as_c_void_p(), |
| 1298 | + gridX, |
| 1299 | + gridY, |
| 1300 | + gridZ, |
| 1301 | + warp_size, |
| 1302 | + num_warps, |
| 1303 | + stream, |
| 1304 | + shared_memory, |
| 1305 | + a_d, |
| 1306 | + b_d, |
| 1307 | + c_d, |
| 1308 | + ) |
| 1309 | + |
| 1310 | + correct = a_h @ b_h |
| 1311 | + assert np.allclose(c_h, -3.0) |
| 1312 | + assert not np.allclose(correct, c_h) |
| 1313 | + hip_check( |
| 1314 | + hip.hipMemcpy(c_h, c_d, c_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost) |
| 1315 | + ) |
| 1316 | + |
| 1317 | + if not np.allclose(c_h, correct): |
| 1318 | + with np.printoptions(threshold=np.inf, linewidth=200): |
| 1319 | + print(correct) |
| 1320 | + print(c_h) |
| 1321 | + assert False |
| 1322 | + |
| 1323 | + hip_check(hip.hipFree(a_d)) |
| 1324 | + hip_check(hip.hipFree(b_d)) |
| 1325 | + hip_check(hip.hipFree(c_d)) |
| 1326 | + |
| 1327 | + hip_check(hip.hipModuleUnload(hip_module)) |
0 commit comments