Skip to content

Commit ccfdc13

Browse files
authored
Merge pull request #67 from JuliaGaussianProcesses/fix-flatten-tuple-type-instability
2 parents 3d496be + f5bdf67 commit ccfdc13

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ParameterHandling"
22
uuid = "2412ca09-6db7-441c-8e3a-88d5709968c5"
33
authors = ["Invenia Technical Computing Corporation"]
4-
version = "0.4.8"
4+
version = "0.4.9"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/flatten.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,20 @@ function flatten(::Type{T}, x::SparseMatrixCSC) where {T<:Real}
7979
end
8080

8181
function flatten(::Type{T}, x::Tuple) where {T<:Real}
82-
x_vecs_and_backs = map(val -> flatten(T, val), x)
83-
x_vecs, x_backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
84-
lengths = map(length, x_vecs)
85-
sz = _cumsum(lengths)
82+
vec1, back1 = flatten(T, first(x))
83+
vec2, back2 = flatten(T, Base.tail(x))
84+
l1 = length(vec1)
85+
l2 = length(vec2)
8686
function unflatten_to_Tuple(v::Vector{T})
87-
map(x_backs, lengths, sz) do x_back, l, s
88-
return x_back(v[(s - l + 1):s])
89-
end
87+
return (back1(v[1:l1]), back2(v[(l1 + 1):(l1 + l2)])...)
9088
end
91-
return reduce(vcat, x_vecs), unflatten_to_Tuple
89+
return vcat(vec1, vec2), unflatten_to_Tuple
90+
end
91+
92+
function flatten(::Type{T}, x::Tuple{}) where {T<:Real}
93+
v = T[]
94+
unflatten_to_empty_Tuple(::Vector{T}) = x
95+
return v, unflatten_to_empty_Tuple
9296
end
9397

9498
function flatten(::Type{T}, x::NamedTuple) where {T<:Real}

test/flatten.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@
3939
test_flatten_interface((1.0, 2.0); check_inferred=tuple_infers)
4040

4141
test_flatten_interface((1.0, (2.0, 3.0), randn(5)); check_inferred=tuple_infers)
42+
43+
# Prevent regression of PR #67
44+
@testset "Type stability of unflatten" begin
45+
θ = (1.0, ((2.0, 3.0), 4.0))
46+
x, unflatten = flatten(θ)
47+
@test (@inferred unflatten(x)) == θ
48+
end
4249
end
4350

4451
@testset "NamedTuple" begin

0 commit comments

Comments
 (0)