Skip to content

Commit 6437801

Browse files
committed
Make unflatten more type stable, and add a test for it
1 parent 048178b commit 6437801

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

src/simple_varinfo.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,12 @@ function unflatten(svi::SimpleVarInfo, x::AbstractVector)
264264
vals = unflatten(svi.values, x)
265265
# TODO(mhauru) See comment in unflatten in src/varinfo.jl for why this conversion is
266266
# required but undesireable.
267-
T = float_type_with_fallback(eltype(x))
268-
accs = map(acc -> convert_eltype(T, acc), getaccs(svi))
267+
# The below line is finicky for type stability. For instance, assigning the eltype to
268+
# convert to into an intermediate variable makes this unstable (constant propagation)
269+
# fails. Take care when editing.
270+
accs = map(
271+
acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), getaccs(svi)
272+
)
269273
return SimpleVarInfo(vals, accs, svi.transformation)
270274
end
271275

src/varinfo.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,8 +444,13 @@ function unflatten(vi::VarInfo, x::AbstractVector)
444444
# element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here
445445
# messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just
446446
# plain ugly and hacky.
447-
T = float_type_with_fallback(eltype(x))
448-
accs = map(acc -> convert_eltype(T, acc), deepcopy(getaccs(vi)))
447+
# The below line is finicky for type stability. For instance, assigning the eltype to
448+
# convert to into an intermediate variable makes this unstable (constant propagation)
449+
# fails. Take care when editing.
450+
accs = map(
451+
acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc),
452+
deepcopy(getaccs(vi)),
453+
)
449454
return VarInfo(md, accs)
450455
end
451456

test/varinfo.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,29 @@ end
732732
end
733733
end
734734

735+
@testset "unflatten type stability" begin
736+
@model function demo(y)
737+
x ~ Normal()
738+
y ~ Normal(x, 1)
739+
return nothing
740+
end
741+
742+
model = demo(0.0)
743+
varinfos = DynamicPPL.TestUtils.setup_varinfos(
744+
model, (; x=1.0), (@varname(x),); include_threadsafe=true
745+
)
746+
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
747+
# Skip the severely inconcrete `SimpleVarInfo` types, since checking for type
748+
# stability for them doesn't make much sense anyway.
749+
if varinfo isa SimpleVarInfo{OrderedDict{Any,Any}} ||
750+
varinfo isa
751+
DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{OrderedDict{Any,Any}}}
752+
continue
753+
end
754+
@inferred DynamicPPL.unflatten(varinfo, varinfo[:])
755+
end
756+
end
757+
735758
@testset "subset" begin
736759
@model function demo_subsetting_varinfo(::Type{TV}=Vector{Float64}) where {TV}
737760
s ~ InverseGamma(2, 3)

0 commit comments

Comments
 (0)