Skip to content

Commit 873a072

Browse files
fix: ArrayPartition mapreduce type inference, add tests
1 parent 309bf97 commit 873a072

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

src/array_partition.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ Base.:(==)(A::ArrayPartition, B::ArrayPartition) = A.x == B.x
165165
## Iterable Collection Constructs
166166

167167
Base.map(f, A::ArrayPartition) = ArrayPartition(map(x -> map(f, x), A.x))
168-
function Base.mapreduce(f, op, A::ArrayPartition)
169-
mapreduce(f, op, (mapreduce(f, op, x) for x in A.x))
168+
function Base.mapreduce(f, op, A::ArrayPartition{T}; kwargs...) where {T}
169+
mapreduce(f, op, (i for i in A); kwargs...)
170170
end
171171
Base.filter(f, A::ArrayPartition) = ArrayPartition(map(x -> filter(f, x), A.x))
172172
Base.any(f, A::ArrayPartition) = any(f, (any(f, x) for x in A.x))

test/partitions_test.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ x = ArrayPartition([1, 2], [3.0, 4.0])
104104
@inferred recursive_one(x)
105105
@inferred recursive_bottom_eltype(x)
106106

107+
# mapreduce
108+
@inferred Union{Int, Float64} sum(x)
109+
@inferred sum(ArrayPartition(ArrayPartition(zeros(4,4))))
110+
@inferred sum(ArrayPartition(ArrayPartition(zeros(4))))
111+
@inferred sum(ArrayPartition(zeros(4,4)))
112+
@inferred mapreduce(string, *, x)
113+
@test mapreduce(i -> string(i) * "q", *, x) == "1q2q3.0q4.0q"
114+
107115
# broadcasting
108116
_scalar_op(y) = y + 1
109117
# Can't do `@inferred(_scalar_op.(x))` so we wrap that in a function:

0 commit comments

Comments
 (0)