Skip to content

Commit b398435

Browse files
Update performant_matmul.jl
debug when N and M are not an integer multiple of TILE_DIM
1 parent de5dc99 commit b398435

File tree

1 file changed

+23
-30
lines changed

1 file changed

+23
-30
lines changed

examples/performant_matmul.jl

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,36 +21,31 @@ const TILE_DIM = 32
2121

2222
# private variable for tile output
2323
outval = @private eltype(output) 1
24-
@inbounds outval[1] = -zero(eltype(output))
24+
@inbounds outval[1] = zero(eltype(output))
2525

26-
@uniform N = size(output, 1)
2726
# number of tiles depends on inner dimension
28-
@uniform NUM_TILES = div(R + TILE_DIM - 1, TILE_DIM)
27+
@uniform NUM_TILES = cld(R, TILE_DIM)
2928

29+
# Can't use @index(Global), because we use a smaller ndrange
30+
I = (gi - 1) * TILE_DIM + i
31+
J = (gj - 1) * TILE_DIM + j
3032
# loop over all tiles needed for this calculation
3133
for t in 0:(NUM_TILES - 1)
32-
# Can't use @index(Global), because we use a smaller ndrange
33-
I = (gi - 1) * TILE_DIM + i
34-
J = (gj - 1) * TILE_DIM + j
35-
3634
# load inputs into tiles, with bounds checking for non-square matrices
3735
if I <= N && t * TILE_DIM + j <= R
3836
@inbounds tile1[i, j] = input1[I, t * TILE_DIM + j]
3937
else
4038
@inbounds tile1[i, j] = 0.0
4139
end
40+
4241
if t * TILE_DIM + i <= R && J <= M
4342
@inbounds tile2[i, j] = input2[t * TILE_DIM + i, J]
4443
else
4544
@inbounds tile2[i, j] = 0.0
4645
end
4746

4847
# wait for all tiles to be loaded
49-
@synchronize
50-
51-
# get global values again
52-
I = (gi - 1) * TILE_DIM + i
53-
J = (gj - 1) * TILE_DIM + j
48+
@synchronize(true)
5449

5550
# calculate value of spot in output, use temporary value to allow for vectorization
5651
out = zero(eltype(output))
@@ -59,29 +54,27 @@ const TILE_DIM = 32
5954
end
6055
outval[1] += out
6156

62-
@synchronize
57+
@synchronize(true)
6358
end
6459

65-
# get global indices again
66-
I = (gi - 1) * TILE_DIM + i
67-
J = (gj - 1) * TILE_DIM + j
68-
6960
# save if inbounds
7061
if I <= N && J <= M
7162
@inbounds output[I, J] = outval[1]
7263
end
7364
end
7465

75-
N = 1024
76-
R = 512
77-
M = 2048
78-
A = rand!(allocate(backend, Float32, N, R))
79-
B = rand!(allocate(backend, Float32, R, M))
80-
C = KernelAbstractions.zeros(backend, Float32, N, M)
81-
82-
kern = coalesced_matmul_kernel!(backend, (TILE_DIM, TILE_DIM))
83-
84-
kern(C, A, B, N, R, M, ndrange = size(C))
85-
KernelAbstractions.synchronize(backend)
86-
87-
@test isapprox(A * B, C)
66+
@testset "dims for $N, $R, $M" for (N,R,M) in [rand(500:1000,3) for _ in 1:10]
67+
A = rand!(allocate(backend, Float32, N, R))
68+
B = rand!(allocate(backend, Float32, R, M))
69+
C = KernelAbstractions.zeros(backend, Float32, N, M)
70+
71+
kern = coalesced_matmul_kernel!(backend, (TILE_DIM, TILE_DIM))
72+
73+
group_size_x = cld(N, TILE_DIM)
74+
group_size_y = cld(M, TILE_DIM)
75+
76+
kern(C, A, B, N, R, M, ndrange = (group_size_x * TILE_DIM, group_size_y * TILE_DIM))
77+
KernelAbstractions.synchronize(backend)
78+
79+
@test isapprox(A * B, C)
80+
end

0 commit comments

Comments
 (0)