Skip to content

Commit bf31525

Browse files
authored
cuda_matmul_opt_v2 (#84)
1 parent 4cd5119 commit bf31525

17 files changed

+730
-177
lines changed

examples/cuda_matmul_opt.py

Lines changed: 243 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
)
1616
from mlir.extras.dialects.ext import arith, memref, gpu, scf
1717
from mlir.extras.dialects.ext.gpu import (
18-
block_id,
19-
thread_id,
18+
block_idx,
19+
thread_idx,
2020
block_dim,
2121
get_compile_object_bytes,
2222
)
@@ -30,13 +30,44 @@
3030
_ = memref
3131

3232

33-
def build_cuda_func(compiled_module, kernel_name="mat_product_kernel"):
33+
def build_cuda_func(compiled_module, kernel_name="naive"):
3434
ptx = get_compile_object_bytes(compiled_module)
3535
mod = Module()
3636
mod.load(ptx)
3737
return mod.get_function(kernel_name)
3838

3939

40+
def print_ptx(compiled_module):
41+
ptx = get_compile_object_bytes(compiled_module)
42+
print(ptx.decode())
43+
44+
45+
def compile_module(module, enable_ir_printing=False, print_ptx_=False):
46+
if enable_ir_printing:
47+
print_ptx_ = True
48+
mod = run_pipeline(
49+
module,
50+
Pipeline().add_pass(
51+
"gpu-lower-to-nvvm-pipeline",
52+
# https://github.com/llvm/llvm-project/blob/ace69e6b942b8fa7e610d70be2a92e801ceea481/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h#L18
53+
**{
54+
"cubin-chip": "sm_80",
55+
"cubin-features": "+ptx83",
56+
"cubin-format": "isa",
57+
"kernel-bare-ptr-calling-convention": "1",
58+
"opt-level": "2",
59+
# "cubin-format": "fatbin",
60+
# "cubin-format": "bin",
61+
},
62+
),
63+
enable_ir_printing=enable_ir_printing,
64+
)
65+
if print_ptx_:
66+
print_ptx(mod)
67+
68+
return mod
69+
70+
4071
@contextlib.contextmanager
4172
def time_cuda():
4273
start_gpu = cp.cuda.Event()
@@ -50,80 +81,254 @@ def time_cuda():
5081

