diff --git a/Project.toml b/Project.toml index 0051650a..d52e498d 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" version = "11.2.3" [deps] +AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" @@ -22,6 +23,7 @@ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" JLD2Ext = "JLD2" [compat] +AcceleratedKernels = "0.4" Adapt = "4.0" GPUArraysCore = "= 0.2.0" JLD2 = "0.4, 0.5" diff --git a/src/GPUArrays.jl b/src/GPUArrays.jl index 8c1fc14e..af8b956e 100644 --- a/src/GPUArrays.jl +++ b/src/GPUArrays.jl @@ -16,6 +16,7 @@ using Reexport @reexport using GPUArraysCore using KernelAbstractions +import AcceleratedKernels as AK # device functionality include("device/abstractarray.jl") diff --git a/src/host/indexing.jl b/src/host/indexing.jl index 8aee8b9d..889e41f6 100644 --- a/src/host/indexing.jl +++ b/src/host/indexing.jl @@ -228,7 +228,7 @@ function findminmax(binop, A::AnyGPUArray; init, dims) (x, i), (y, j) = t1, t2 binop(x, y) && return t2 - x == y && return (x, min(i, j)) + isequal(x, y) && return (x, min(i, j)) return t1 end diff --git a/src/host/mapreduce.jl b/src/host/mapreduce.jl index 4a128313..a1ec458b 100644 --- a/src/host/mapreduce.jl +++ b/src/host/mapreduce.jl @@ -27,9 +27,9 @@ neutral_element(::typeof(Base._extrema_rf), ::Type{<:NTuple{2,T}}) where {T} = t # resolve ambiguities Base.mapreduce(f, op, A::AnyGPUArray, As::AbstractArrayOrBroadcasted...; - dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init) + dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims, init) Base.mapreduce(f, op, A::Broadcast.Broadcasted{<:AbstractGPUArrayStyle}, As::AbstractArrayOrBroadcasted...; - dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init) + dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims, init) function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,N,D} # figure out the destination container type by looking at the initializer element, @@ -40,7 +40,7 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP, (ET === Union{} || ET === Any) && error("mapreduce cannot figure the output element type, please pass an explicit init value") - init = neutral_element(op, ET) + init = AK.neutral_element(op, ET) else ET = typeof(init) end @@ -66,9 +66,25 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP, end # allocate an output container + block_size = 256 # Hard-code AK default to prevent mismatches sz = size(A) red = ntuple(i->(dims==Colon() || i in dims) ? 1 : sz[i], length(sz)) - R = similar(A, ET, red) + R = if dims isa Colon + num_per_block = 2 * block_size + blocks = (prod(sz) + num_per_block - 1) รท num_per_block + similar(A, ET, 2 * blocks) + else + similar(A, ET, red) + end + + # Use AcceleratedKernels if possible + if dims isa Colon || dims isa Integer + return AK.mapreduce(f, op, Base.materialize(A), get_backend(R); + block_size, init, + neutral=init, + dims=dims isa Colon ? nothing : dims, + temp = R) + end # perform the reduction if prod(sz) == 0 @@ -85,14 +101,14 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP, end end -Base.any(A::AnyGPUArray{Bool}) = mapreduce(identity, |, A) -Base.all(A::AnyGPUArray{Bool}) = mapreduce(identity, &, A) +Base.any(A::AnyGPUArray{Bool}) = AK.any(identity, A) +Base.all(A::AnyGPUArray{Bool}) = AK.all(identity, A) -Base.any(f::Function, A::AnyGPUArray) = mapreduce(f, |, A) -Base.all(f::Function, A::AnyGPUArray) = mapreduce(f, &, A) +Base.any(f::Function, A::AnyGPUArray) = AK.any(f, A) +Base.all(f::Function, A::AnyGPUArray) = AK.all(f, A) Base.count(pred::Function, A::AnyGPUArray; dims=:, init=0) = - mapreduce(pred, Base.add_sum, A; init=init, dims=dims) + AK.count(pred, A; init, dims=dims isa Colon ? nothing : dims) # avoid calling into `initarray!` for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)), @@ -101,7 +117,7 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)), fname! = Symbol(fname, '!') @eval begin Base.$(fname!)(f::Function, r::AnyGPUArray, A::AnyGPUArray{T}) where T = - GPUArrays.mapreducedim!(f, $(op), r, A; init=neutral_element($(op), T)) + GPUArrays.mapreducedim!(f, $(op), r, A; init=AK.neutral_element($(op), T)) end end