Skip to content

Commit 6f9f150

Browse files
authored
add rocdl.wmma_f16_16x16x16_f16 (#140)
1 parent f2cff77 commit 6f9f150

File tree

2 files changed

+194
-2
lines changed

2 files changed

+194
-2
lines changed

mlir/extras/dialects/ext/rocdl.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from . import arith
2+
from ...util import get_user_code_loc
3+
4+
from ....dialects._ods_common import (
5+
_dispatch_mixed_values,
6+
_cext,
7+
get_op_results_or_values,
8+
get_default_loc_context,
9+
get_op_result_or_op_results,
10+
get_default_loc_context,
11+
segmented_accessor,
12+
)
13+
14+
# noinspection PyUnresolvedReferences
15+
from ....dialects.rocdl import *
16+
from ....dialects._rocdl_ops_gen import _Dialect
17+
from .... import ir
18+
19+
20+
@_cext.register_operation(_Dialect, replace=True)
21+
class WMMA_F16_16X16X16_F16(ir.OpView):
22+
OPERATION_NAME = "rocdl.wmma.f16.16x16x16.f16"
23+
24+
_ODS_REGIONS = (0, True)
25+
26+
def __init__(self, res, args, *, loc=None, ip=None):
27+
operands = []
28+
results = []
29+
attributes = {}
30+
regions = None
31+
operands.extend(get_op_results_or_values(args))
32+
_ods_context = get_default_loc_context(loc)
33+
results.append(res)
34+
_ods_successors = None
35+
super().__init__(
36+
self.OPERATION_NAME,
37+
self._ODS_REGIONS,
38+
self._ODS_OPERAND_SEGMENTS,
39+
self._ODS_RESULT_SEGMENTS,
40+
attributes=attributes,
41+
results=results,
42+
operands=operands,
43+
successors=_ods_successors,
44+
regions=regions,
45+
loc=loc,
46+
ip=ip,
47+
)
48+
49+
@property
50+
def args(self):
51+
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
52+
return self.operation.operands[0 : 0 + _ods_variadic_group_length]
53+
54+
@property
55+
def res(self):
56+
return self.operation.results[0]
57+
58+
59+
def wmma_f16_16x16x16_f16(res, args, *, loc=None, ip=None) -> ir.Value:
60+
return WMMA_F16_16X16X16_F16(res=res, args=args, loc=loc, ip=ip).result

tests/test_gpu.py

Lines changed: 134 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from mlir.dialects.memref import cast
1616

1717
from mlir.extras.ast.canonicalize import canonicalize
18-
from mlir.extras.dialects.ext import arith, scf, memref
18+
from mlir.extras.dialects.ext import arith, scf, memref, rocdl
1919
from mlir.extras.dialects.ext.func import func
2020

2121
# noinspection PyUnresolvedReferences
@@ -36,7 +36,7 @@
3636
)
3737
from mlir.extras.dialects.ext.llvm import llvm_ptr_t
3838
from mlir.extras.dialects.ext.scf import forall, in_parallel_
39-
from mlir.extras.dialects.ext.vector import outer, load, shuffle
39+
from mlir.extras.dialects.ext.vector import outer, load, shuffle, print_
4040
from mlir.extras.runtime.passes import run_pipeline, Pipeline
4141

