Skip to content

Commit 956ed54

Browse files
committed
Fix default_varinfo/initialisation for odd models
1 parent c641923 commit 956ed54

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

src/sampler.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ end
6868
6969
Return a default varinfo object for the given `model` and `sampler`.
7070
71-
The default method for this returns an empty NTVarInfo (i.e. 'typed varinfo').
71+
The default method for this returns a NTVarInfo (i.e. 'typed varinfo').
7272
7373
# Arguments
7474
- `rng::Random.AbstractRNG`: Random number generator.
@@ -78,10 +78,14 @@ The default method for this returns an empty NTVarInfo (i.e. 'typed varinfo').
7878
# Returns
7979
- `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`.
8080
"""
81-
function default_varinfo(::Random.AbstractRNG, ::Model, ::AbstractSampler)
82-
# Note that variable values are unconditionally initialized later, so no
83-
# point putting them in now.
84-
return typed_varinfo(VarInfo())
81+
function default_varinfo(rng::Random.AbstractRNG, model::Model, ::AbstractSampler)
82+
# Note that in `AbstractMCMC.step`, the values in the varinfo returned here are
83+
# immediately overwritten by a subsequent call to `init!!`. The reason why we
84+
# _do_ create a varinfo with parameters here (as opposed to simply returning
85+
# an empty `typed_varinfo(VarInfo())`) is to avoid issues where pushing to an empty
86+
# typed VarInfo would fail. This can happen if two VarNames have different types
87+
# but share the same symbol (e.g. `x.a` and `x.b`).
88+
return typed_varinfo(VarInfo(rng, model))
8589
end
8690

8791
"""
@@ -131,8 +135,8 @@ function AbstractMCMC.step(
131135
initial_params::AbstractInitStrategy=init_strategy(spl),
132136
kwargs...,
133137
)
134-
# Generate the default varinfo (usually this just makes an empty VarInfo
135-
# with NamedTuple of Metadata).
138+
# Generate the default varinfo. Note that any parameters inside this varinfo
139+
# will be immediately overwritten by the next call to `init!!`.
136140
vi = default_varinfo(rng, model, spl)
137141

138142
# Fill it with initial parameters. Note that, if `InitFromParams` is used, the

test/sampler.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,17 @@
11
@testset "sampler.jl" begin
2+
@testset "varnames with same symbol but different type" begin
3+
struct S <: AbstractMCMC.AbstractSampler end
4+
DynamicPPL.initialstep(rng, model, ::DynamicPPL.Sampler{S}, vi; kwargs...) = vi
5+
@model function g()
6+
y = (; a=1, b=2)
7+
y.a ~ Normal()
8+
return y.b ~ Normal()
9+
end
10+
model = g()
11+
spl = DynamicPPL.Sampler(S())
12+
@test AbstractMCMC.step(Xoshiro(468), g(), spl) isa Any
13+
end
14+
215
@testset "initial_state and resume_from kwargs" begin
316
# Model is unused, but has to be a DynamicPPL.Model otherwise we won't hit our
417
# overloaded method.

0 commit comments

Comments
 (0)