Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/ROCKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,14 @@ end
# TODO
end

# Reduction.

@device_override @inline KA.supports_warp_reduction() = true

KA.supports_warp_reduction(::ROCBackend) = true

function KA.shfl_down(val, offset)
AMDGPU.Device.shfl_down(val, offset)
end

end
16 changes: 7 additions & 9 deletions src/device/gcn/wavefront.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,20 @@ end
"""
wfred(op::Function, val::T) where T -> T

Performs a wavefront-wide reduction on `val` in each lane, and returns the
result. A limited subset of functions are available to be passed as `op`. When
`op` is one of `(+, max, min, &, |, ⊻)`, `T` may be
`<:Union{Cint, Clong, Cuint, Culong}`. When `op` is one of `(+, max, min)`,
`T` may also be `<:Union{Float32, Float64}`.
Performs a wavefront-wide reduction on `val` in each lane, and returns the result.
A limited subset of functions are available to be passed as `op`.
When `op` is one of `(+, max, min, &, |, ⊻)`, `T` may be `<:Union{Cint, Clong, Cuint, Culong}`.
When `op` is one of `(+, max, min)`, `T` may also be `<:Union{Float32, Float64}`.
"""
wfred

"""
wfscan(op::Function, val::T) where T -> T

Performs a wavefront-wide scan on `val` in each lane, and returns the
result. A limited subset of functions are available to be passed as `op`. When
`op` is one of `(+, max, min, &, |, ⊻)`, `T` may be
`<:Union{Cint, Clong, Cuint, Culong}`. When `op` is one of `(+, max, min)`,
`T` may also be `<:Union{Float32, Float64}`.
result. A limited subset of functions are available to be passed as `op`.
When `op` is one of `(+, max, min, &, |, ⊻)`, `T` may be `<:Union{Cint, Clong, Cuint, Culong}`.
When `op` is one of `(+, max, min)`, `T` may also be `<:Union{Float32, Float64}`.
"""
wfscan

Expand Down