Skip to content

Commit b2f7105

Browse files
mjulian31vchuravyDilumAluthge
authored
performant matmul example for KA (#208)
Co-authored-by: Valentin Churavy <[email protected]> Co-authored-by: Dilum Aluthge <[email protected]>
1 parent 050bb9c commit b2f7105

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed

examples/performant_matmul.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
using KernelAbstractions
2+
using StaticArrays
3+
using Test
4+
5+
const TILE_DIM = 32
6+
7+
@kernel function coalesced_matmul_kernel!(output, @Const(input1), @Const(input2), N, R, M,
8+
::Val{BANK}=Val(1)) where BANK
9+
gi, gj = @index(Group, NTuple)
10+
i, j = @index(Local, NTuple)
11+
12+
TILE_DIM = @uniform groupsize()[1]
13+
14+
# +1 to avoid bank conflicts on shared memory
15+
tile1 = @localmem eltype(output) (TILE_DIM+BANK, TILE_DIM)
16+
tile2 = @localmem eltype(output) (TILE_DIM+BANK, TILE_DIM)
17+
18+
# private variable for tile output
19+
outval = @private eltype(output) 1
20+
@inbounds outval[1] = -zero(eltype(output))
21+
22+
@uniform N = size(output, 1)
23+
# number of tiles depends on inner dimension
24+
@uniform NUM_TILES = div(R + TILE_DIM - 1, TILE_DIM)
25+
26+
# loop over all tiles needed for this calculation
27+
for t in 0:NUM_TILES-1
28+
# Can't use @index(Global), because we use a smaller ndrange
29+
I = (gi-1) * TILE_DIM + i
30+
J = (gj-1) * TILE_DIM + j
31+
32+
# load inputs into tiles, with bounds checking for non-square matrices
33+
if I <= N && t*TILE_DIM + j <= R
34+
@inbounds tile1[i, j] = input1[I, t*TILE_DIM + j]
35+
else
36+
@inbounds tile1[i, j] = 0.0
37+
end
38+
if t*TILE_DIM + i <= R && J <= M
39+
@inbounds tile2[i, j] = input2[t*TILE_DIM + i, J]
40+
else
41+
@inbounds tile2[i, j] = 0.0
42+
end
43+
44+
# wait for all tiles to be loaded
45+
@synchronize
46+
47+
# get global values again
48+
I = (gi-1) * TILE_DIM + i
49+
J = (gj-1) * TILE_DIM + j
50+
51+
# calculate value of spot in output, use temporary value to allow for vectorization
52+
out = zero(eltype(output))
53+
@simd for k in 1:TILE_DIM
54+
@inbounds out += tile1[i, k] * tile2[k, j]
55+
end
56+
outval[1] += out
57+
58+
@synchronize
59+
end
60+
61+
# get global indices again
62+
I = (gi-1) * TILE_DIM + i
63+
J = (gj-1) * TILE_DIM + j
64+
65+
# save if inbounds
66+
if I <= N && J <= M
67+
@inbounds output[I, J] = outval[1]
68+
end
69+
end
70+
71+
N = 1024
72+
R = 512
73+
M = 2048
74+
A = rand(N, R)
75+
B = rand(R, M)
76+
C = zeros(N, M)
77+
78+
kern = coalesced_matmul_kernel!(CPU(), (TILE_DIM, TILE_DIM))
79+
80+
81+
wait(kern(C, A, B, N, R, M, ndrange=size(C)))
82+
@test isapprox(A*B, C)

0 commit comments

Comments
 (0)