From ef6522f8e32c2739a495de758bcec1b377cf8017 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 25 May 2025 19:46:09 +0100 Subject: [PATCH 1/4] Remove initialstep, rework default_varinfo --- HISTORY.md | 6 ++++++ Project.toml | 2 +- src/sampler.jl | 40 ++++++++++------------------------------ 3 files changed, 17 insertions(+), 31 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 40a671dc1..09eb1df76 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,11 @@ # DynamicPPL Changelog +## 0.36.9 + +Removed the `DynamicPPL.initialstep` method. This method was unexported. If you were relying on this, you should directly use `AbstractMCMC.step`. + +`DynamicPPL.default_varinfo` now takes two additional optional arguments, `initial_params` (an AbstractVector or nothing) and `link` (a Bool). These are used to generate the initial varinfo. + ## 0.36.8 Made `LogDensityFunction` a subtype of `AbstractMCMC.AbstractModel`. diff --git a/Project.toml b/Project.toml index 2fc1d984c..fd5d20c9b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.36.8" +version = "0.36.9" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/sampler.jl b/src/sampler.jl index 49d910fec..88327c1d0 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -71,45 +71,25 @@ Return a default varinfo object for the given `model` and `sampler`. - `rng::Random.AbstractRNG`: Random number generator. - `model::Model`: Model for which we want to create a varinfo object. - `sampler::AbstractSampler`: Sampler which will make use of the varinfo object. -- `context::AbstractContext`: Context in which the model is evaluated. +- `initial_params::Union{AbstractVector,Nothing}`: Initial parameter values to +be set in the varinfo object. +- `link::Bool`: Whether to link the varinfo. +- `context::AbstractContext`: Context in which the model is evaluated. Defaults +to `DefaultContext()`. # Returns - `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`. """ -function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler) - return default_varinfo(rng, model, sampler, DefaultContext()) -end function default_varinfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler, - context::AbstractContext, + initial_params::Union{AbstractVector,Nothing}=nothing, + link::Bool=false, + context::AbstractContext=DefaultContext(), ) init_sampler = initialsampler(sampler) - return typed_varinfo(rng, model, init_sampler, context) -end - -function AbstractMCMC.sample( - rng::Random.AbstractRNG, - model::Model, - sampler::Sampler, - N::Integer; - chain_type=default_chain_type(sampler), - resume_from=nothing, - initial_state=loadstate(resume_from), - kwargs..., -) - return AbstractMCMC.mcmcsample( - rng, model, sampler, N; chain_type, initial_state, kwargs... - ) -end - -# initial step: general interface for resuming and -function AbstractMCMC.step( - rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs... -) - # Sample initial values. - vi = default_varinfo(rng, model, spl) + vi = typed_varinfo(rng, model, init_sampler, context) # Update the parameters if provided. if initial_params !== nothing @@ -122,7 +102,7 @@ function AbstractMCMC.step( vi = last(evaluate!!(model, vi, DefaultContext())) end - return initialstep(rng, model, spl, vi; initial_params, kwargs...) + return vi end """ From 1b87b4896ccef6785b43ec140fc08a939fe3f8d8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 25 May 2025 19:59:54 +0100 Subject: [PATCH 2/4] Update docs --- docs/src/api.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 08522e2ce..bb1db81d9 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -458,7 +458,10 @@ DynamicPPL.loadstate DynamicPPL.initialsampler ``` -Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`. +Finally, [`DynamicPPL.default_varinfo`](@ref) specifies how an initial varinfo should be generated when sampling from a given model with a given sampler. +The default behaviour creates a typed varinfo, inserts the initial parameters (if specified), and links the varinfo if requested. + +Overriding this can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`. ```@docs DynamicPPL.default_varinfo From 54ab40648f4337c835a8f71445f2b1d914184336 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 27 May 2025 11:35:36 +0100 Subject: [PATCH 3/4] Fix tests --- docs/src/api.md | 1 - src/DynamicPPL.jl | 2 +- src/sampler.jl | 28 +++++++++------------------- test/sampler.jl | 9 +++++---- 4 files changed, 15 insertions(+), 25 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index bb1db81d9..6766f4de3 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -453,7 +453,6 @@ Sampler The default implementation of [`Sampler`](@ref) uses the following unexported functions. ```@docs -DynamicPPL.initialstep DynamicPPL.loadstate DynamicPPL.initialsampler ``` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 21f9044cd..63654d1d5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -161,7 +161,6 @@ abstract type AbstractVarInfo <: AbstractModelTrace end include("utils.jl") include("chains.jl") include("model.jl") -include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") @@ -176,6 +175,7 @@ include("pointwise_logdensities.jl") include("submodel_macro.jl") include("transforming.jl") include("logdensityfunction.jl") +include("sampler.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") diff --git a/src/sampler.jl b/src/sampler.jl index 88327c1d0..faebfb575 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -38,12 +38,12 @@ Generic sampler type for inference algorithms of type `T` in DynamicPPL. `Sampler` should implement the AbstractMCMC interface, and in particular `AbstractMCMC.step`. A default implementation of the initial sampling step is -provided that supports resuming sampling from a previous state and setting initial -parameter values. It requires to overload [`loadstate`](@ref) and [`initialstep`](@ref) -for loading previous states and actually performing the initial sampling step, -respectively. Additionally, sometimes one might want to implement [`initialsampler`](@ref) -that specifies how the initial parameter values are sampled if they are not provided. -By default, values are sampled from the prior. +provided that supports resuming sampling from a previous state and setting +initial parameter values. It requires you to to overload [`loadstate`](@ref) +for loading previous states. Additionally, sometimes one might want to +implement [`initialsampler`](@ref) that specifies how the initial parameter +values are sampled if they are not provided. By default, values are sampled +from the prior. """ struct Sampler{T} <: AbstractSampler alg::T @@ -52,13 +52,13 @@ end # AbstractMCMC interface for SampleFromUniform and SampleFromPrior function AbstractMCMC.step( rng::Random.AbstractRNG, - model::Model, + ldf::LogDensityFunction, sampler::Union{SampleFromUniform,SampleFromPrior}, state=nothing; kwargs..., ) - vi = VarInfo() - model(rng, vi, sampler) + ctx = SamplingContext(rng, sampler) + _, vi = DynamicPPL.evaluate!!(ldf.model, ldf.varinfo, ctx) return vi, nothing end @@ -223,13 +223,3 @@ function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Mod return vi end - -""" - initialstep(rng, model, sampler, varinfo; kwargs...) - -Perform the initial sampling step of the `sampler` for the `model`. - -The `varinfo` contains the initial samples, which can be provided by the user or -sampled randomly. -""" -function initialstep end diff --git a/test/sampler.jl b/test/sampler.jl index 8c4f1ed96..f6172dd0e 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -58,14 +58,15 @@ abstract type OnlyInitAlg end struct OnlyInitAlgDefault <: OnlyInitAlg end struct OnlyInitAlgUniform <: OnlyInitAlg end - function DynamicPPL.initialstep( + + function AbstractMCMC.step( rng::Random.AbstractRNG, - model::Model, + ldf::DynamicPPL.LogDensityFunction, ::Sampler{<:OnlyInitAlg}, - vi::AbstractVarInfo; + state=nothing; kwargs..., ) - return vi, nothing + return ldf.vi, nothing end # initial samplers From d1cf5d63c43bbf7def4d9bf8ca05f6a1e6c06596 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 28 May 2025 20:41:02 +0100 Subject: [PATCH 4/4] Fix formatting --- src/DynamicPPL.jl | 4 ++-- src/sampler.jl | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 63654d1d5..bd9ae3617 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -166,6 +166,8 @@ include("distribution_wrappers.jl") include("contexts.jl") include("varnamedvector.jl") include("abstract_varinfo.jl") +include("logdensityfunction.jl") +include("sampler.jl") include("threadsafe.jl") include("varinfo.jl") include("simple_varinfo.jl") @@ -174,8 +176,6 @@ include("compiler.jl") include("pointwise_logdensities.jl") include("submodel_macro.jl") include("transforming.jl") -include("logdensityfunction.jl") -include("sampler.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") diff --git a/src/sampler.jl b/src/sampler.jl index faebfb575..8a37c9f41 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -223,3 +223,8 @@ function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Mod return vi end + +# TODO: Get rid of this +function initialstep(args...; kwargs...) + return error("no initialstep") +end