5182
@gpu.func
5283
@canonicalize(using=(arith.canonicalizer, scf.canonicalizer))
53-
def mat_product_kernel[
84+
def sgemm_naive[
85+
M, K, N, dtype
86+
](A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)):
87+
one = arith.constant(1.0, type=dtype)
88+
tmp = arith.constant(0, type=dtype)
89+
90+
# this is from the example and it's basically a mistake
91+
# it increments the row for each adjacent thread id
92+
# uncomment the print to see
93+
r = block_dim.x * block_idx.x + thread_idx.x
94+
c = block_dim.y * block_idx.y + thread_idx.y
95+
# tid = gpu.thread_id()
96+
# gpu.printf("tid: %ld: (%ld, %ld)\n", tid, r, c)
97+
98+
for k, tmp in range_(K, iter_args=[tmp]):
99+
tmp += A[r, k] * B[k, c]
100+
tmp = yield tmp
101+
C[r, c] = tmp + one
102+
103+
104+
@gpu.func
105+
@canonicalize(using=(arith.canonicalizer, scf.canonicalizer))
106+
def sgemm_naive_row_order[
54107
M, K, N, dtype
55108
](A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)):
56-
x = block_dim.x * block_id.x + thread_id.x
57-
y = block_dim.y * block_id.y + thread_id.y
109+
one = arith.constant(1.0, type=dtype)
110+
tmp = arith.constant(0, type=dtype)
111+
112+
# increment along the cols (ie preserve row-order access)
113+
c = block_dim.x * block_idx.x + thread_idx.x
114+
r = block_dim.y * block_idx.y + thread_idx.y
115+
# tid = gpu.thread_id()
116+
# gpu.printf("tid: %ld: (%ld, %ld)\n", tid, r, c)
117+
118+
for k, tmp in range_(K, iter_args=[tmp]):
119+
tmp += A[r, k] * B[k, c]
120+
tmp = yield tmp
121+
C[r, c] = tmp + one
122+
123+
124+
@gpu.func
125+
@canonicalize(using=(arith.canonicalizer, scf.canonicalizer))
126+
def sgemm_coalesce[
127+
M, K, N, dtype, BLOCK_SIZE
128+
](A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)):
129+
130+
tid = gpu.thread_id()
131+
# this is actually floordiv
132+
r = block_idx.x * BLOCK_SIZE + (tid / BLOCK_SIZE)
133+
c = block_idx.y * BLOCK_SIZE + (tid % BLOCK_SIZE)
134+
# gpu.printf("tid: %ld: (%ld, %ld)\n", tid, r, c)
135+
136+
one = arith.constant(1.0, type=dtype)
137+
tmp = arith.constant(0, type=dtype)
138+
139+
for k, tmp in range_(K, iter_args=[tmp]):
140+
# k varies per core while c varies with tid
141+
# apparently that's fine? i guess all the loads can happen
142+
# because there's enough scratch per SM to prefetch all the data each thread needs?
143+
tmp += A[r, k] * B[k, c]
144+
tmp = yield tmp
145+
C[r, c] = tmp + one
146+
147+
148+
# So if you try to load something like:
149+
#
150+
# B.T:
151+
#
152+
# 0 0 0 0 0 0 0 0
153+
# 1 1 1 1 1 1 1 1
154+
# 2 2 2 2 2 2 2 2
155+
#
156+
# vs
157+
#
158+
# B:
159+
# 0 1 2 3 4 5 6 7 8
160+
# 0 1 2 3 4 5 6 7 8
161+
# 0 1 2 3 4 5 6 7 8
162+
#
163+
# In B, you are feeding all threads with a single load (say warp can load 8 elements at a time) and then you increment k
164+
#
165+
# in B.T, a single load is feeding only a single thread, so others are probably waiting for their load to happen
166+
# these are the issues by threads:
167+
#
168+
# 0: (0, 0), (1, 0), (2, 0)
169+
# 1: (0, 1), (1, 1), (2, 1)
170+
# 2: (0, 2), (1, 2), (2, 2)
171+
#
172+
# warp recieves these issues:
173+
#
174+
# (0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)
175+
#
176+
# warp issues coalesced reads:
177+
#
178+
# (0, 0:2), (1, 0:2), (2,0:2)
179+
# so even though the threads have bad memory access pattern
180+
# the warp has good memory access pattern
181+
# and since the actual load happens at warp level
182+
# its good
183+
@gpu.func
184+
@canonicalize(using=(arith.canonicalizer, scf.canonicalizer))
185+
def sgemm_coalesce_transpose_B[
186+
M, K, N, dtype, BLOCK_SIZE
187+
](A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)):
188+
189+
tid = gpu.thread_id()
190+
r = block_idx.x * BLOCK_SIZE + (tid / BLOCK_SIZE)
191+
c = block_idx.y * BLOCK_SIZE + (tid % BLOCK_SIZE)
58192

