Skip to content

Commit 3f956ef

Browse files
recursive bottoming
1 parent 30698d2 commit 3f956ef

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

src/RecursiveArrayTools.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ module RecursiveArrayTools
1515
vecarr_to_arr, vecarr_to_vectors, tuples
1616

1717
export recursivecopy, recursivecopy!, vecvecapply, copyat_or_push!,
18-
vecvec_to_mat, recursive_one, recursive_mean, recursive_eltype
18+
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
19+
recursive_unitless_bottom_eltype, recursive_unitless_eltype
1920

2021
export ArrayPartition
2122

src/utils.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,16 @@ end
9292
recursive_one(a) = recursive_one(a[1])
9393
recursive_one(a::T) where {T<:Number} = one(a)
9494

95-
recursive_eltype(a) = recursive_eltype(eltype(a))
96-
recursive_eltype(a::Type{T}) where {T<:Number} = eltype(a)
95+
recursive_bottom_eltype(a) = recursive_bottom_eltype(eltype(a))
96+
recursive_bottom_eltype(a::Type{T}) where {T<:Number} = eltype(a)
97+
98+
recursive_unitless_bottom_eltype(a) = recursive_unitless_bottom_eltype(eltype(a))
99+
recursive_unitless_bottom_eltype(a::Type{T}) where {T<:Number} = typeof(one(eltype(a)))
100+
101+
Base.@pure recursive_unitless_eltype(a) = recursive_unitless_eltype(eltype(a))
102+
Base.@pure recursive_unitless_eltype{T<:StaticArray}(a::Type{T}) = similar_type(a,recursive_unitless_eltype(eltype(a)))
103+
Base.@pure recursive_unitless_eltype{T<:Array}(a::Type{T}) = Array{recursive_unitless_eltype(eltype(a)),ndims(a)}
104+
Base.@pure recursive_unitless_eltype{T<:Number}(a::Type{T}) = typeof(one(eltype(a)))
97105

98106
recursive_mean(x...) = mean(x...)
99107
function recursive_mean(vecvec::Vector{T}) where T<:AbstractArray

test/utils_test.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using OrdinaryDiffEq, ParameterizedFunctions,
2-
DiffEqBase, RecursiveArrayTools
2+
DiffEqBase, RecursiveArrayTools, Unitful, StaticArrays
33
using Base.Test
44

55
# Here's the problem to solve
@@ -35,3 +35,18 @@ ans = [[1 2; 3 4],[1 4; 4 4.5],[5 7; 4.5 4.5]]
3535
ans = [[2.333333333333 4.666666666666; 3.6666666666666 6.0], [2.3333333 3.0; 5.0 2.6666666]]
3636
@test recursive_mean(B,2)[1] ans[1]
3737
@test recursive_mean(B,2)[2] ans[2]
38+
39+
A = zeros(5,5)
40+
recursive_unitless_eltype(A) == Float64
41+
A = zeros(5,5)*1u"kg"
42+
recursive_unitless_eltype(A) == Float64
43+
AA = [zeros(5,5) for i in 1:5]
44+
recursive_unitless_eltype(AA) == Array{Float64,2}
45+
AofA = [copy(A) for i in 1:5]
46+
recursive_unitless_eltype(AofA) == Array{Float64,2}
47+
AofSA = [@SVector [2.0,3.0] for i in 1:5]
48+
recursive_unitless_eltype(AofSA) == SVector{2,Float64}
49+
AofuSA = [@SVector [2.0u"kg",3.0u"kg"] for i in 1:5]
50+
recursive_unitless_eltype(AofuSA) == SVector{2,Float64}
51+
52+
@inferred recursive_unitless_eltype(AofuSA)

0 commit comments

Comments
 (0)