15
15
)
16
16
from mlir .extras .dialects .ext import arith , memref , gpu , scf
17
17
from mlir .extras .dialects .ext .gpu import (
18
- block_id ,
19
- thread_id ,
18
+ block_idx ,
19
+ thread_idx ,
20
20
block_dim ,
21
21
get_compile_object_bytes ,
22
22
)
30
30
_ = memref
31
31
32
32
33
- def build_cuda_func (compiled_module , kernel_name = "mat_product_kernel " ):
33
+ def build_cuda_func (compiled_module , kernel_name = "naive " ):
34
34
ptx = get_compile_object_bytes (compiled_module )
35
35
mod = Module ()
36
36
mod .load (ptx )
37
37
return mod .get_function (kernel_name )
38
38
39
39
40
+ def print_ptx (compiled_module ):
41
+ ptx = get_compile_object_bytes (compiled_module )
42
+ print (ptx .decode ())
43
+
44
+
45
+ def compile_module (module , enable_ir_printing = False , print_ptx_ = False ):
46
+ if enable_ir_printing :
47
+ print_ptx_ = True
48
+ mod = run_pipeline (
49
+ module ,
50
+ Pipeline ().add_pass (
51
+ "gpu-lower-to-nvvm-pipeline" ,
52
+ # https://github.com/llvm/llvm-project/blob/ace69e6b942b8fa7e610d70be2a92e801ceea481/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h#L18
53
+ ** {
54
+ "cubin-chip" : "sm_80" ,
55
+ "cubin-features" : "+ptx83" ,
56
+ "cubin-format" : "isa" ,
57
+ "kernel-bare-ptr-calling-convention" : "1" ,
58
+ "opt-level" : "2" ,
59
+ # "cubin-format": "fatbin",
60
+ # "cubin-format": "bin",
61
+ },
62
+ ),
63
+ enable_ir_printing = enable_ir_printing ,
64
+ )
65
+ if print_ptx_ :
66
+ print_ptx (mod )
67
+
68
+ return mod
69
+
70
+
40
71
@contextlib .contextmanager
41
72
def time_cuda ():
42
73
start_gpu = cp .cuda .Event ()
@@ -50,80 +81,254 @@ def time_cuda():
50
81
51
82
@gpu .func
52
83
@canonicalize (using = (arith .canonicalizer , scf .canonicalizer ))
53
- def mat_product_kernel [
84
+ def sgemm_naive [
85
+ M , K , N , dtype
86
+ ](A : T .memref (M , K , dtype ), B : T .memref (K , N , dtype ), C : T .memref (M , N , dtype )):
87
+ one = arith .constant (1.0 , type = dtype )
88
+ tmp = arith .constant (0 , type = dtype )
89
+
90
+ # this is from the example and it's basically a mistake
91
+ # it increments the row for each adjacent thread id
92
+ # uncomment the print to see
93
+ r = block_dim .x * block_idx .x + thread_idx .x
94
+ c = block_dim .y * block_idx .y + thread_idx .y
95
+ # tid = gpu.thread_id()
96
+ # gpu.printf("tid: %ld: (%ld, %ld)\n", tid, r, c)
97
+
98
+ for k , tmp in range_ (K , iter_args = [tmp ]):
99
+ tmp += A [r , k ] * B [k , c ]
100
+ tmp = yield tmp
101
+ C [r , c ] = tmp + one
102
+
103
+
104
+ @gpu .func
105
+ @canonicalize (using = (arith .canonicalizer , scf .canonicalizer ))
106
+ def sgemm_naive_row_order [
54
107
M , K , N , dtype
55
108
](A : T .memref (M , K , dtype ), B : T .memref (K , N , dtype ), C : T .memref (M , N , dtype )):
56
- x = block_dim .x * block_id .x + thread_id .x
57
- y = block_dim .y * block_id .y + thread_id .y
109
+ one = arith .constant (1.0 , type = dtype )
110
+ tmp = arith .constant (0 , type = dtype )
111
+
112
+ # increment along the cols (ie preserve row-order access)
113
+ c = block_dim .x * block_idx .x + thread_idx .x
114
+ r = block_dim .y * block_idx .y + thread_idx .y
115
+ # tid = gpu.thread_id()
116
+ # gpu.printf("tid: %ld: (%ld, %ld)\n", tid, r, c)
117
+
118
+ for k , tmp in range_ (K , iter_args = [tmp ]):
119
+ tmp += A [r , k ] * B [k , c ]
120
+ tmp = yield tmp
121
+ C [r , c ] = tmp + one
122
+
123
+
124
+ @gpu .func
125
+ @canonicalize (using = (arith .canonicalizer , scf .canonicalizer ))
126
+ def sgemm_coalesce [
127
+ M , K , N , dtype , BLOCK_SIZE
128
+ ](A : T .memref (M , K , dtype ), B : T .memref (K , N , dtype ), C : T .memref (M , N , dtype )):
129
+
130
+ tid = gpu .thread_id ()
131
+ # this is actually floordiv
132
+ r = block_idx .x * BLOCK_SIZE + (tid / BLOCK_SIZE )
133
+ c = block_idx .y * BLOCK_SIZE + (tid % BLOCK_SIZE )
134
+ # gpu.printf("tid: %ld: (%ld, %ld)\n", tid, r, c)
135
+
136
+ one = arith .constant (1.0 , type = dtype )
137
+ tmp = arith .constant (0 , type = dtype )
138
+
139
+ for k , tmp in range_ (K , iter_args = [tmp ]):
140
+ # k varies per core while c varies with tid
141
+ # apparently that's fine? i guess all the loads can happen
142
+ # because there's enough scratch per SM to prefetch all the data each thread needs?
143
+ tmp += A [r , k ] * B [k , c ]
144
+ tmp = yield tmp
145
+ C [r , c ] = tmp + one
146
+
147
+
148
+ # So if you try to load something like:
149
+ #
150
+ # B.T:
151
+ #
152
+ # 0 0 0 0 0 0 0 0
153
+ # 1 1 1 1 1 1 1 1
154
+ # 2 2 2 2 2 2 2 2
155
+ #
156
+ # vs
157
+ #
158
+ # B:
159
+ # 0 1 2 3 4 5 6 7 8
160
+ # 0 1 2 3 4 5 6 7 8
161
+ # 0 1 2 3 4 5 6 7 8
162
+ #
163
+ # In B, you are feeding all threads with a single load (say warp can load 8 elements at a time) and then you increment k
164
+ #
165
+ # in B.T, a single load is feeding only a single thread, so others are probably waiting for their load to happen
166
+ # these are the issues by threads:
167
+ #
168
+ # 0: (0, 0), (1, 0), (2, 0)
169
+ # 1: (0, 1), (1, 1), (2, 1)
170
+ # 2: (0, 2), (1, 2), (2, 2)
171
+ #
172
+ # warp recieves these issues:
173
+ #
174
+ # (0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)
175
+ #
176
+ # warp issues coalesced reads:
177
+ #
178
+ # (0, 0:2), (1, 0:2), (2,0:2)
179
+ # so even though the threads have bad memory access pattern
180
+ # the warp has good memory access pattern
181
+ # and since the actual load happens at warp level
182
+ # its good
183
+ @gpu .func
184
+ @canonicalize (using = (arith .canonicalizer , scf .canonicalizer ))
185
+ def sgemm_coalesce_transpose_B [
186
+ M , K , N , dtype , BLOCK_SIZE
187
+ ](A : T .memref (M , K , dtype ), B : T .memref (K , N , dtype ), C : T .memref (M , N , dtype )):
188
+
189
+ tid = gpu .thread_id ()
190
+ r = block_idx .x * BLOCK_SIZE + (tid / BLOCK_SIZE )
191
+ c = block_idx .y * BLOCK_SIZE + (tid % BLOCK_SIZE )
58
192
59
193
one = arith .constant (1.0 , type = dtype )
60
194
tmp = arith .constant (0 , type = dtype )
195
+
61
196
for k , tmp in range_ (K , iter_args = [tmp ]):
62
- tmp += A [x , k ] * B [k , y ]
197
+ # this is slower because c is incremented with each tid
198
+ # so you break memory coalescing
199
+ # but k now being on the row order dim doesn't help?
200
+ tmp += A [r , k ] * B [c , k ]
201
+ tmp = yield tmp
202
+ C [r , c ] = tmp + one
203
+
204
+
205
+ @gpu .func
206
+ @canonicalize (using = (arith .canonicalizer , scf .canonicalizer ))
207
+ def sgemm_shared_mem_block [
208
+ M , K , N , dtype , BLOCK_SIZE
209
+ ](A : T .memref (M , K , dtype ), B : T .memref (K , N , dtype ), C : T .memref (M , N , dtype )):
210
+ # allocate buffer for current block in fast shared mem
211
+ # shared mem is shared between all threads in a block
212
+ base = gpu .dynamic_shared_memory ()
213
+ A_shared = memref .view (base , (BLOCK_SIZE , BLOCK_SIZE ), dtype = dtype )
214
+ B_shared = memref .view (
215
+ base , (BLOCK_SIZE , BLOCK_SIZE ), dtype = dtype , shift = BLOCK_SIZE * BLOCK_SIZE
216
+ )
217
+
218
+ # the inner row & col that we're accessing in this thread
219
+ tid = gpu .thread_id ()
220
+ thread_row = tid / BLOCK_SIZE
221
+ thread_col = tid % BLOCK_SIZE
222
+
223
+ # the output block that we want to compute in this threadblock
224
+ c_row = block_idx .x * BLOCK_SIZE
225
+ c_col = block_idx .y * BLOCK_SIZE
226
+
227
+ one = arith .constant (1.0 , type = dtype )
228
+ tmp = arith .constant (0 , type = dtype )
229
+
230
+ for bk_idx , tmp in range_ (0 , K , BLOCK_SIZE , iter_args = [tmp ]):
231
+ A_ = A [c_row : c_row + BLOCK_SIZE , bk_idx : bk_idx + BLOCK_SIZE ]
232
+ B_ = B [bk_idx : bk_idx + BLOCK_SIZE , c_col : c_col + BLOCK_SIZE ]
233
+
234
+ # Have each thread load one of the elements in A & B
235
+ # Make the threadCol (=threadIdx.x) the consecutive index
236
+ # to allow global memory access coalescing
237
+ A_shared [thread_row , thread_col ] = A_ [thread_row , thread_col ]
238
+ B_shared [thread_row , thread_col ] = B_ [thread_row , thread_col ]
239
+
240
+ # block threads in this block until cache is fully populated
241
+ gpu .barrier ()
242
+
243
+ # execute the dotproduct on the currently cached block
244
+ for k , tmp in range_ (BLOCK_SIZE , iter_args = [tmp ]):
245
+ tmp += A_shared [thread_row , k ] * B_shared [k , thread_col ]
246
+ tmp = yield tmp
247
+
248
+ # need to sync again at the end, to avoid faster threads
249
+ # fetching the next block into the cache before slower threads are done
250
+ gpu .barrier ()
251
+
63
252
tmp = yield tmp
64
- C [x , y ] = tmp + one
253
+
254
+ C_ = C [c_row : c_row + BLOCK_SIZE , c_col : c_col + BLOCK_SIZE ]
255
+ C_ [thread_row , thread_col ] = tmp + one
65
256
66
257
67
- def main (ctx : MLIRContext , M , K , N , BLOCK_SIZE = 32 , repeat_times = 50 ):
258
+ def main (ctx : MLIRContext , M , K , N , BLOCK_SIZE = 32 , repeat_times = None ):
259
+ if repeat_times is None :
260
+ repeat_times = 50
68
261
dtype = T .f32 ()
69
262
npy_dtype = np .float32
70
263
71
264
gpu .set_container_module (ctx .module )
72
265
73
- @gpu .module ("naive " , ["#nvvm.target" ])
74
- def _ ():
75
- mat_product_kernel [M , K , N , dtype ].emit ()
266
+ @gpu .module ("matmul " , ["#nvvm.target" ])
267
+ def matmul_mod ():
268
+ sgemm_shared_mem_block [M , K , N , dtype , BLOCK_SIZE ].emit ()
76
269
77
270
# print(ctx.module)
78
- ctx .module .operation .verify ()
271
+ # print(ctx.module.operation.verify())
272
+ # exit()
79
273
80
- compiled_module = run_pipeline (
81
- ctx .module ,
82
- Pipeline ().add_pass (
83
- "gpu-lower-to-nvvm-pipeline" ,
84
- # https://github.com/llvm/llvm-project/blob/ace69e6b942b8fa7e610d70be2a92e801ceea481/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h#L18
85
- ** {
86
- "cubin-chip" : "sm_80" ,
87
- "cubin-features" : "+ptx83" ,
88
- "cubin-format" : "isa" ,
89
- "kernel-bare-ptr-calling-convention" : "1" ,
90
- # "cubin-format": "fatbin",
91
- # "cubin-format": "bin",
92
- },
93
- ),
94
- )
95
- cuda_func = build_cuda_func (compiled_module )
96
- # print(compiled_module)
274
+ kernel_name = matmul_mod .opview .body .operations [0 ].attributes ["sym_name" ].value
275
+ compiled_module = compile_module (ctx .module )
276
+ cuda_func = build_cuda_func (compiled_module , kernel_name )
97
277
# print_ptx(compiled_module)
98
278
99
279
A = np .random .randint (0 , 10 , (M , K )).astype (npy_dtype )
100
280
B = np .random .randint (0 , 10 , (K , N )).astype (npy_dtype )
101
281
C = np .zeros ((M , N )).astype (npy_dtype )
102
282
103
283
dA = cp .asarray (A )
104
- dB = cp .asarray (B )
284
+ if "transpose_B" in kernel_name :
285
+ dB = cp .asarray (np .ascontiguousarray (B .T ))
286
+ else :
287
+ dB = cp .asarray (B )
105
288
dC = cp .asarray (C )
106
289
290
+ grid_dims = (math .ceil (M / BLOCK_SIZE ), math .ceil (N / BLOCK_SIZE ))
291
+ block_dims = (BLOCK_SIZE , BLOCK_SIZE )
292
+
293
+ if "shared" in kernel_name :
294
+ shared_mem = 2 * BLOCK_SIZE * BLOCK_SIZE * npy_dtype ().nbytes
295
+ else :
296
+ shared_mem = None
297
+
298
+ cuda_func (
299
+ grid_dims ,
300
+ block_dims ,
301
+ (dA .data .ptr , dB .data .ptr , dC .data .ptr ),
302
+ shared_mem = shared_mem ,
303
+ )
304
+ C = cp .asnumpy (dC )
305
+ if not np .array_equal (C , A @ B + 1 ):
306
+ print (A @ B + 1 )
307
+ print (C )
308
+ assert False
309
+ if repeat_times < 1 :
310
+ return
311
+
107
312
with time_cuda () as (start_gpu , end_gpu ):
108
313
for _ in range (repeat_times ):
109
314
cuda_func (
110
- ( math . ceil ( M / BLOCK_SIZE ), math . ceil ( N / BLOCK_SIZE ), 1 ) ,
111
- ( BLOCK_SIZE , BLOCK_SIZE , 1 ) ,
315
+ grid_dims ,
316
+ block_dims ,
112
317
(dA .data .ptr , dB .data .ptr , dC .data .ptr ),
318
+ shared_mem = shared_mem ,
113
319
)
114
320
115
321
t_gpu = cp .cuda .get_elapsed_time (start_gpu , end_gpu )
116
322
117
323
print (f"t_gpu={ t_gpu / repeat_times :.6f} ms" )
118
324
119
- if not cp .array_equal (dC , dA @ dB + 1 ):
120
- print (dA @ dB + 1 )
121
- print (dC )
122
325
326
+ sizes = [128 , 256 , 512 , 1024 ]
327
+ repeats = None
123
328
124
- for s in [ 128 , 256 , 512 , 1024 ] :
329
+ for s in sizes :
125
330
with (
126
331
mlir_mod_ctx () as ctx ,
127
332
# enable_debug()
128
333
):
129
- main (ctx , s , s , s )
334
+ main (ctx , s , s , s , repeat_times = repeats )
0 commit comments