@@ -21,36 +21,31 @@ const TILE_DIM = 32
2121
2222 # private variable for tile output
2323 outval = @private eltype (output) 1
24- @inbounds outval[1 ] = - zero (eltype (output))
24+ @inbounds outval[1 ] = zero (eltype (output))
2525
26- @uniform N = size (output, 1 )
2726 # number of tiles depends on inner dimension
28- @uniform NUM_TILES = div (R + TILE_DIM - 1 , TILE_DIM)
27+ @uniform NUM_TILES = cld (R , TILE_DIM)
2928
29+ # Can't use @index(Global), because we use a smaller ndrange
30+ I = (gi - 1 ) * TILE_DIM + i
31+ J = (gj - 1 ) * TILE_DIM + j
3032 # loop over all tiles needed for this calculation
3133 for t in 0 : (NUM_TILES - 1 )
32- # Can't use @index(Global), because we use a smaller ndrange
33- I = (gi - 1 ) * TILE_DIM + i
34- J = (gj - 1 ) * TILE_DIM + j
35-
3634 # load inputs into tiles, with bounds checking for non-square matrices
3735 if I <= N && t * TILE_DIM + j <= R
3836 @inbounds tile1[i, j] = input1[I, t * TILE_DIM + j]
3937 else
4038 @inbounds tile1[i, j] = 0.0
4139 end
40+
4241 if t * TILE_DIM + i <= R && J <= M
4342 @inbounds tile2[i, j] = input2[t * TILE_DIM + i, J]
4443 else
4544 @inbounds tile2[i, j] = 0.0
4645 end
4746
4847 # wait for all tiles to be loaded
49- @synchronize
50-
51- # get global values again
52- I = (gi - 1 ) * TILE_DIM + i
53- J = (gj - 1 ) * TILE_DIM + j
48+ @synchronize (true )
5449
5550 # calculate value of spot in output, use temporary value to allow for vectorization
5651 out = zero (eltype (output))
@@ -59,29 +54,27 @@ const TILE_DIM = 32
5954 end
6055 outval[1 ] += out
6156
62- @synchronize
57+ @synchronize ( true )
6358 end
6459
65- # get global indices again
66- I = (gi - 1 ) * TILE_DIM + i
67- J = (gj - 1 ) * TILE_DIM + j
68-
6960 # save if inbounds
7061 if I <= N && J <= M
7162 @inbounds output[I, J] = outval[1 ]
7263 end
7364end
7465
75- N = 1024
76- R = 512
77- M = 2048
78- A = rand! (allocate (backend, Float32, N, R))
79- B = rand! (allocate (backend, Float32, R, M))
80- C = KernelAbstractions. zeros (backend, Float32, N, M)
81-
82- kern = coalesced_matmul_kernel! (backend, (TILE_DIM, TILE_DIM))
83-
84- kern (C, A, B, N, R, M, ndrange = size (C))
85- KernelAbstractions. synchronize (backend)
86-
87- @test isapprox (A * B, C)
66+ @testset " dims for $N , $R , $M " for (N,R,M) in [rand (500 : 1000 ,3 ) for _ in 1 : 10 ]
67+ A = rand! (allocate (backend, Float32, N, R))
68+ B = rand! (allocate (backend, Float32, R, M))
69+ C = KernelAbstractions. zeros (backend, Float32, N, M)
70+
71+ kern = coalesced_matmul_kernel! (backend, (TILE_DIM, TILE_DIM))
72+
73+ group_size_x = cld (N, TILE_DIM)
74+ group_size_y = cld (M, TILE_DIM)
75+
76+ kern (C, A, B, N, R, M, ndrange = (group_size_x * TILE_DIM, group_size_y * TILE_DIM))
77+ KernelAbstractions. synchronize (backend)
78+
79+ @test isapprox (A * B, C)
80+ end
0 commit comments