59193
one = arith.constant(1.0, type=dtype)
60194
tmp = arith.constant(0, type=dtype)
195+
61196
for k, tmp in range_(K, iter_args=[tmp]):
62-
tmp += A[x, k] * B[k, y]
197+
# this is slower because c is incremented with each tid
198+
# so you break memory coalescing
199+
# but k now being on the row order dim doesn't help?
200+
tmp += A[r, k] * B[c, k]
201+
tmp = yield tmp
202+
C[r, c] = tmp + one
203+
204+
205+
@gpu.func
206+
@canonicalize(using=(arith.canonicalizer, scf.canonicalizer))
207+
def sgemm_shared_mem_block[
208+
M, K, N, dtype, BLOCK_SIZE
209+
](A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)):
210+
# allocate buffer for current block in fast shared mem
211+
# shared mem is shared between all threads in a block
212+
base = gpu.dynamic_shared_memory()
213+
A_shared = memref.view(base, (BLOCK_SIZE, BLOCK_SIZE), dtype=dtype)
214+
B_shared = memref.view(
215+
base, (BLOCK_SIZE, BLOCK_SIZE), dtype=dtype, shift=BLOCK_SIZE * BLOCK_SIZE
216+
)
217+
218+
# the inner row & col that we're accessing in this thread
219+
tid = gpu.thread_id()
220+
thread_row = tid / BLOCK_SIZE
221+
thread_col = tid % BLOCK_SIZE
222+
223+
# the output block that we want to compute in this threadblock
224+
c_row = block_idx.x * BLOCK_SIZE
225+
c_col = block_idx.y * BLOCK_SIZE
226+
227+
one = arith.constant(1.0, type=dtype)
228+
tmp = arith.constant(0, type=dtype)
229+
230+
for bk_idx, tmp in range_(0, K, BLOCK_SIZE, iter_args=[tmp]):
231+
A_ = A[c_row : c_row + BLOCK_SIZE, bk_idx : bk_idx + BLOCK_SIZE]
232+
B_ = B[bk_idx : bk_idx + BLOCK_SIZE, c_col : c_col + BLOCK_SIZE]
233+
234+
# Have each thread load one of the elements in A & B
235+
# Make the threadCol (=threadIdx.x) the consecutive index
236+
# to allow global memory access coalescing
237+
A_shared[thread_row, thread_col] = A_[thread_row, thread_col]
238+
B_shared[thread_row, thread_col] = B_[thread_row, thread_col]
239+
240+
# block threads in this block until cache is fully populated
241+
gpu.barrier()
242+
243+
# execute the dotproduct on the currently cached block
244+
for k, tmp in range_(BLOCK_SIZE, iter_args=[tmp]):
245+
tmp += A_shared[thread_row, k] * B_shared[k, thread_col]
246+
tmp = yield tmp
247+
248+
# need to sync again at the end, to avoid faster threads
249+
# fetching the next block into the cache before slower threads are done
250+
gpu.barrier()
251+
63252
tmp = yield tmp
64-
C[x, y] = tmp + one
253+
254+
C_ = C[c_row : c_row + BLOCK_SIZE, c_col : c_col + BLOCK_SIZE]
255+
C_[thread_row, thread_col] = tmp + one
65256

66257

67-
def main(ctx: MLIRContext, M, K, N, BLOCK_SIZE=32, repeat_times=50):
258+
def main(ctx: MLIRContext, M, K, N, BLOCK_SIZE=32, repeat_times=None):
259+
if repeat_times is None:
260+
repeat_times = 50
68261
dtype = T.f32()
69262
npy_dtype = np.float32
70263

71264
gpu.set_container_module(ctx.module)
72265

73-
@gpu.module("naive", ["#nvvm.target"])
74-
def _():
75-
mat_product_kernel[M, K, N, dtype].emit()
266+
@gpu.module("matmul", ["#nvvm.target"])
267+
def matmul_mod():
268+
sgemm_shared_mem_block[M, K, N, dtype, BLOCK_SIZE].emit()
76269

77270
# print(ctx.module)
78-
ctx.module.operation.verify()
271+
# print(ctx.module.operation.verify())
272+
# exit()
79273

80-
compiled_module = run_pipeline(
81-
ctx.module,
82-
Pipeline().add_pass(
83-
"gpu-lower-to-nvvm-pipeline",
84-
# https://github.com/llvm/llvm-project/blob/ace69e6b942b8fa7e610d70be2a92e801ceea481/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h#L18
85-
**{
86-
"cubin-chip": "sm_80",
87-
"cubin-features": "+ptx83",
88-
"cubin-format": "isa",
89-
"kernel-bare-ptr-calling-convention": "1",
90-
# "cubin-format": "fatbin",
91-
# "cubin-format": "bin",
92-
},
93-
),
94-
)
95-
cuda_func = build_cuda_func(compiled_module)
96-
# print(compiled_module)
274+
kernel_name = matmul_mod.opview.body.operations[0].attributes["sym_name"].value
275+
compiled_module = compile_module(ctx.module)
276+
cuda_func = build_cuda_func(compiled_module, kernel_name)
97277
# print_ptx(compiled_module)
98278

99279
A = np.random.randint(0, 10, (M, K)).astype(npy_dtype)
100280
B = np.random.randint(0, 10, (K, N)).astype(npy_dtype)
101281
C = np.zeros((M, N)).astype(npy_dtype)
102282

103283
dA = cp.asarray(A)
104-
dB = cp.asarray(B)
284+
if "transpose_B" in kernel_name:
285+
dB = cp.asarray(np.ascontiguousarray(B.T))
286+
else:
287+
dB = cp.asarray(B)
105288
dC = cp.asarray(C)
106289

