Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# DynamicPPL Changelog

## 0.36.9

Removed the `DynamicPPL.initialstep` method. This method was unexported. If you were relying on this, you should directly use `AbstractMCMC.step`.

`DynamicPPL.default_varinfo` now takes two additional optional arguments, `initial_params` (an AbstractVector or nothing) and `link` (a Bool). These are used to generate the initial varinfo.

## 0.36.8

Made `LogDensityFunction` a subtype of `AbstractMCMC.AbstractModel`.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.36.8"
version = "0.36.9"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
6 changes: 4 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -453,12 +453,14 @@ Sampler
The default implementation of [`Sampler`](@ref) uses the following unexported functions.

```@docs
DynamicPPL.initialstep
DynamicPPL.loadstate
DynamicPPL.initialsampler
```

Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.
Finally, [`DynamicPPL.default_varinfo`](@ref) specifies how an initial varinfo should be generated when sampling from a given model with a given sampler.
The default behaviour creates a typed varinfo, inserts the initial parameters (if specified), and links the varinfo if requested.

Overriding this can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.

```@docs
DynamicPPL.default_varinfo
Expand Down
4 changes: 2 additions & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,13 @@ abstract type AbstractVarInfo <: AbstractModelTrace end
include("utils.jl")
include("chains.jl")
include("model.jl")
include("sampler.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("varnamedvector.jl")
include("abstract_varinfo.jl")
include("logdensityfunction.jl")
include("sampler.jl")
include("threadsafe.jl")
include("varinfo.jl")
include("simple_varinfo.jl")
Expand All @@ -175,7 +176,6 @@ include("compiler.jl")
include("pointwise_logdensities.jl")
include("submodel_macro.jl")
include("transforming.jl")
include("logdensityfunction.jl")
include("model_utils.jl")
include("extract_priors.jl")
include("values_as_in_model.jl")
Expand Down
71 changes: 23 additions & 48 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ Generic sampler type for inference algorithms of type `T` in DynamicPPL.

`Sampler` should implement the AbstractMCMC interface, and in particular
`AbstractMCMC.step`. A default implementation of the initial sampling step is
provided that supports resuming sampling from a previous state and setting initial
parameter values. It requires to overload [`loadstate`](@ref) and [`initialstep`](@ref)
for loading previous states and actually performing the initial sampling step,
respectively. Additionally, sometimes one might want to implement [`initialsampler`](@ref)
that specifies how the initial parameter values are sampled if they are not provided.
By default, values are sampled from the prior.
provided that supports resuming sampling from a previous state and setting
initial parameter values. It requires you to to overload [`loadstate`](@ref)
for loading previous states. Additionally, sometimes one might want to
implement [`initialsampler`](@ref) that specifies how the initial parameter
values are sampled if they are not provided. By default, values are sampled
from the prior.
"""
struct Sampler{T} <: AbstractSampler
alg::T
Expand All @@ -52,13 +52,13 @@ end
# AbstractMCMC interface for SampleFromUniform and SampleFromPrior
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::Model,
ldf::LogDensityFunction,
sampler::Union{SampleFromUniform,SampleFromPrior},
state=nothing;
kwargs...,
)
vi = VarInfo()
model(rng, vi, sampler)
ctx = SamplingContext(rng, sampler)
_, vi = DynamicPPL.evaluate!!(ldf.model, ldf.varinfo, ctx)
return vi, nothing
end

Expand All @@ -71,45 +71,25 @@ Return a default varinfo object for the given `model` and `sampler`.
- `rng::Random.AbstractRNG`: Random number generator.
- `model::Model`: Model for which we want to create a varinfo object.
- `sampler::AbstractSampler`: Sampler which will make use of the varinfo object.
- `context::AbstractContext`: Context in which the model is evaluated.
- `initial_params::Union{AbstractVector,Nothing}`: Initial parameter values to
be set in the varinfo object.
- `link::Bool`: Whether to link the varinfo.
- `context::AbstractContext`: Context in which the model is evaluated. Defaults
to `DefaultContext()`.

# Returns
- `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`.
"""
function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler)
return default_varinfo(rng, model, sampler, DefaultContext())
end
function default_varinfo(
rng::Random.AbstractRNG,
model::Model,
sampler::AbstractSampler,
context::AbstractContext,
initial_params::Union{AbstractVector,Nothing}=nothing,
link::Bool=false,
context::AbstractContext=DefaultContext(),
)
init_sampler = initialsampler(sampler)
return typed_varinfo(rng, model, init_sampler, context)
end

function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::Model,
sampler::Sampler,
N::Integer;
chain_type=default_chain_type(sampler),
resume_from=nothing,
initial_state=loadstate(resume_from),
kwargs...,
)
return AbstractMCMC.mcmcsample(
rng, model, sampler, N; chain_type, initial_state, kwargs...
)
end

# initial step: general interface for resuming and
function AbstractMCMC.step(
rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs...
)
# Sample initial values.
vi = default_varinfo(rng, model, spl)
vi = typed_varinfo(rng, model, init_sampler, context)

# Update the parameters if provided.
if initial_params !== nothing
Expand All @@ -122,7 +102,7 @@ function AbstractMCMC.step(
vi = last(evaluate!!(model, vi, DefaultContext()))
end

return initialstep(rng, model, spl, vi; initial_params, kwargs...)
return vi
end

"""
Expand Down Expand Up @@ -244,12 +224,7 @@ function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Mod
return vi
end

"""
initialstep(rng, model, sampler, varinfo; kwargs...)

Perform the initial sampling step of the `sampler` for the `model`.

The `varinfo` contains the initial samples, which can be provided by the user or
sampled randomly.
"""
function initialstep end
# TODO: Get rid of this
function initialstep(args...; kwargs...)
return error("no initialstep")
end
9 changes: 5 additions & 4 deletions test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,15 @@
abstract type OnlyInitAlg end
struct OnlyInitAlgDefault <: OnlyInitAlg end
struct OnlyInitAlgUniform <: OnlyInitAlg end
function DynamicPPL.initialstep(

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::Model,
ldf::DynamicPPL.LogDensityFunction,
::Sampler{<:OnlyInitAlg},
vi::AbstractVarInfo;
state=nothing;
kwargs...,
)
return vi, nothing
return ldf.vi, nothing
end

# initial samplers
Expand Down