Skip to content

Commit 25627df

Browse files
committed
AK Broadcasted
1 parent 92cccf6 commit 25627df

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

src/host/indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ function findminmax(binop, A::AnyGPUArray; init, dims)
228228
(x, i), (y, j) = t1, t2
229229

230230
binop(x, y) && return t2
231-
x == y && return (x, min(i, j))
231+
isequal(x, y) && return (x, min(i, j))
232232
return t1
233233
end
234234

src/host/mapreduce.jl

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,9 @@ neutral_element(::typeof(Base._extrema_rf), ::Type{<:NTuple{2,T}}) where {T} = t
2727

2828
# resolve ambiguities
2929
Base.mapreduce(f, op, A::AnyGPUArray, As::AbstractArrayOrBroadcasted...;
30-
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)
31-
# dims=:, init=nothing) = AK._mapreduce(f, op, A, As...; dims=dims, init=init)
30+
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims, init)
3231
Base.mapreduce(f, op, A::Broadcast.Broadcasted{<:AbstractGPUArrayStyle}, As::AbstractArrayOrBroadcasted...;
33-
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)
34-
# dims=:, init=nothing) = AK.mapreduce(f, op, #_mapreduce(f, op, A, As...; dims=dims, init=init)
35-
Base.mapreduce(f, op, A::AnyGPUArray;
36-
dims=:, init=nothing) = AK.mapreduce(f, op, A; init, dims=dims isa Colon ? nothing : dims)
37-
Base.mapreduce(f, op, A::Broadcast.Broadcasted{<:AbstractGPUArrayStyle};
38-
dims=:, init=nothing) = AK.mapreduce(f, op, A; init, dims=dims isa Colon ? nothing : dims)
32+
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims, init)
3933

4034
function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,N,D}
4135
# figure out the destination container type by looking at the initializer element,
@@ -72,9 +66,25 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
7266
end
7367

7468
# allocate an output container
69+
block_size = 256 # Hard-code AK default to prevent mismatches
7570
sz = size(A)
7671
red = ntuple(i->(dims==Colon() || i in dims) ? 1 : sz[i], length(sz))
77-
R = similar(A, ET, red)
72+
R = if dims isa Colon
73+
num_per_block = 2 * block_size
74+
blocks = (prod(sz) + num_per_block - 1) ÷ num_per_block
75+
similar(A, ET, 2 * blocks)
76+
else
77+
similar(A, ET, red)
78+
end
79+
80+
# Use AcceleratedKernels if possible
81+
if dims isa Colon || dims isa Integer
82+
return AK.mapreduce(f, op, Base.materialize(A), get_backend(R);
83+
block_size, init,
84+
neutral=init,
85+
dims=dims isa Colon ? nothing : dims,
86+
temp = R)
87+
end
7888

7989
# perform the reduction
8090
if prod(sz) == 0

0 commit comments

Comments
 (0)