Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions lib/CUDAKernels/src/CUDAKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ import CUDA: @device_override
import KernelAbstractions: CompilerMetadata, DynamicCheck, LinearIndices
import KernelAbstractions: __index_Local_Linear, __index_Group_Linear, __index_Global_Linear, __index_Local_Cartesian, __index_Group_Cartesian, __index_Global_Cartesian, __validindex, __print
import KernelAbstractions: mkcontext, expand, __iterspace, __ndrange, __dynamic_checkbounds
import KernelAbstractions: __reduce


function mkcontext(kernel::Kernel{<:CUDADevice}, _ndrange, iterspace)
CompilerMetadata{KernelAbstractions.ndrange(kernel), DynamicCheck}(_ndrange, iterspace)
Expand Down Expand Up @@ -407,4 +409,40 @@ Adapt.adapt_storage(to::ConstAdaptor, a::CUDA.CuDeviceArray) = Base.Experimental
# Argument conversion
KernelAbstractions.argconvert(k::Kernel{<:CUDADevice}, arg) = CUDA.cudaconvert(arg)

# TODO: make variable block size possible
# TODO: figure out where to place this
# reduction functionality for a group
@device_override @inline function __reduce(__ctx__ , op, val, neutral, ::Type{T}) where T
threads = KernelAbstractions.@groupsize()[1]
threadIdx = KernelAbstractions.@index(Local)

# shared mem for a complete reduction
shared = KernelAbstractions.@localmem(T, 1024)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is the moment we need dynamic shared memory support?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x-ref: #11

@inbounds shared[threadIdx] = val

# perform the reduction
d = 1
while d < threads
KernelAbstractions.@synchronize()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are inside CUDAKernels here and as such you can use CUDA.jl functionality directly.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats correct! But a implementation with KA.jl macros would allow for a single implementation that can run on all supported back-end. Because of this I am not sure what the best place is for the code for this implementation.

Also, the main difference between different back-end would the size of local memory but the use of dynamic memory would be a solution to this.

index = 2 * d * (threadIdx-1) + 1
@inbounds if index <= threads
other_val = if index + d <= threads
shared[index+d]
else
neutral
end
shared[index] = op(shared[index], other_val)
end
d *= 2
end

# load the final value on the first thread
if threadIdx == 1
val = @inbounds shared[threadIdx]
end
# every thread will return the reduced value of the group
return val
end

end

16 changes: 16 additions & 0 deletions src/KernelAbstractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export @kernel
export @Const, @localmem, @private, @uniform, @synchronize
export @index, @groupsize, @ndrange
export @print
export @reduce
export Device, GPU, CPU, Event, MultiEvent, NoneEvent
export async_copy!

Expand Down Expand Up @@ -329,6 +330,14 @@ macro index(locale, args...)
Expr(:call, GlobalRef(KernelAbstractions, index_function), esc(:__ctx__), map(esc, args)...)
end

# TODO: Where should we havdle the logic of neutral, adding it to the macro's logic would reduce complexity in terms of using the macro
# but adding it to the macro may cause some overhead
macro reduce(op, val, neutral)
quote
$__reduce($(esc(:__ctx__)),$(esc(op)), $(esc(val)), $(esc(neutral)), typeof($(esc(val))))
end
end

###
# Internal kernel functions
###
Expand Down Expand Up @@ -493,6 +502,7 @@ function __synchronize()
error("@synchronize used outside kernel or not captured")
end


@generated function __print(items...)
str = ""
args = []
Expand All @@ -515,6 +525,12 @@ end
__size(args::Tuple) = Tuple{args...}
__size(i::Int) = Tuple{i}


# reduction
function __reduce(op, val, ::Type{T}) where T
error("@reduce used outside kernel or not captured")
end

###
# Extras
# - LoopInfo
Expand Down