diff --git a/lib/intrinsics/src/SPIRVIntrinsics.jl b/lib/intrinsics/src/SPIRVIntrinsics.jl index bd15fdd9..b2bca59d 100644 --- a/lib/intrinsics/src/SPIRVIntrinsics.jl +++ b/lib/intrinsics/src/SPIRVIntrinsics.jl @@ -23,6 +23,7 @@ include("printf.jl") include("math.jl") include("integer.jl") include("atomic.jl") +include("shuffle.jl") # helper macro to import all names from this package, even non-exported ones. macro import_all() diff --git a/lib/intrinsics/src/shuffle.jl b/lib/intrinsics/src/shuffle.jl new file mode 100644 index 00000000..804b0e18 --- /dev/null +++ b/lib/intrinsics/src/shuffle.jl @@ -0,0 +1,12 @@ +export sub_group_shuffle, sub_group_shuffle_xor + +const gentypes = [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Float16, Float32, Float64] + +for gentype in gentypes +@eval begin + +@device_function sub_group_shuffle(x::$gentype, i::Integer) = @builtin_ccall("sub_group_shuffle", $gentype, ($gentype, Int32), x, i % Int32 - 1i32) +@device_function sub_group_shuffle_xor(x::$gentype, mask::Integer) = @builtin_ccall("sub_group_shuffle_xor", $gentype, ($gentype, UInt32), x, mask % UInt32) + +end +end diff --git a/lib/intrinsics/src/utils.jl b/lib/intrinsics/src/utils.jl index e1a5a939..2c12db8a 100644 --- a/lib/intrinsics/src/utils.jl +++ b/lib/intrinsics/src/utils.jl @@ -26,6 +26,8 @@ macro builtin_ccall(name, ret, argtypes, args...) "c" elseif T == UInt8 "h" + elseif T == Float16 + "Dh" elseif T == Float32 "f" elseif T == Float64 diff --git a/src/mapreduce.jl b/src/mapreduce.jl index e9a3f979..60064d06 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -5,8 +5,78 @@ # - group-stride loop to delay need for second kernel launch # - let the driver choose the local size +function shuffle_expr(::Type{T}) where {T} + if T in SPIRVIntrinsics.gentypes + return :(sub_group_shuffle(val, i)) + elseif Base.isstructtype(T) + ex = Expr(:new, T) + for f in fieldnames(T) + ex_f = shuffle_expr(fieldtype(T, f)) + ex_f === nothing && return nothing + push!(ex.args, :(let val = getfield(val, $(QuoteNode(f))) + $ex_f + end)) + end + return ex + else + return nothing + end +end + +@inline @generated function reduce_group(op, val::T, neutral, ::Val{maxitems}) where {T, maxitems} + ex = shuffle_expr(T) + if ex === nothing + return :(reduce_group_fallback(op, val, neutral, Val(maxitems))) + end + + quote + # Subgroup shuffle-based warp reduction + lane = get_sub_group_local_id() + width = get_sub_group_size() + + offset = 1 + while offset < width + i = lane + offset + other = $ex + if i <= width + val = op(val, other) + end + offset <<= 1 + end + + items = get_num_sub_groups() + item = get_sub_group_id() + + shared = CLLocalArray(T, (maxitems,)) + if items > 1 && lane == 1 + @inbounds shared[item] = val + + d = 1 + while d < items + work_group_barrier(LOCAL_MEM_FENCE) + index = 2 * d * (item-1) + 1 + @inbounds if index <= items + other_val = if index + d <= items + shared[index+d] + else + neutral + end + shared[index] = op(shared[index], other_val) + end + d *= 2 + end + + if item == 1 + val = @inbounds shared[item] + end + end + + return val + end +end + # Reduce a value across a group, using local memory for communication -@inline function reduce_group(op, val::T, neutral, ::Val{maxitems}) where {T, maxitems} +@inline function reduce_group_fallback(op, val::T, neutral, ::Val{maxitems}) where {T, maxitems} items = get_local_size() item = get_local_id() @@ -45,12 +115,13 @@ Base.@propagate_inbounds _map_getindex(args::Tuple{}, I) = () # Reduce an array across the grid. All elements to be processed can be addressed by the # product of the two iterators `Rreduce` and `Rother`, where the latter iterator will have # singleton entries for the dimensions that should be reduced (and vice versa). -function partial_mapreduce_device(f, op, neutral, maxitems, Rreduce, Rother, R, As...) +function partial_mapreduce_device(f, op, neutral, maxitems, Rreduce, Rother, R, A) + As = (A,) # decompose the 1D hardware indices into separate ones for reduction (across items # and possibly groups if it doesn't fit) and other elements (remaining groups) localIdx_reduce = get_local_id() localDim_reduce = get_local_size() - groupIdx_reduce, groupIdx_other = fldmod1(get_group_id(), length(Rother)) + groupIdx_reduce, groupIdx_other = @inline fldmod1(get_group_id(), length(Rother)) groupDim_reduce = get_num_groups() รท length(Rother) # group-based indexing into the values outside of the reduction dimension @@ -67,7 +138,7 @@ function partial_mapreduce_device(f, op, neutral, maxitems, Rreduce, Rother, R, neutral end - val = op(neutral, neutral) + val = neutral # reduce serially across chunks of input vector that don't fit in a group ireduce = localIdx_reduce + (groupIdx_reduce - 1) * localDim_reduce