Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/intrinsics/src/SPIRVIntrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ include("printf.jl")
include("math.jl")
include("integer.jl")
include("atomic.jl")
include("shuffle.jl")

# helper macro to import all names from this package, even non-exported ones.
macro import_all()
Expand Down
12 changes: 12 additions & 0 deletions lib/intrinsics/src/shuffle.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
export sub_group_shuffle, sub_group_shuffle_xor

const gentypes = [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Float16, Float32, Float64]

for gentype in gentypes
@eval begin

@device_function sub_group_shuffle(x::$gentype, i::Integer) = @builtin_ccall("sub_group_shuffle", $gentype, ($gentype, Int32), x, i % Int32 - 1i32)
@device_function sub_group_shuffle_xor(x::$gentype, mask::Integer) = @builtin_ccall("sub_group_shuffle_xor", $gentype, ($gentype, UInt32), x, mask % UInt32)

end
end
2 changes: 2 additions & 0 deletions lib/intrinsics/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ macro builtin_ccall(name, ret, argtypes, args...)
"c"
elseif T == UInt8
"h"
elseif T == Float16
"Dh"
elseif T == Float32
"f"
elseif T == Float64
Expand Down
79 changes: 75 additions & 4 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,78 @@
# - group-stride loop to delay need for second kernel launch
# - let the driver choose the local size

function shuffle_expr(::Type{T}) where {T}
if T in SPIRVIntrinsics.gentypes
return :(sub_group_shuffle(val, i))
elseif Base.isstructtype(T)
ex = Expr(:new, T)
for f in fieldnames(T)
ex_f = shuffle_expr(fieldtype(T, f))
ex_f === nothing && return nothing
push!(ex.args, :(let val = getfield(val, $(QuoteNode(f)))
$ex_f
end))
end
return ex
else
return nothing
end
end

@inline @generated function reduce_group(op, val::T, neutral, ::Val{maxitems}) where {T, maxitems}
ex = shuffle_expr(T)
if ex === nothing
return :(reduce_group_fallback(op, val, neutral, Val(maxitems)))
end

quote
# Subgroup shuffle-based warp reduction
lane = get_sub_group_local_id()
width = get_sub_group_size()

offset = 1
while offset < width
i = lane + offset
other = $ex
if i <= width
val = op(val, other)
end
offset <<= 1
end

items = get_num_sub_groups()
item = get_sub_group_id()

shared = CLLocalArray(T, (maxitems,))
if items > 1 && lane == 1
@inbounds shared[item] = val

d = 1
while d < items
work_group_barrier(LOCAL_MEM_FENCE)
index = 2 * d * (item-1) + 1
@inbounds if index <= items
other_val = if index + d <= items
shared[index+d]
else
neutral
end
shared[index] = op(shared[index], other_val)
end
d *= 2
end

if item == 1
val = @inbounds shared[item]
end
end

return val
end
end

# Reduce a value across a group, using local memory for communication
@inline function reduce_group(op, val::T, neutral, ::Val{maxitems}) where {T, maxitems}
@inline function reduce_group_fallback(op, val::T, neutral, ::Val{maxitems}) where {T, maxitems}
items = get_local_size()
item = get_local_id()

Expand Down Expand Up @@ -45,12 +115,13 @@ Base.@propagate_inbounds _map_getindex(args::Tuple{}, I) = ()
# Reduce an array across the grid. All elements to be processed can be addressed by the
# product of the two iterators `Rreduce` and `Rother`, where the latter iterator will have
# singleton entries for the dimensions that should be reduced (and vice versa).
function partial_mapreduce_device(f, op, neutral, maxitems, Rreduce, Rother, R, As...)
function partial_mapreduce_device(f, op, neutral, maxitems, Rreduce, Rother, R, A)
As = (A,)
# decompose the 1D hardware indices into separate ones for reduction (across items
# and possibly groups if it doesn't fit) and other elements (remaining groups)
localIdx_reduce = get_local_id()
localDim_reduce = get_local_size()
groupIdx_reduce, groupIdx_other = fldmod1(get_group_id(), length(Rother))
groupIdx_reduce, groupIdx_other = @inline fldmod1(get_group_id(), length(Rother))
groupDim_reduce = get_num_groups() ÷ length(Rother)

# group-based indexing into the values outside of the reduction dimension
Expand All @@ -67,7 +138,7 @@ function partial_mapreduce_device(f, op, neutral, maxitems, Rreduce, Rother, R,
neutral
end

val = op(neutral, neutral)
val = neutral

# reduce serially across chunks of input vector that don't fit in a group
ireduce = localIdx_reduce + (groupIdx_reduce - 1) * localDim_reduce
Expand Down
Loading