290+
grid_dims = (math.ceil(M / BLOCK_SIZE), math.ceil(N / BLOCK_SIZE))
291+
block_dims = (BLOCK_SIZE, BLOCK_SIZE)
292+
293+
if "shared" in kernel_name:
294+
shared_mem = 2 * BLOCK_SIZE * BLOCK_SIZE * npy_dtype().nbytes
295+
else:
296+
shared_mem = None
297+
298+
cuda_func(
299+
grid_dims,
300+
block_dims,
301+
(dA.data.ptr, dB.data.ptr, dC.data.ptr),
302+
shared_mem=shared_mem,
303+
)
304+
C = cp.asnumpy(dC)
305+
if not np.array_equal(C, A @ B + 1):
306+
print(A @ B + 1)
307+
print(C)
308+
assert False
309+
if repeat_times < 1:
310+
return
311+
107312
with time_cuda() as (start_gpu, end_gpu):
108313
for _ in range(repeat_times):
109314
cuda_func(
110-
(math.ceil(M / BLOCK_SIZE), math.ceil(N / BLOCK_SIZE), 1),
111-
(BLOCK_SIZE, BLOCK_SIZE, 1),
315+
grid_dims,
316+
block_dims,
112317
(dA.data.ptr, dB.data.ptr, dC.data.ptr),
318+
shared_mem=shared_mem,
113319
)
114320

115321
t_gpu = cp.cuda.get_elapsed_time(start_gpu, end_gpu)
116322

117323
print(f"t_gpu={t_gpu / repeat_times:.6f} ms")
118324

119-
if not cp.array_equal(dC, dA @ dB + 1):
120-
print(dA @ dB + 1)
121-
print(dC)
122325

326+
sizes = [128, 256, 512, 1024]
327+
repeats = None
123328

124-
for s in [128, 256, 512, 1024]:
329+
for s in sizes:
125330
with (
126331
mlir_mod_ctx() as ctx,
127332
# enable_debug()
128333
):
129-
main(ctx, s, s, s)
334+
main(ctx, s, s, s, repeat_times=repeats)

mlir/extras/ast/canonicalize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import astunparse
1515
from bytecode import ConcreteBytecode
1616

17-
from ..ast.util import get_module_cst
17+
from ..ast.util import get_module_cst, set_lineno
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -69,6 +69,7 @@ def insert_closed_vars(f, module):
6969
),
7070
body=[],
7171
decorator_list=[],
72+
type_params=[],
7273
)
7374
for var in f.__code__.co_freevars:
7475
enclosing_mod.body.append(
@@ -77,6 +78,9 @@ def insert_closed_vars(f, module):
7778
value=ast.Constant(None, kind="None"),
7879
)
7980
)
81+
enclosing_mod = set_lineno(enclosing_mod, module.body[0].lineno)
82+
enclosing_mod = ast.fix_missing_locations(enclosing_mod)
83+
8084
enclosing_mod.body.extend(module.body)
8185
module.body = [enclosing_mod]
8286
return module

mlir/extras/ast/util.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ def ast_call(name, args=None, keywords=None):
3232

3333

3434
def get_module_cst(f):
35-
lines, _lnum = inspect.getsourcelines(f)
36-
f_src = dedent("".join(list(dropwhile(lambda l: l.startswith("@"), lines))))
35+
f_src = dedent(inspect.getsource(f))
3736
tree = ast.parse(f_src)
3837
assert isinstance(
3938
tree.body[0], ast.FunctionDef

mlir/extras/dialects/ext/arith.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,8 @@ def _binary_op(
337337
elif _is_integer_like_type(lhs.dtype):
338338
# TODO(max): this needs to all be regularized
339339
if "div" in op.lower() or "rem" in op.lower():
340-
if not lhs.dtype.is_signless:
340+
# TODO(max): should be using index ops here
341+
if not _is_index_type(lhs.dtype) and not lhs.dtype.is_signless:
341342
raise ValueError(f"{op.lower()}i not supported for {lhs=}")
342343
if op == "Floordiv":
343344
op = "FloorDiv"

0 commit comments

Comments
 (0)