diff --git a/src/ROCKernels.jl b/src/ROCKernels.jl index 67eb0b2ea..c9f4e4971 100644 --- a/src/ROCKernels.jl +++ b/src/ROCKernels.jl @@ -166,4 +166,14 @@ end # TODO end +# Reduction. + +@device_override @inline KA.supports_warp_reduction() = true + +KA.supports_warp_reduction(::ROCBackend) = true + +function KA.shfl_down(val, offset) + AMDGPU.Device.shfl_down(val, offset) +end + end diff --git a/src/device/gcn/wavefront.jl b/src/device/gcn/wavefront.jl index f4f652a81..47ad389e3 100644 --- a/src/device/gcn/wavefront.jl +++ b/src/device/gcn/wavefront.jl @@ -57,11 +57,10 @@ end """ wfred(op::Function, val::T) where T -> T -Performs a wavefront-wide reduction on `val` in each lane, and returns the -result. A limited subset of functions are available to be passed as `op`. When -`op` is one of `(+, max, min, &, |, ⊻)`, `T` may be -`<:Union{Cint, Clong, Cuint, Culong}`. When `op` is one of `(+, max, min)`, -`T` may also be `<:Union{Float32, Float64}`. +Performs a wavefront-wide reduction on `val` in each lane, and returns the result. +A limited subset of functions are available to be passed as `op`. +When `op` is one of `(+, max, min, &, |, ⊻)`, `T` may be `<:Union{Cint, Clong, Cuint, Culong}`. +When `op` is one of `(+, max, min)`, `T` may also be `<:Union{Float32, Float64}`. """ wfred @@ -69,10 +68,9 @@ wfred wfscan(op::Function, val::T) where T -> T Performs a wavefront-wide scan on `val` in each lane, and returns the -result. A limited subset of functions are available to be passed as `op`. When -`op` is one of `(+, max, min, &, |, ⊻)`, `T` may be -`<:Union{Cint, Clong, Cuint, Culong}`. When `op` is one of `(+, max, min)`, -`T` may also be `<:Union{Float32, Float64}`. +result. A limited subset of functions are available to be passed as `op`. +When `op` is one of `(+, max, min, &, |, ⊻)`, `T` may be `<:Union{Cint, Clong, Cuint, Culong}`. +When `op` is one of `(+, max, min)`, `T` may also be `<:Union{Float32, Float64}`. """ wfscan