diff --git a/lib/CUDAKernels/src/CUDAKernels.jl b/lib/CUDAKernels/src/CUDAKernels.jl index 5e49ed666..d193c002a 100644 --- a/lib/CUDAKernels/src/CUDAKernels.jl +++ b/lib/CUDAKernels/src/CUDAKernels.jl @@ -324,6 +324,7 @@ import KernelAbstractions: CompilerMetadata, DynamicCheck, LinearIndices import KernelAbstractions: __index_Local_Linear, __index_Group_Linear, __index_Global_Linear, __index_Local_Cartesian, __index_Group_Cartesian, __index_Global_Cartesian, __validindex, __print import KernelAbstractions: mkcontext, expand, __iterspace, __ndrange, __dynamic_checkbounds + function mkcontext(kernel::Kernel{<:CUDADevice}, _ndrange, iterspace) CompilerMetadata{KernelAbstractions.ndrange(kernel), DynamicCheck}(_ndrange, iterspace) end @@ -398,6 +399,14 @@ end CUDA._cuprint(args...) end +import KernelAbstractions: __test + +@device_override @inline function __test(__ctx__, conf) + KernelAbstractions.@localmem Float64 conf.threads_per_block + + KernelAbstractions.@print("dit werkt") +end + ### # GPU implementation of const memory ### @@ -408,3 +417,4 @@ Adapt.adapt_storage(to::ConstAdaptor, a::CUDA.CuDeviceArray) = Base.Experimental KernelAbstractions.argconvert(k::Kernel{<:CUDADevice}, arg) = CUDA.cudaconvert(arg) end + diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index 30d1cd75f..7db68c362 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -4,6 +4,7 @@ export @kernel export @Const, @localmem, @private, @uniform, @synchronize export @index, @groupsize, @ndrange export @print +export @reduce export Device, GPU, CPU, Event, MultiEvent, NoneEvent export async_copy! @@ -329,6 +330,18 @@ macro index(locale, args...) Expr(:call, GlobalRef(KernelAbstractions, index_function), esc(:__ctx__), map(esc, args)...) end +macro reduce(op, val, neutral) + quote + $__reduce($(esc(:__ctx__)),$(esc(op)), $(esc(val)), $(esc(neutral)), typeof($(esc(val)))) + end +end + +macro test(conf) + quote + $__test($(esc(:__ctx__)),$(esc(conf))) + end +end + ### # Internal kernel functions ### @@ -493,6 +506,7 @@ function __synchronize() error("@synchronize used outside kernel or not captured") end + @generated function __print(items...) str = "" args = [] @@ -515,15 +529,28 @@ end __size(args::Tuple) = Tuple{args...} __size(i::Int) = Tuple{i} + +# reduction +function __reduce(__ctx__, op, val, ::Type{T}) where T + error("@reduce used outside kernel or not captured") +end + +function __test(__ctx__, conf) + error("@test used outside kernel or not captured") +end + ### # Extras # - LoopInfo ### + include("extras/extras.jl") include("reflection.jl") +include("reduce.jl") + # CPU backend include("cpu.jl") diff --git a/src/reduce.jl b/src/reduce.jl new file mode 100644 index 000000000..71b19a4a2 --- /dev/null +++ b/src/reduce.jl @@ -0,0 +1,52 @@ +struct Config{ + THREADS_PER_WARP, # size of warp + THREADS_PER_BLOCK # size of blocks + } +end + +@inline function Base.getproperty(conf::Type{Config{ THREADS_PER_WARP, THREADS_PER_BLOCK}}, sym::Symbol) where { THREADS_PER_WARP, THREADS_PER_BLOCK} + if sym == :threads_per_warp + THREADS_PER_WARP + elseif sym == :threads_per_block + THREADS_PER_BLOCK + else + # fallback for nothing + getfield(conf, sym) + end +end + +# TODO: make variable block size possible +# TODO: figure out where to place this +# reduction functionality for a group +@inline function __reduce(__ctx__ , op, val, neutral, ::Type{T}) where {T} + threads = KernelAbstractions.@groupsize()[1] + threadIdx = KernelAbstractions.@index(Local) + + # shared mem for a complete reduction + shared = KernelAbstractions.@localmem(T, 1024) + @inbounds shared[threadIdx] = val + + # perform the reduction + d = 1 + while d < threads + KernelAbstractions.@synchronize() + index = 2 * d * (threadIdx-1) + 1 + @inbounds if index <= threads + other_val = if index + d <= threads + shared[index+d] + else + neutral + end + shared[index] = op(shared[index], other_val) + end + d *= 2 + end + + # load the final value on the first thread + if threadIdx == 1 + val = @inbounds shared[threadIdx] + end + + # every thread will return the reduced value of the group + return val +end