@@ -48,30 +48,39 @@ function Base.mapreduce(f, op, v0, A::GPUArray)
48
48
acc_mapreduce (f, op, v0, A, ())
49
49
end
50
50
51
-
52
- function mapreducedim_kernel (state, f, op, R:: AbstractArray{T1, N} , A:: AbstractArray{T, N} , slice_size, sizeA, dim) where {T1, T, N}
53
- ilin = UInt32 (linear_index (state))
54
- ilin > length (R) && return
55
- accum = zero (T1)
56
- @inbounds for i = UInt32 (1 ): slice_size
57
- idx = N == dim ? (ilin, i) : (i, ilin)
58
- i2d = gpu_sub2ind (sizeA, idx)
59
- accum = op (accum, f (A[i2d]))
51
+ @generated function mapreducedim_kernel (state, f, op, R, A, range:: NTuple{N, Any} ) where N
52
+ types = (range. parameters... ,)
53
+ indices = ntuple (i-> Symbol (" I_$i " ), N)
54
+ Iexpr = ntuple (i-> :(I[$ i]), N)
55
+ body = :(@inbounds R[$ (Iexpr... )] = op (R[$ (Iexpr... )], f (A[$ (indices... )])))
56
+ for i = N: - 1 : 1
57
+ idxsym = indices[i]
58
+ if types[i] == Void
59
+ body = quote
60
+ $ idxsym = I[$ i]
61
+ $ body
62
+ end
63
+ else
64
+ rsym = Symbol (" r_$i " )
65
+ body = quote
66
+ $ (rsym) = range[$ i]
67
+ for $ idxsym in UInt32 (first ($ rsym)): UInt32 (last ($ rsym))
68
+ $ body
69
+ end
70
+ end
71
+ end
72
+ body
73
+ end
74
+ quote
75
+ I = @cartesianidx R state
76
+ $ body
77
+ return
60
78
end
61
- R[ilin] = accum
62
- return
63
79
end
80
+
64
81
function Base. _mapreducedim! (f, op, R:: GPUArray , A:: GPUArray )
65
- sizeR = size (R)
66
- if all (x-> x == 1 , sizeR)
67
- x = mapreduce (f, op, A)
68
- copy! (R, reshape ([x], sizeR))
69
- return R
70
- end
71
- @assert count (x-> x == 1 , sizeR) == (ndims (R) - 1 ) " Not implemented"
72
- dim = findfirst (x-> x == 1 , sizeR)
73
- slice_size = size (A, dim)
74
- gpu_call (mapreducedim_kernel, R, (f, op, R, A, UInt32 (slice_size), UInt32 .(size (A)), UInt32 (dim)))
82
+ range = ifelse .(length .(indices (R)) .== 1 , indices (A), nothing )
83
+ gpu_call (mapreducedim_kernel, R, (f, op, R, A, range))
75
84
return R
76
85
end
77
86
0 commit comments