Skip to content

Commit c755f89

Browse files
committed
AK Broadcasted
1 parent a11ff60 commit c755f89

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/host/indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ function findminmax(binop, A::AnyGPUArray; init, dims)
228228
(x, i), (y, j) = t1, t2
229229

230230
binop(x, y) && return t2
231-
x == y && return (x, min(i, j))
231+
isequal(x, y) && return (x, min(i, j))
232232
return t1
233233
end
234234

src/host/mapreduce.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
6969
block_size = 256 # Hard-code AK default to prevent mismatches
7070
sz = size(A)
7171
red = ntuple(i->(dims==Colon() || i in dims) ? 1 : sz[i], length(sz))
72-
R = if !(A isa Broadcast.Broadcasted) && dims isa Colon
72+
R = if dims isa Colon
7373
num_per_block = 2 * block_size
7474
blocks = (prod(sz) + num_per_block - 1) ÷ num_per_block
7575
similar(A, ET, 2 * blocks)
@@ -78,8 +78,8 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
7878
end
7979

8080
# Use AcceleratedKernels if possible
81-
if !(A isa Broadcast.Broadcasted) && (dims isa Colon || dims isa Integer)
82-
return AK.mapreduce(f, op, A, get_backend(R);
81+
if dims isa Colon || dims isa Integer
82+
return AK.mapreduce(f, op, Base.materialize(A), get_backend(R);
8383
block_size, init,
8484
neutral=init,
8585
dims=dims isa Colon ? nothing : dims,

0 commit comments

Comments
 (0)