Skip to content

Commit cd9a975

Browse files
fix mapreduce recursion
1 parent 2094a78 commit cd9a975

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

src/array_partition.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ Base.:(==)(A::ArrayPartition, B::ArrayPartition) = A.x == B.x
166166

167167
Base.map(f, A::ArrayPartition) = ArrayPartition(map(x -> map(f, x), A.x))
168168
function Base.mapreduce(f, op, A::ArrayPartition{T}; kwargs...) where {T}
169-
mapreduce(f, op, (i for i in A); kwargs...)
169+
mapreduce(x->mapreduce(f, op, x; kwargs...), op, (i for i in A.x); kwargs...)
170170
end
171171
Base.filter(f, A::ArrayPartition) = ArrayPartition(map(x -> filter(f, x), A.x))
172172
Base.any(f, A::ArrayPartition) = any((any(f, x) for x in A.x))

test/gpu/arraypartition_gpu.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,7 @@ a = ArrayPartition(([1.0f0] |> cu, [2.0f0] |> cu, [3.0f0] |> cu))
2222
b = ArrayPartition(([0.0f0] |> cu, [0.0f0] |> cu, [0.0f0] |> cu))
2323
@. a + b
2424

25-
@test ArrayInterface.zeromatrix(ArrayPartition((CUDA.zeros(2),CUDA.zeros(2)))) isa CuMatrix
26-
@test size(ArrayInterface.zeromatrix(ArrayPartition((CUDA.zeros(2),CUDA.zeros(2))))) == (4,4)
25+
x = ArrayPartition((CUDA.zeros(2),CUDA.zeros(2)))
26+
@test ArrayInterface.zeromatrix(x) isa CuMatrix
27+
@test size(ArrayInterface.zeromatrix(x)) == (4,4)
28+
@test maximum(abs, x) == 0f0

0 commit comments

Comments
 (0)