Skip to content

Commit d5b657c

Browse files
committed
Define recursive_eltype
1 parent 5352af9 commit d5b657c

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

src/array_partition.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,12 @@ function recursivecopy!(A::ArrayPartition, B::ArrayPartition)
182182
end
183183
end
184184

185-
recursive_one(A::ArrayPartition) = recursive_one(first(A.x))
186-
187185
recursive_mean(A::ArrayPartition) = mean((recursive_mean(x) for x in A.x))
188186

187+
# note: consider only first partition for recursive one and eltype
188+
recursive_one(A::ArrayPartition) = recursive_one(first(A.x))
189+
recursive_eltype(A::ArrayPartition) = recursive_eltype(first(A.x))
190+
189191
## iteration
190192

191193
Base.start(A::ArrayPartition) = start(Chain(A.x))

test/partitions_test.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ x = ArrayPartition([1, 2], [3.0, 4.0])
6060
@inferred first(x)
6161
@inferred last(x)
6262

63+
# recursive
64+
@inferred recursive_mean(x)
65+
@inferred recursive_one(x)
66+
@inferred recursive_eltype(x)
67+
6368
# broadcasting
6469
_scalar_op(y) = y + 1
6570
# Can't do `@inferred(_scalar_op.(x))` so we wrap that in a function:

0 commit comments

Comments
 (0)