Skip to content

Commit dea5d19

Browse files
authored
Refactor HMC initialisation code (#2567)
1 parent 411a341 commit dea5d19

File tree

1 file changed

+36
-25
lines changed

1 file changed

+36
-25
lines changed

src/mcmc/hmc.jl

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,36 @@ function AbstractMCMC.sample(
138138
end
139139
end
140140

141+
function find_initial_params(
142+
rng::Random.AbstractRNG,
143+
model::DynamicPPL.Model,
144+
varinfo::DynamicPPL.AbstractVarInfo,
145+
hamiltonian::AHMC.Hamiltonian;
146+
max_attempts::Int=1000,
147+
)
148+
varinfo = deepcopy(varinfo) # Don't mutate
149+
150+
for attempts in 1:max_attempts
151+
theta = varinfo[:]
152+
z = AHMC.phasepoint(rng, theta, hamiltonian)
153+
isfinite(z) && return varinfo, z
154+
155+
attempts == 10 &&
156+
@warn "failed to find valid initial parameters in $(attempts) tries; consider providing explicit initial parameters using the `initial_params` keyword"
157+
158+
# Resample and try again.
159+
# NOTE: varinfo has to be linked to make sure this samples in unconstrained space
160+
varinfo = last(
161+
DynamicPPL.evaluate!!(model, rng, varinfo, DynamicPPL.SampleFromUniform())
162+
)
163+
end
164+
165+
# if we failed to find valid initial parameters, error
166+
return error(
167+
"failed to find valid initial parameters in $(max_attempts) tries. This may indicate an error with the model or AD backend; please open an issue at https://github.com/TuringLang/Turing.jl/issues",
168+
)
169+
end
170+
141171
function DynamicPPL.initialstep(
142172
rng::AbstractRNG,
143173
model::AbstractModel,
@@ -170,33 +200,14 @@ function DynamicPPL.initialstep(
170200
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
171201
hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func)
172202

173-
# Compute phase point z.
174-
z = AHMC.phasepoint(rng, theta, hamiltonian)
175-
176203
# If no initial parameters are provided, resample until the log probability
177-
# and its gradient are finite.
178-
if initial_params === nothing
179-
init_attempt_count = 1
180-
while !isfinite(z)
181-
if init_attempt_count == 10
182-
@warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `initial_params` keyword"
183-
end
184-
if init_attempt_count == 1000
185-
error(
186-
"failed to find valid initial parameters in $(init_attempt_count) tries. This may indicate an error with the model or AD backend; please open an issue at https://github.com/TuringLang/Turing.jl/issues",
187-
)
188-
end
189-
190-
# NOTE: This will sample in the unconstrained space.
191-
vi = last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromUniform()))
192-
theta = vi[:]
193-
194-
hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func)
195-
z = AHMC.phasepoint(rng, theta, hamiltonian)
196-
197-
init_attempt_count += 1
198-
end
204+
# and its gradient are finite. Otherwise, just use the existing parameters.
205+
vi, z = if initial_params === nothing
206+
find_initial_params(rng, model, vi, hamiltonian)
207+
else
208+
vi, AHMC.phasepoint(rng, theta, hamiltonian)
199209
end
210+
theta = vi[:]
200211

201212
# Cache current log density.
202213
log_density_old = getlogp(vi)

0 commit comments

Comments
 (0)