Skip to content

Commit 27b0096

Browse files
committed
[no ci] More fixes, reexport InitFrom
1 parent 02d1d0e commit 27b0096

File tree

6 files changed

+71
-40
lines changed

6 files changed

+71
-40
lines changed

docs/src/api.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu
7575
| `RepeatSampler` | [`Turing.Inference.RepeatSampler`](@ref) | A sampler that runs multiple times on the same variable |
7676
| `externalsampler` | [`Turing.Inference.externalsampler`](@ref) | Wrap an external sampler for use in Turing |
7777

78+
### Initialisation strategies
79+
80+
Turing.jl provides several strategies to initialise parameters for models.
81+
82+
| Exported symbol | Documentation | Description |
83+
|:----------------- |:--------------------------------------- |:--------------------------------------------------------------- |
84+
| `InitFromPrior` | [`DynamicPPL.InitFromPrior`](@extref) | Obtain initial parameters from the prior distribution |
85+
| `InitFromUniform` | [`DynamicPPL.InitFromUniform`](@extref) | Obtain initial parameters by sampling uniformly in linked space |
86+
| `InitFromParams` | [`DynamicPPL.InitFromParams`](@extref) | Manually specify (possibly a subset of) initial parameters |
87+
7888
### Variational inference
7989

8090
See the [docs of AdvancedVI.jl](https://turinglang.org/AdvancedVI.jl/stable/) for detailed usage and the [variational inference tutorial](https://turinglang.org/docs/tutorials/09-variational-inference/) for a basic walkthrough.

src/Turing.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ using DynamicPPL:
7373
conditioned,
7474
to_submodel,
7575
LogDensityFunction,
76-
@addlogprob!
76+
@addlogprob!,
77+
InitFromPrior,
78+
InitFromUniform,
79+
InitFromParams
7780
using StatsBase: predict
7881
using OrderedCollections: OrderedDict
7982

@@ -148,6 +151,10 @@ export
148151
fix,
149152
unfix,
150153
OrderedDict, # OrderedCollections
154+
# Initialisation strategies for models
155+
InitFromPrior,
156+
InitFromUniform,
157+
InitFromParams,
151158
# Point estimates - Turing.Optimisation
152159
# The MAP and MLE exports are only needed for the Optim.jl interface.
153160
maximum_a_posteriori,

src/mcmc/emcee.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,16 @@ struct EmceeState{V<:AbstractVarInfo,S}
3131
states::S
3232
end
3333

34+
# Utility function to tetrieve the number of walkers
35+
_get_n_walkers(e::Emcee) = e.ensemble.n_walkers
36+
_get_n_walkers(spl::Sampler{<:Emcee}) = _get_n_walkers(spl.alg)
37+
3438
function AbstractMCMC.step(
3539
rng::Random.AbstractRNG,
3640
model::Model,
3741
spl::Sampler{<:Emcee};
3842
resume_from=nothing,
39-
initial_params=nothing,
43+
initial_params=fill(DynamicPPL.init_strategy(spl), _get_n_walkers(spl)),
4044
kwargs...,
4145
)
4246
if resume_from !== nothing
@@ -45,21 +49,19 @@ function AbstractMCMC.step(
4549
end
4650

4751
# Sample from the prior
48-
n = spl.alg.ensemble.n_walkers
52+
n = _get_n_walkers(spl)
4953
vis = [VarInfo(rng, model) for _ in 1:n]
5054

5155
# Update the parameters if provided.
52-
if initial_params !== nothing
53-
if !(
54-
initial_params isa AbstractVector{<:DynamicPPL.AbstractInitStrategy} &&
55-
length(initial_params) == n
56-
)
57-
err_msg = "initial_params for `Emcee` must be a vector of `DynamicPPL.AbstractInitStrategy`, with length equal to the number of walkers ($n)"
58-
throw(ArgumentError(err_msg))
59-
end
60-
vis = map(vis, initial_params) do vi, strategy
61-
DynamicPPL.init!!(rng, model, vi, strategy)
62-
end
56+
if !(
57+
initial_params isa AbstractVector{<:DynamicPPL.AbstractInitStrategy} &&
58+
length(initial_params) == n
59+
)
60+
err_msg = "initial_params for `Emcee` must be a vector of `DynamicPPL.AbstractInitStrategy`, with length equal to the number of walkers ($n)"
61+
throw(ArgumentError(err_msg))
62+
end
63+
vis = map(vis, initial_params) do vi, strategy
64+
last(DynamicPPL.init!!(rng, model, vi, strategy))
6365
end
6466

6567
# Compute initial transition and states.

src/mcmc/external_sampler.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -117,25 +117,25 @@ function AbstractMCMC.step(
117117
model::DynamicPPL.Model,
118118
sampler_wrapper::Sampler{<:ExternalSampler};
119119
initial_state=nothing,
120-
initial_params=nothing,
120+
initial_params=DynamicPPL.init_strategy(sampler_wrapper.alg.sampler),
121121
kwargs...,
122122
)
123123
alg = sampler_wrapper.alg
124124
sampler = alg.sampler
125125

126126
# Initialise varinfo with initial params and link the varinfo if needed.
127127
varinfo = DynamicPPL.VarInfo(model)
128+
_, varinfo = DynamicPPL.init!!(rng, model, varinfo, initial_params)
129+
128130
if requires_unconstrained_space(alg)
129-
if initial_params !== nothing
130-
# If we have initial parameters, we need to set the varinfo before linking.
131-
varinfo = DynamicPPL.link(DynamicPPL.unflatten(varinfo, initial_params), model)
132-
# Extract initial parameters in unconstrained space.
133-
initial_params = varinfo[:]
134-
else
135-
varinfo = DynamicPPL.link(varinfo, model)
136-
end
131+
varinfo = DynamicPPL.link(varinfo, model)
137132
end
138133

134+
# We need to extract the vectorised initial_params, because the later call to
135+
# AbstractMCMC.step only sees a `LogDensityModel` which expects `initial_params`
136+
# to be a vector.
137+
initial_params_vector = varinfo[:]
138+
139139
# Construct LogDensityFunction
140140
f = DynamicPPL.LogDensityFunction(
141141
model, DynamicPPL.getlogjoint_internal, varinfo; adtype=alg.adtype
@@ -144,15 +144,19 @@ function AbstractMCMC.step(
144144
# Then just call `AbstractMCMC.step` with the right arguments.
145145
if initial_state === nothing
146146
transition_inner, state_inner = AbstractMCMC.step(
147-
rng, AbstractMCMC.LogDensityModel(f), sampler; initial_params, kwargs...
147+
rng,
148+
AbstractMCMC.LogDensityModel(f),
149+
sampler;
150+
initial_params=initial_params_vector,
151+
kwargs...,
148152
)
149153
else
150154
transition_inner, state_inner = AbstractMCMC.step(
151155
rng,
152156
AbstractMCMC.LogDensityModel(f),
153157
sampler,
154158
initial_state;
155-
initial_params,
159+
initial_params=initial_params_vector,
156160
kwargs...,
157161
)
158162
end

src/mcmc/hmc.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ function find_initial_params(
146146
rng::Random.AbstractRNG,
147147
model::DynamicPPL.Model,
148148
varinfo::DynamicPPL.AbstractVarInfo,
149-
hamiltonian::AHMC.Hamiltonian;
149+
hamiltonian::AHMC.Hamiltonian,
150+
init_strategy::DynamicPPL.AbstractInitStrategy;
150151
max_attempts::Int=1000,
151152
)
152153
varinfo = deepcopy(varinfo) # Don't mutate
@@ -157,10 +158,10 @@ function find_initial_params(
157158
isfinite(z) && return varinfo, z
158159

159160
attempts == 10 &&
160-
@warn "failed to find valid initial parameters in $(attempts) tries; consider providing explicit initial parameters using the `initial_params` keyword"
161+
@warn "failed to find valid initial parameters in $(attempts) tries; consider providing a different initialisation strategy with the `initial_params` keyword"
161162

162163
# Resample and try again.
163-
varinfo = DynamicPPL.init!!(rng, model, varinfo, DynamicPPL.InitFromUniform())
164+
varinfo = DynamicPPL.init!!(rng, model, varinfo, init_strategy)
164165
end
165166

166167
# if we failed to find valid initial parameters, error
@@ -174,7 +175,9 @@ function DynamicPPL.initialstep(
174175
model::AbstractModel,
175176
spl::Sampler{<:Hamiltonian},
176177
vi_original::AbstractVarInfo;
177-
initial_params=nothing,
178+
# the initial_params kwarg is always passed on from sample(), cf. DynamicPPL
179+
# src/sampler.jl, so we don't need to provide a default value here
180+
initial_params::DynamicPPL.AbstractInitStrategy,
178181
nadapts=0,
179182
verbose::Bool=true,
180183
kwargs...,
@@ -195,13 +198,15 @@ function DynamicPPL.initialstep(
195198
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
196199
hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func)
197200

198-
# If no initial parameters are provided, resample until the log probability
199-
# and its gradient are finite. Otherwise, just use the existing parameters.
200-
vi, z = if initial_params === nothing
201-
find_initial_params(rng, model, vi, hamiltonian)
202-
else
203-
vi, AHMC.phasepoint(rng, theta, hamiltonian)
204-
end
201+
# Note that there is already one round of 'initialisation' before we reach this step,
202+
# inside DynamicPPL's `AbstractMCMC.step` implementation. That leads to a possible issue
203+
# that this `find_initial_params` function might override the parameters set by the
204+
# user.
205+
# Luckily for us, `find_initial_params` always checks if the logp and its gradient are
206+
# finite. If it is already finite with the params inside the current `vi`, it doesn't
207+
# attempt to find new ones. This means that the parameters passed to `sample()` will be
208+
# respected instead of being overridden here.
209+
vi, z = find_initial_params(rng, model, vi, hamiltonian, initial_params)
205210
theta = vi[:]
206211

207212
# Find good eps if not provided one

test/mcmc/emcee.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,21 @@ using Turing
3434
nwalkers = 250
3535
spl = Emcee(nwalkers, 2.0)
3636

37-
# No initial parameters, with im- and explicit `initial_params=nothing`
3837
Random.seed!(1234)
3938
chain1 = sample(gdemo_default, spl, 1)
4039
Random.seed!(1234)
41-
chain2 = sample(gdemo_default, spl, 1; initial_params=nothing)
40+
chain2 = sample(gdemo_default, spl, 1)
4241
@test Array(chain1) == Array(chain2)
4342

43+
initial_nt = DynamicPPL.InitFromParams((s=2.0, m=1.0))
4444
# Initial parameters have to be specified for every walker
45-
@test_throws ArgumentError sample(gdemo_default, spl, 1; initial_params=[2.0, 1.0])
45+
@test_throws ArgumentError sample(gdemo_default, spl, 1; initial_params=initial_nt)
46+
@test_throws r"must be a vector of" sample(
47+
gdemo_default, spl, 1; initial_params=initial_nt
48+
)
4649

4750
# Initial parameters
48-
chain = sample(gdemo_default, spl, 1; initial_params=fill([2.0, 1.0], nwalkers))
51+
chain = sample(gdemo_default, spl, 1; initial_params=fill(initial_nt, nwalkers))
4952
@test chain[:s] == fill(2.0, 1, nwalkers)
5053
@test chain[:m] == fill(1.0, 1, nwalkers)
5154
end

0 commit comments

Comments
 (0)