Skip to content

Commit e526007

Browse files
authored
Merge pull request #99 from JuliaGPU/sd/mapreducedim
implement mapreducedim
2 parents a9a395a + d5e78ce commit e526007

File tree

2 files changed

+31
-21
lines changed

2 files changed

+31
-21
lines changed

src/mapreduce.jl

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,30 +48,39 @@ function Base.mapreduce(f, op, v0, A::GPUArray)
4848
acc_mapreduce(f, op, v0, A, ())
4949
end
5050

51-
52-
function mapreducedim_kernel(state, f, op, R::AbstractArray{T1, N}, A::AbstractArray{T, N}, slice_size, sizeA, dim) where {T1, T, N}
53-
ilin = UInt32(linear_index(state))
54-
ilin > length(R) && return
55-
accum = zero(T1)
56-
@inbounds for i = UInt32(1):slice_size
57-
idx = N == dim ? (ilin, i) : (i, ilin)
58-
i2d = gpu_sub2ind(sizeA, idx)
59-
accum = op(accum, f(A[i2d]))
51+
@generated function mapreducedim_kernel(state, f, op, R, A, range::NTuple{N, Any}) where N
52+
types = (range.parameters...,)
53+
indices = ntuple(i-> Symbol("I_$i"), N)
54+
Iexpr = ntuple(i-> :(I[$i]), N)
55+
body = :(@inbounds R[$(Iexpr...)] = op(R[$(Iexpr...)], f(A[$(indices...)])))
56+
for i = N:-1:1
57+
idxsym = indices[i]
58+
if types[i] == Void
59+
body = quote
60+
$idxsym = I[$i]
61+
$body
62+
end
63+
else
64+
rsym = Symbol("r_$i")
65+
body = quote
66+
$(rsym) = range[$i]
67+
for $idxsym in UInt32(first($rsym)):UInt32(last($rsym))
68+
$body
69+
end
70+
end
71+
end
72+
body
73+
end
74+
quote
75+
I = @cartesianidx R state
76+
$body
77+
return
6078
end
61-
R[ilin] = accum
62-
return
6379
end
80+
6481
function Base._mapreducedim!(f, op, R::GPUArray, A::GPUArray)
65-
sizeR = size(R)
66-
if all(x-> x == 1, sizeR)
67-
x = mapreduce(f, op, A)
68-
copy!(R, reshape([x], sizeR))
69-
return R
70-
end
71-
@assert count(x-> x == 1, sizeR) == (ndims(R) - 1) "Not implemented"
72-
dim = findfirst(x-> x == 1, sizeR)
73-
slice_size = size(A, dim)
74-
gpu_call(mapreducedim_kernel, R, (f, op, R, A, UInt32(slice_size), UInt32.(size(A)), UInt32(dim)))
82+
range = ifelse.(length.(indices(R)) .== 1, indices(A), nothing)
83+
gpu_call(mapreducedim_kernel, R, (f, op, R, A, range))
7584
return R
7685
end
7786

src/testsuite/mapreduce.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ function run_mapreduce(Typ)
1313
x = T(y)
1414
@test sum(y, 2) Array(sum(x, 2))
1515
@test sum(y, 1) Array(sum(x, 1))
16+
@test sum(y, (1, 2)) Array(sum(x, (1, 2)))
1617

1718
y = rand(range, N, 10)
1819
x = T(y)

0 commit comments

Comments
 (0)