Skip to content

Commit 9f76d75

Browse files
torfjeldedevmotion
andauthored
Warn user if we're struggling to find good init params (#1999)
* added warning message in case we cant find good initialization point in a reasonable number of tries * version bump * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * added test for HMC initial params warning * Update test/inference/hmc.jl Co-authored-by: David Widmann <[email protected]> * fixed the warning test * Update test/inference/hmc.jl Co-authored-by: David Widmann <[email protected]> * relax prior tests a bit * further relaxation --------- Co-authored-by: David Widmann <[email protected]>
1 parent 861ae37 commit 9f76d75

File tree

3 files changed

+25
-2
lines changed

3 files changed

+25
-2
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
33
version = "0.25.2"
44

5-
65
[deps]
76
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
87
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"

src/inference/hmc.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,20 @@ function DynamicPPL.initialstep(
178178
# If no initial parameters are provided, resample until the log probability
179179
# and its gradient are finite.
180180
if init_params === nothing
181+
init_attempt_count = 1
181182
while !isfinite(z)
183+
if init_attempt_count == 10
184+
@warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `init_params` keyword"
185+
end
186+
182187
# NOTE: This will sample in the unconstrained space.
183188
vi = last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromUniform()))
184189
theta = vi[spl]
185190

186191
hamiltonian = AHMC.Hamiltonian(metric, logπ, ∂logπ∂θ)
187192
z = AHMC.phasepoint(rng, theta, hamiltonian)
193+
194+
init_attempt_count += 1
188195
end
189196
end
190197

test/inference/hmc.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,23 @@
221221
alg = NUTS(1000, 0.8)
222222
gdemo_default_prior = DynamicPPL.contextualize(gdemo_default, DynamicPPL.PriorContext())
223223
chain = sample(gdemo_default_prior, alg, 10_000)
224-
check_numerical(chain, [:s, :m], [mean(InverseGamma(2, 3)), 0], atol=0.3)
224+
check_numerical(chain, [:s, :m], [mean(InverseGamma(2, 3)), 0], atol=0.45)
225+
end
226+
227+
@turing_testset "warning for difficult init params" begin
228+
attempt = 0
229+
@model function demo_warn_init_params()
230+
x ~ Normal()
231+
if (attempt += 1) < 30
232+
Turing.@addlogprob! -Inf
233+
end
234+
end
235+
236+
@test_logs (
237+
:warn,
238+
"failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `init_params` keyword",
239+
) (:info,) match_mode=:any begin
240+
sample(demo_warn_init_params(), NUTS(), 5)
241+
end
225242
end
226243
end

0 commit comments

Comments
 (0)