Skip to content

Commit 8d14f4e

Browse files
authored
Merge pull request #226 from JuliaGPU/tb/sumprod_abs
Properly handle reducing complex numbers with abs.
2 parents 41661f0 + c0469d3 commit 8d14f4e

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

src/mapreduce.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ gpu_promote_type(::typeof(Base.add_sum), ::Type{T}) where {T<:Number} = typeof(B
5757
gpu_promote_type(::typeof(Base.mul_prod), ::Type{T}) where {T<:Number} = typeof(Base.mul_prod(one(T), one(T)))
5858
gpu_promote_type(::typeof(max), ::Type{T}) where {T<: WidenReduceResult} = T
5959
gpu_promote_type(::typeof(min), ::Type{T}) where {T<: WidenReduceResult} = T
60+
gpu_promote_type(::typeof(abs), ::Type{Complex{T}}) where {T} = T
6061

6162
import Base.Broadcast: Broadcasted, ArrayStyle
6263
const GPUSrcArray = Union{Broadcasted{ArrayStyle{AT}}, GPUArray{T, N}} where {T, N, AT<:GPUArray}
@@ -66,7 +67,7 @@ function Base.mapreduce(f::Function, op::Function, A::GPUSrcArray; dims = :, ini
6667
end
6768

6869
function mapreduce_impl(f, op, ::NamedTuple{()}, A::GPUSrcArray, ::Colon)
69-
OT = gpu_promote_type(op, eltype(A))
70+
OT = gpu_promote_type(op, gpu_promote_type(f, eltype(A)))
7071
v0 = startvalue(op, OT) # TODO do this better
7172
acc_mapreduce(f, op, v0, A, ())
7273
end

test/testsuite/mapreduce.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ function test_mapreduce(AT)
2828
for dims in ((4048,), (1024,1024), (77,), (1923,209))
2929
@test compare(sum, AT, rand(range, dims))
3030
@test compare(prod, AT, rand(range, dims))
31+
@test compare(x -> sum(abs, x), AT, rand(range, dims))
32+
@test compare(x -> prod(abs, x), AT, rand(range, dims))
3133
ET <: Complex || @test compare(maximum, AT,rand(range, dims))
3234
ET <: Complex || @test compare(minimum, AT,rand(range, dims))
3335
end

0 commit comments

Comments
 (0)