Skip to content

Commit b48b970

Browse files
authored
Merge pull request #14 from JuliaGPU/vc/shmem
Fix and test local memory
2 parents 50efa41 + 10fa470 commit b48b970

File tree

4 files changed

+42
-7
lines changed

4 files changed

+42
-7
lines changed

src/KernelAbstractions.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,15 @@ Query the workgroupsize on the device.
7979
"""
8080
function groupsize end
8181

82-
const shmem_id = Ref(0)
83-
8482
"""
8583
@localmem T dims
8684
"""
8785
macro localmem(T, dims)
88-
id = (shmem_id[]+= 1)
86+
# Stay in sync with CUDAnative
87+
id = gensym("static_shmem")
8988

9089
quote
91-
$SharedMemory($(esc(T)), Val($(esc(dims))), Val($id))
90+
$SharedMemory($(esc(T)), Val($(esc(dims))), Val($(QuoteNode(id))))
9291
end
9392
end
9493

@@ -281,11 +280,11 @@ include("macros.jl")
281280
###
282281

283282
function Scratchpad(::Type{T}, ::Val{Dims}) where {T, Dims}
284-
throw(MethodError(ScratchArray, (T, Val(Dims))))
283+
throw(MethodError(Scratchpad, (T, Val(Dims))))
285284
end
286285

287286
function SharedMemory(::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
288-
throw(MethodError(ScratchArray, (T, Val(Dims), Val(Id))))
287+
throw(MethodError(SharedMemory, (T, Val(Dims), Val(Id))))
289288
end
290289

291290
function __synchronize()

src/backends/cuda.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ end
203203
###
204204
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(SharedMemory), ::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
205205
ptr = CUDAnative._shmem(Val(Id), T, Val(prod(Dims)))
206-
CUDAnative.CuDeviceArray(Dims, CUDAnative.DevicePtr{T, CUDAnative.AS.Shared}(ptr))
206+
CUDAnative.CuDeviceArray(Dims, ptr)
207207
end
208208

209209
###

test/localmem.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using KernelAbstractions
2+
using Test
3+
using CUDAapi
4+
if has_cuda_gpu()
5+
using CuArrays
6+
CuArrays.allowscalar(false)
7+
end
8+
9+
@kernel function localmem(A)
10+
I = @index(Global, Linear)
11+
i = @index(Local, Linear)
12+
lmem = @localmem Int (groupsize(),) # Ok iff groupsize is static
13+
lmem[i] = i
14+
@synchronize
15+
A[I] = lmem[groupsize() - i + 1]
16+
end
17+
18+
function harness(backend, ArrayT)
19+
A = ArrayT{Int}(undef, 64)
20+
wait(localmem(backend, 16)(A, ndrange=size(A)))
21+
@test all(A[1:16] .== 16:-1:1)
22+
@test all(A[17:32] .== 16:-1:1)
23+
@test all(A[33:48] .== 16:-1:1)
24+
@test all(A[49:64] .== 16:-1:1)
25+
end
26+
27+
@testset "kernels" begin
28+
harness(CPU(), Array)
29+
if has_cuda_gpu()
30+
harness(CUDA(), CuArray)
31+
end
32+
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,8 @@ using Test
55
include("test.jl")
66
end
77

8+
@testset "Localmem" begin
9+
include("localmem.jl")
10+
end
11+
812
include("examples.jl")

0 commit comments

Comments
 (0)