Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
485a525
Replace `evaluate_and_sample!!` -> `init!!`
penelopeysm Jul 10, 2025
7a05ec5
Use `ParamsInit` for `predict`; remove `setval_and_resample!` and fri…
penelopeysm Jul 10, 2025
b00e284
Use `init!!` for initialisation
penelopeysm Jul 10, 2025
5ed975c
Paper over the `Sampling->Init` context stack (pending removal of Sam…
penelopeysm Jul 10, 2025
2706239
Remove SamplingContext from JETExt to avoid triggering `Sampling->Ini…
penelopeysm Jul 10, 2025
84e5e55
Remove `predict` on vector of VarInfo
penelopeysm Jul 26, 2025
7f188b9
Fix some tests
penelopeysm Jul 20, 2025
f7ac1b1
Remove duplicated test
penelopeysm Jul 20, 2025
2041927
Simplify context testing
penelopeysm Aug 10, 2025
d9292ad
Rename FooInit -> InitFromFoo
penelopeysm Aug 13, 2025
70bb2c4
Fix JETExt
penelopeysm Aug 13, 2025
bc04355
Fix JETExt properly
penelopeysm Aug 13, 2025
2cfc297
Fix tests
penelopeysm Aug 13, 2025
891b4b3
Improve comments
penelopeysm Aug 13, 2025
3bb7ade
Remove duplicated tests
penelopeysm Aug 13, 2025
1bdb76e
Merge branch 'breaking' into py/actually-use-init
penelopeysm Sep 15, 2025
39b958d
Docstring improvements
penelopeysm Sep 16, 2025
907d24f
Concretise `chain_sample_to_varname_dict` using chain value type
penelopeysm Sep 16, 2025
07946a7
Clarify testset name
penelopeysm Sep 16, 2025
afdb173
Re-add comment that shouldn't have vanished
penelopeysm Sep 16, 2025
c641923
Fix stale Requires dep
penelopeysm Sep 16, 2025
956ed54
Fix default_varinfo/initialisation for odd models
penelopeysm Sep 17, 2025
8d13c30
Add comment to src/sampler.jl
penelopeysm Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

Expand Down Expand Up @@ -71,7 +70,6 @@ Mooncake = "0.4.147"
OrderedCollections = "1"
Printf = "1.10"
Random = "1.6"
Requires = "1"
Statistics = "1"
Test = "1.6"
julia = "1.10.8"
6 changes: 4 additions & 2 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
return keys(c.info.varname_to_symbol)
end

function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx)
function chain_sample_to_varname_dict(
c::MCMCChains.Chains{Tval}, sample_idx, chain_idx
) where {Tval}
_check_varname_indexing(c)
d = Dict{DynamicPPL.VarName,Any}()
d = Dict{DynamicPPL.VarName,Tval}()
for vn in DynamicPPL.varnames(c)
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
end
Expand Down
4 changes: 0 additions & 4 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,6 @@ include("test_utils.jl")
include("experimental.jl")
include("deprecated.jl")

if !isdefined(Base, :get_extension)
using Requires
end

# Better error message if users forget to load JET
if isdefined(Base.Experimental, :register_error_hint)
function __init__()
Expand Down
4 changes: 2 additions & 2 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -873,10 +873,10 @@ end

Evaluate the `model` and replace the values of the model's random variables
in the given `varinfo` with new values, using a specified initialisation strategy.
If the values in `varinfo` are not set, they will be added.
If the values in `varinfo` are not set, they will be added
using a specified initialisation strategy.

If `init_strategy` is not provided, defaults to InitFromPrior().
If `init_strategy` is not provided, defaults to `InitFromPrior()`.

Returns a tuple of the model's return value, plus the updated `varinfo` object.
"""
Expand Down
51 changes: 36 additions & 15 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Generic sampler type for inference algorithms of type `T` in DynamicPPL.
provided that supports resuming sampling from a previous state and setting initial
parameter values. It requires to overload [`loadstate`](@ref) and [`initialstep`](@ref)
for loading previous states and actually performing the initial sampling step,
respectively. Additionally, sometimes one might want to implement [`init_strategy`](@ref)
respectively. Additionally, sometimes one might want to implement an [`init_strategy`](@ref)
that specifies how the initial parameter values are sampled if they are not provided.
By default, values are sampled from the prior.
"""
Expand All @@ -68,7 +68,7 @@ end

Return a default varinfo object for the given `model` and `sampler`.

The default method for this returns an empty NTVarInfo (i.e. 'typed varinfo').
The default method for this returns a NTVarInfo (i.e. 'typed varinfo').

# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
Expand All @@ -78,12 +78,24 @@ The default method for this returns an empty NTVarInfo (i.e. 'typed varinfo').
# Returns
- `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`.
"""
function default_varinfo(::Random.AbstractRNG, ::Model, ::AbstractSampler)
# Note that variable values are unconditionally initialized later, so no
# point putting them in now.
return typed_varinfo(VarInfo())
function default_varinfo(rng::Random.AbstractRNG, model::Model, ::AbstractSampler)
# Note that in `AbstractMCMC.step`, the values in the varinfo returned here are
# immediately overwritten by a subsequent call to `init!!`. The reason why we
# _do_ create a varinfo with parameters here (as opposed to simply returning
# an empty `typed_varinfo(VarInfo())`) is to avoid issues where pushing to an empty
# typed VarInfo would fail. This can happen if two VarNames have different types
# but share the same symbol (e.g. `x.a` and `x.b`).
return typed_varinfo(VarInfo(rng, model))
end

"""
init_strategy(sampler)

Define the initialisation strategy used for generating initial values when
sampling with `sampler`. Defaults to `InitFromPrior()`, but can be overridden.
"""
init_strategy(::Sampler) = InitFromPrior()

function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::Model,
Expand All @@ -99,13 +111,22 @@ function AbstractMCMC.sample(
)
end

"""
init_strategy(sampler)

Define the initialisation strategy used for generating initial values when
sampling with `sampler`. Defaults to `InitFromPrior()`, but can be overridden.
"""
init_strategy(::Sampler) = InitFromPrior()
function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::Model,
sampler::Sampler,
parallel::AbstractMCMC.AbstractMCMCEnsemble,
N::Integer,
nchains::Integer;
chain_type=default_chain_type(sampler),
resume_from=nothing,
initial_state=loadstate(resume_from),
kwargs...,
)
return AbstractMCMC.mcmcsample(
rng, model, sampler, parallel, N, nchains; chain_type, initial_state, kwargs...
)
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
Expand All @@ -114,8 +135,8 @@ function AbstractMCMC.step(
initial_params::AbstractInitStrategy=init_strategy(spl),
kwargs...,
)
# Generate the default varinfo (usually this just makes an empty VarInfo
# with NamedTuple of Metadata).
# Generate the default varinfo. Note that any parameters inside this varinfo
# will be immediately overwritten by the next call to `init!!`.
vi = default_varinfo(rng, model, spl)

# Fill it with initial parameters. Note that, if `InitFromParams` is used, the
Expand Down
2 changes: 1 addition & 1 deletion src/test_utils/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ end
function test_parent_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model)
@test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent

@testset "{set,}{leaf,child}context" begin
@testset "get/set leaf and child contexts" begin
# Ensure we're using a different leaf context than the current.
leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext
DynamicPPL.DynamicTransformationContext{false}()
Expand Down
3 changes: 3 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ module Issue537 end
varinfo = VarInfo(model)
@test getlogjoint(varinfo) == lp
@test varinfo_ isa AbstractVarInfo
# During the model evaluation, its leaf context is changed to an InitContext, so
# `model_` is not going to be equal to `model`. We can still check equality of `f`
# though.
@test model_.f === model.f
@test model_.context isa DynamicPPL.InitContext
@test model_.context.rng isa Random.AbstractRNG
Expand Down
13 changes: 13 additions & 0 deletions test/sampler.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
@testset "sampler.jl" begin
@testset "varnames with same symbol but different type" begin
struct S <: AbstractMCMC.AbstractSampler end
DynamicPPL.initialstep(rng, model, ::DynamicPPL.Sampler{S}, vi; kwargs...) = vi
@model function g()
y = (; a=1, b=2)
y.a ~ Normal()
return y.b ~ Normal()
end
model = g()
spl = DynamicPPL.Sampler(S())
@test AbstractMCMC.step(Xoshiro(468), g(), spl) isa Any
end

@testset "initial_state and resume_from kwargs" begin
# Model is unused, but has to be a DynamicPPL.Model otherwise we won't hit our
# overloaded method.
Expand Down
Loading