@@ -27,15 +27,9 @@ neutral_element(::typeof(Base._extrema_rf), ::Type{<:NTuple{2,T}}) where {T} = t
27
27
28
28
# resolve ambiguities
29
29
Base. mapreduce (f, op, A:: AnyGPUArray , As:: AbstractArrayOrBroadcasted... ;
30
- dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims= dims, init= init)
31
- # dims=:, init=nothing) = AK._mapreduce(f, op, A, As...; dims=dims, init=init)
30
+ dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims, init)
32
31
Base. mapreduce (f, op, A:: Broadcast.Broadcasted{<:AbstractGPUArrayStyle} , As:: AbstractArrayOrBroadcasted... ;
33
- dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims= dims, init= init)
34
- # dims=:, init=nothing) = AK.mapreduce(f, op, #_mapreduce(f, op, A, As...; dims=dims, init=init)
35
- Base. mapreduce (f, op, A:: AnyGPUArray ;
36
- dims= :, init= nothing ) = AK. mapreduce (f, op, A; init, dims= dims isa Colon ? nothing : dims)
37
- Base. mapreduce (f, op, A:: Broadcast.Broadcasted{<:AbstractGPUArrayStyle} ;
38
- dims= :, init= nothing ) = AK. mapreduce (f, op, A; init, dims= dims isa Colon ? nothing : dims)
32
+ dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims, init)
39
33
40
34
function _mapreduce (f:: F , op:: OP , As:: Vararg{Any,N} ; dims:: D , init) where {F,OP,N,D}
41
35
# figure out the destination container type by looking at the initializer element,
@@ -72,9 +66,25 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
72
66
end
73
67
74
68
# allocate an output container
69
+ block_size = 256 # Hard-code AK default to prevent mismatches
75
70
sz = size (A)
76
71
red = ntuple (i-> (dims== Colon () || i in dims) ? 1 : sz[i], length (sz))
77
- R = similar (A, ET, red)
72
+ R = if dims isa Colon
73
+ num_per_block = 2 * block_size
74
+ blocks = (prod (sz) + num_per_block - 1 ) ÷ num_per_block
75
+ similar (A, ET, 2 * blocks)
76
+ else
77
+ similar (A, ET, red)
78
+ end
79
+
80
+ # Use AcceleratedKernels if possible
81
+ if dims isa Colon || dims isa Integer
82
+ return AK. mapreduce (f, op, Base. materialize (A), get_backend (R);
83
+ block_size, init,
84
+ neutral= init,
85
+ dims= dims isa Colon ? nothing : dims,
86
+ temp = R)
87
+ end
78
88
79
89
# perform the reduction
80
90
if prod (sz) == 0
0 commit comments