Skip to content

Commit ca63bbb

Browse files
authored
parameterize e2e sugar correctly (#134)
1 parent 448d4a5 commit ca63bbb

File tree

3 files changed

+178
-30
lines changed

3 files changed

+178
-30
lines changed

tests/test_gpu.py

Lines changed: 121 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import ctypes
21
import platform
2+
import random
33
import sys
44
import tempfile
5+
import time
56
from textwrap import dedent
67

78
import mlir.extras.types as T
@@ -40,7 +41,7 @@
4041

4142
# noinspection PyUnresolvedReferences
4243
from mlir.extras.testing import mlir_ctx as ctx, filecheck, MLIRContext
43-
from util import hip_bindings_not_installed, hip_check, launch_kernel
44+
from util import hip_bindings_not_installed, hip_check, launch_kernel, hip_synchronize
4445

4546
# needed since the fix isn't defined here nor conftest.py
4647
pytest.mark.usefixtures("ctx")
@@ -962,6 +963,7 @@ def test_amdgpu_vector(ctx: MLIRContext):
962963

963964
scale = 2
964965
M, K, N = 2 * scale, 4 * scale, 6 * scale
966+
tz_a, tz_b, tz_c = [2, 2, 2]
965967
v2f32 = T.vector(2, T.f32())
966968

967969
@gpu_func
@@ -972,11 +974,11 @@ def smol_matmul(
972974
):
973975
cst = arith.constant(np.full([4], 0.0, np.float32), T.vector(4, T.f32()))
974976
cst_0 = arith.constant(
975-
np.full([2, 2], 0.0, np.float32), T.vector(2, 2, T.f32())
977+
np.full([tz_a, tz_b], 0.0, np.float32), T.vector(tz_a, tz_b, T.f32())
976978
)
977-
for i, C, v0 in scf.range_(0, M, 2, iter_args=[C]):
978-
for j, C, v1 in scf.range_(0, N, 2, iter_args=[C]):
979-
for k, C, v2 in scf.range_(0, K, 2, iter_args=[C]):
979+
for i, C, v0 in scf.range_(0, M, tz_a, iter_args=[C]):
980+
for j, C, v1 in scf.range_(0, N, tz_b, iter_args=[C]):
981+
for k, C, v2 in scf.range_(0, K, tz_c, iter_args=[C]):
980982
cst[0::1] = A @ load(v2f32) @ [i, k]
981983
cst[2::1] = A @ load(v2f32) @ [i + 1, k]
982984
cst_0[0] = C @ load(v2f32) @ [i, j]
@@ -1078,3 +1080,116 @@ def gpu_module():
10781080
hip_check(hip.hipFree(c_d))
10791081

10801082
hip_check(hip.hipModuleUnload(hip_module))
1083+
1084+
1085+
@pytest.mark.skipif(hip_bindings_not_installed(), reason="hip not installed")
1086+
def test_amdgpu_bank_conflicts(ctx: MLIRContext):
1087+
from hip import hip
1088+
1089+
set_container_module(ctx.module)
1090+
1091+
M = 1024
1092+
1093+
@gpu_func
1094+
def no_bank_conflicts(A: T.memref(M, M, T.f32()), B: T.memref(M, M, T.f32())):
1095+
for i in range(M):
1096+
a = A[i, thread_idx.x]
1097+
B[i, thread_idx.x] = a * a
1098+
1099+
@gpu_func
1100+
def all_bank_conflicts(A: T.memref(M, M, T.f32()), B: T.memref(M, M, T.f32())):
1101+
for i in range(M):
1102+
a = A[i, thread_idx.x]
1103+
B[thread_idx.x, i] = a * a
1104+
1105+
props = hip.hipDeviceProp_t()
1106+
hip_check(hip.hipGetDeviceProperties(props, 0))
1107+
arch = props.gcnArchName.decode()
1108+
1109+
@module("naive", [f'#rocdl.target<chip = "{arch}">'])
1110+
def gpu_module():
1111+
no_bank_conflicts.emit()
1112+
all_bank_conflicts.emit()
1113+
1114+
lowered_module = run_pipeline(
1115+
gpu_module,
1116+
Pipeline()
1117+
.Gpu(Pipeline().convert_gpu_to_rocdl(use_bare_ptr_memref_call_conv=True))
1118+
.rocdl_attach_target(chip=arch)
1119+
.gpu_to_llvm()
1120+
.lower_to_llvm()
1121+
.gpu_module_to_binary(),
1122+
)
1123+
1124+
hsaco = get_compile_object_bytes(lowered_module)
1125+
hip_module = hip_check(hip.hipModuleLoadData(hsaco))
1126+
1127+
a_h = np.arange(M).astype(dtype=np.float32)
1128+
a_h = np.tile(a_h, (M, 1))
1129+
b_h = np.zeros((M, M), dtype=np.float32)
1130+
1131+
a_num_bytes = a_h.size * a_h.itemsize
1132+
b_num_bytes = b_h.size * b_h.itemsize
1133+
1134+
a_d = hip_check(hip.hipMalloc(a_num_bytes))
1135+
b_d = hip_check(hip.hipMalloc(b_num_bytes))
1136+
1137+
gridX = max(M // 32, 1)
1138+
gridY = max(M // 8, 1)
1139+
gridZ = 1
1140+
warp_size = 32
1141+
num_warps = 8
1142+
stream = 0
1143+
shared_memory = 0
1144+
1145+
times = {
1146+
no_bank_conflicts.__name__: 0,
1147+
all_bank_conflicts.__name__: 0,
1148+
}
1149+
runs = 10
1150+
start, stop = hip.hipEventCreate(), hip.hipEventCreate()
1151+
for i in range(runs):
1152+
kernels = [no_bank_conflicts, all_bank_conflicts]
1153+
random.shuffle(kernels)
1154+
for kernel in kernels:
1155+
hip_check(
1156+
hip.hipMemcpy(
1157+
a_d, a_h, a_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice
1158+
)
1159+
)
1160+
hip_check(
1161+
hip.hipMemcpy(
1162+
b_d, b_h, b_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice
1163+
)
1164+
)
1165+
function = hip_check(
1166+
hip.hipModuleGetFunction(hip_module, kernel.__name__.encode())
1167+
)
1168+
1169+
start = time.monotonic()
1170+
launch_kernel(
1171+
function.as_c_void_p(),
1172+
gridX,
1173+
gridY,
1174+
gridZ,
1175+
warp_size,
1176+
num_warps,
1177+
stream,
1178+
shared_memory,
1179+
a_d,
1180+
b_d,
1181+
)
1182+
hip_synchronize()
1183+
if i > 0:
1184+
times[kernel.__name__] += time.monotonic() - start
1185+
1186+
hip_check(
1187+
hip.hipMemcpy(
1188+
b_h, b_d, b_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost
1189+
)
1190+
)
1191+
1192+
times[no_bank_conflicts.__name__] /= runs
1193+
times[all_bank_conflicts.__name__] /= runs
1194+
for k, v in times.items():
1195+
print(f"{k}: {v:.3e}ms")

tests/test_vector.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
ShapedType,
2121
AffineMap,
2222
AffineConstantExpr,
23+
Attribute,
24+
ArrayAttr,
2325
)
2426

2527
from mlir.extras import types as T
@@ -28,6 +30,7 @@
2830
# you need this to register the memref value caster
2931
# noinspection PyUnresolvedReferences
3032
from mlir.extras.dialects.ext import arith, linalg, memref, transform, vector, scf, func
33+
from mlir.dialects import affine
3134
from mlir.extras.dialects.ext.vector import outer, shuffle, load
3235
from mlir.extras.dialects.ext.transform import (
3336
get_parent_op,
@@ -71,7 +74,7 @@ def mod_transform():
7174
@named_sequence("main", [any_op_t()], [])
7275
def main(module_op: any_op_t()):
7376
matmul = match(module_op, ops=["linalg.matmul"])
74-
tiled_matmul, (_, _, inner_loop) = tile_to_scf_for(matmul, sizes=[2, 2, 2])
77+
tiled_matmul, (_, _, inner_loop) = tile_to_scf_for(matmul, sizes=[4, 3, 8])
7578
transform.structured.vectorize_children_and_apply_patterns(
7679
get_parent_op(
7780
transform_any_op_t(), tiled_matmul, isolated_from_above=True
@@ -136,40 +139,64 @@ def pats():
136139
assert np.allclose(A @ B, C)
137140

138141

139-
def test_e2e_sugar(ctx: MLIRContext):
142+
testdata = (
143+
(2, 2, 2),
144+
(2, 3, 2),
145+
(2, 3, 4),
146+
(2, 4, 8),
147+
(2, 4, 16),
148+
(4, 4, 16),
149+
(4, 8, 16),
150+
)
151+
152+
153+
@pytest.mark.parametrize("tz_a, tz_b, tz_c", testdata)
154+
def test_e2e_sugar(ctx: MLIRContext, tz_a, tz_b, tz_c):
140155
backend = LLVMJITBackend()
141156

142157
scale = 16
143158
M, K, N = 2 * scale, 4 * scale, 6 * scale
144-
v2f32 = T.vector(2, T.f32())
159+
160+
vaf32 = T.vector(tz_a, T.f32())
161+
vbf32 = T.vector(tz_b, T.f32())
162+
vcf32 = T.vector(tz_c, T.f32())
163+
vacrossbf32 = T.vector(tz_a, tz_b, T.f32())
164+
vatimescf32 = T.vector(tz_a * tz_c, T.f32())
165+
166+
shuffle_mask = np.arange(tz_a * tz_c).reshape(tz_a, tz_c).reshape((-1,), order="F")
145167

146168
@func.func(emit=True)
147169
def smol_matmul(
148170
A: T.memref(M, K, T.f32()),
149171
B: T.memref(K, N, T.f32()),
150172
C: T.memref(M, N, T.f32()),
151173
):
152-
cst = arith.constant(np.full([4], 0.0, np.float32), T.vector(4, T.f32()))
153-
cst_0 = arith.constant(
154-
np.full([2, 2], 0.0, np.float32), T.vector(2, 2, T.f32())
155-
)
156-
for i, C, v0 in scf.range_(0, M, 2, iter_args=[C]):
157-
for j, C, v1 in scf.range_(0, N, 2, iter_args=[C]):
158-
for k, C, v2 in scf.range_(0, K, 2, iter_args=[C]):
159-
cst[0::1] = A @ load(v2f32) @ [i, k]
160-
cst[2::1] = A @ load(v2f32) @ [i + 1, k]
161-
cst_0[0] = C @ load(v2f32) @ [i, j]
162-
cst_0[1] = C @ load(v2f32) @ [i + 1, j]
163-
cst = cst @ shuffle(mask=[0, 2, 1, 3]) @ cst
164-
165-
v19 = cst[0:2:1] @ outer(cst_0) @ (B @ load(v2f32) @ [k, j])
166-
v20 = cst[2:4:1] @ outer(v19) @ (B @ load(v2f32) @ [k + 1, j])
167-
C[i, j] = v20[0]
168-
C[i + 1, j] = v20[1]
169-
170-
scf.yield_(C)
171-
scf.yield_(v2)
172-
scf.yield_(v1)
174+
cst = arith.constant(np.full([tz_a * tz_c], 0.0, np.float32), vatimescf32)
175+
acc = arith.constant(np.full([tz_a, tz_b], 0.0, np.float32), vacrossbf32)
176+
177+
for m, C, v0 in scf.range_(0, M, tz_a, iter_args=[C]):
178+
for n, C, v1 in scf.range_(0, N, tz_b, iter_args=[C]):
179+
for k, C, v2 in scf.range_(0, K, tz_c, iter_args=[C]):
180+
for i in range(tz_a):
181+
cst[tz_c * i :: 1] = A @ load(vcf32) @ [m + i, k]
182+
cst = cst @ shuffle(mask=shuffle_mask) @ cst
183+
184+
for i in range(tz_a):
185+
acc[i] = C @ load(vbf32) @ [m + i, n]
186+
187+
for i in range(tz_c):
188+
acc = (
189+
(cst[i * tz_a : (i + 1) * tz_a : 1])
190+
@ outer(acc)
191+
@ (B @ load(vbf32) @ [k + i, n])
192+
)
193+
194+
for i in range(tz_a):
195+
C[m + i, n] = acc[i]
196+
197+
scf.yield_(results_=[C])
198+
scf.yield_(results_=[v2])
199+
scf.yield_(results_=[v1])
173200

174201
compiled_module = backend.compile(
175202
ctx.module,

tests/util.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def hip_check(call_result):
5050
return result
5151

5252

53+
def hip_synchronize():
54+
from hip import hip
55+
56+
hip.hipDeviceSynchronize()
57+
58+
5359
def hip_bindings_not_installed():
5460
try:
5561
from hip import hip

0 commit comments

Comments
 (0)