4242
# noinspection PyUnresolvedReferences
@@ -1193,3 +1193,135 @@ def gpu_module():
11931193
times[all_bank_conflicts.__name__] /= runs
11941194
for k, v in times.items():
11951195
print(f"{k}: {v:.3e}ms")
1196+
1197+
1198+
# https://gpuopen.com/learn/wmma_on_rdna3/
1199+
@pytest.mark.skipif(hip_bindings_not_installed(), reason="hip not installed")
1200+
def test_amdgpu_vector_wmma(ctx: MLIRContext):
1201+
from hip import hip
1202+
1203+
set_container_module(ctx.module)
1204+
1205+
v_len = 16
1206+
M, K, N = v_len, v_len, v_len
1207+
v16f16 = T.vector(v_len, T.f16())
1208+
1209+
@gpu_func
1210+
@canonicalize(using=scf.canonicalizer)
1211+
def smol_matmul(
1212+
a: T.memref(M, K, T.f16()),
1213+
b: T.memref(K, N, T.f16()),
1214+
c: T.memref(M, N, T.f16()),
1215+
):
1216+
lIdx = thread_idx.x
1217+
# a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and b
1218+
# a_frag will store one column of the 16x16 matrix A tile
1219+
# b_frag will store one row of the 16x16 matrix B tile
1220+
a_frag = arith.constant(np.full([v_len], 0.0, np.float16), v16f16)
1221+
b_frag = arith.constant(np.full([v_len], 0.0, np.float16), v16f16)
1222+
c_frag = arith.constant(np.full([v_len], 0.0, np.float16), v16f16)
1223+
1224+
# lane is (0-31) mod 16 instead of 0-31 due to matrix replication in RDNA 3
1225+
lane = lIdx % v_len
1226+
for ele, [a_frag, b_frag], _ in scf.range_(v_len, iter_args=[a_frag, b_frag]):
1227+
b_frag[ele] = b[ele, lane]
1228+
a_frag[ele] = a[lane, ele]
1229+
a_frag, b_frag = yield a_frag, b_frag
1230+
1231+
# call the WMMA intrinsic
1232+
false = arith.constant(False, T.bool())
1233+
c_frag = rocdl.wmma_f16_16x16x16_f16(v16f16, [a_frag, b_frag, c_frag, false])
1234+
1235+
for ele in scf.range_(v_len // 2):
1236+
r = ele * 2 + (lIdx // v_len)
1237+
# store results from unpacked c_frag output
1238+
c[r, lane] = c_frag[ele * 2]
1239+
1240+
props = hip.hipDeviceProp_t()
1241+
hip_check(hip.hipGetDeviceProperties(props, 0))
1242+
arch = props.gcnArchName.decode()
1243+
1244+
@module("naive", [f'#rocdl.target<chip = "{arch}", abi = "500">'])
1245+
def gpu_module():
1246+
smol_matmul.emit()
1247+
1248+
print(gpu_module)
1249+
1250+
lowered_module = run_pipeline(
1251+
gpu_module,
1252+
Pipeline()
1253+
.Gpu(Pipeline().convert_gpu_to_rocdl(use_bare_ptr_memref_call_conv=True))
1254+
.rocdl_attach_target(chip=arch, abi="500")
1255+
.gpu_to_llvm()
1256+
.lower_to_llvm()
1257+
.gpu_module_to_binary(),
1258+
)
1259+
1260+
hsaco = get_compile_object_bytes(lowered_module)
1261+
hip_module = hip_check(hip.hipModuleLoadData(hsaco))
1262+
function = hip_check(
1263+
hip.hipModuleGetFunction(hip_module, smol_matmul.__name__.encode())
1264+
)
1265+
1266+
a_h = np.random.randint(0, 10, (M, K)).astype(dtype=np.float16)
1267+
b_h = np.random.randint(0, 10, (K, N)).astype(dtype=np.float16)
1268+
c_h = -3 * np.ones((M, N), dtype=np.float16)
1269+
1270+
a_num_bytes = a_h.size * a_h.itemsize
1271+
b_num_bytes = b_h.size * b_h.itemsize
1272+
c_num_bytes = c_h.size * c_h.itemsize
1273+
1274+
a_d = hip_check(hip.hipMalloc(a_num_bytes))
1275+
b_d = hip_check(hip.hipMalloc(b_num_bytes))
1276+
c_d = hip_check(hip.hipMalloc(c_num_bytes))
1277+
1278+
hip_check(
1279+
hip.hipMemcpy(a_d, a_h, a_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)
1280+
)
1281+
hip_check(
1282+
hip.hipMemcpy(b_d, b_h, b_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)
1283+
)
1284+
hip_check(
1285+
hip.hipMemcpy(c_d, c_h, c_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)
1286+
)
1287+
1288+
gridX = 1
1289+
gridY = 1
1290+
gridZ = 1
1291+
warp_size = 32
1292+
num_warps = 1
1293+
stream = 0
1294+
shared_memory = 0
1295+
1296+
launch_kernel(
1297+
function.as_c_void_p(),
1298+
gridX,
1299+
gridY,
1300+
gridZ,
1301+
warp_size,
1302+
num_warps,
1303+
stream,
1304+
shared_memory,
1305+
a_d,
1306+
b_d,
1307+
c_d,
1308+
)
1309+
1310+
correct = a_h @ b_h
1311+
assert np.allclose(c_h, -3.0)
1312+
assert not np.allclose(correct, c_h)
1313+
hip_check(
1314+
hip.hipMemcpy(c_h, c_d, c_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost)
1315+
)
1316+
1317+
if not np.allclose(c_h, correct):
1318+
with np.printoptions(threshold=np.inf, linewidth=200):
1319+
print(correct)
1320+
print(c_h)
1321+
assert False
1322+
1323+
hip_check(hip.hipFree(a_d))
1324+
hip_check(hip.hipFree(b_d))
1325+
hip_check(hip.hipFree(c_d))
1326+
1327+
hip_check(hip.hipModuleUnload(hip_module))

0 commit comments

Comments
 (0)