diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index aeba71be26..6963f9c7c5 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -17,6 +17,48 @@ parameters for sampling are chosen if not specified by the user. By default, thi """ init_strategy(::AbstractSampler) = DynamicPPL.InitFromPrior() +""" + find_initial_params(rng, model, varinfo, init_strategy, validator; max_attempts=1000) + +Attempt to find valid initial parameters for MCMC sampling, using the provided `init_strategy` up to `max_attempts` times. + +The function `validator` should take the (updated) VarInfo with new parameters, and return a Bool indicating whether the parameters are valid. +""" +function find_initial_params( + rng::AbstractRNG, + model::Model, + varinfo::AbstractVarInfo, + init_strategy::DynamicPPL.AbstractInitStrategy, + validator::Any; + max_attempts::Int=1000, +) + varinfo = deepcopy(varinfo) # Don't mutate the input + + for attempt in 1:max_attempts + # Validate current parameters + if validator(varinfo) + return varinfo + end + + # Warn at attempt 10 + if attempt == 10 + @warn "failed to find valid initial parameters in $(attempt) tries; consider providing a different initialisation strategy with the `initial_params` keyword" + end + + # If this is the last attempt, throw informative error + if attempt == max_attempts + error( + "Failed to find valid initial parameters after $max_attempts attempts. " * + "See https://turinglang.org/docs/uri/initial-parameters for common causes and solutions. " * + "If the issue persists, please open an issue at https://github.com/TuringLang/Turing.jl/issues" + ) + end + + # Regenerate parameters for next attempt + _, varinfo = DynamicPPL.init!!(rng, model, varinfo, init_strategy) + end +end + """ _convert_initial_params(initial_params) diff --git a/src/mcmc/external_sampler.jl b/src/mcmc/external_sampler.jl index aa4d984c13..3b0db8fb9c 100644 --- a/src/mcmc/external_sampler.jl +++ b/src/mcmc/external_sampler.jl @@ -156,16 +156,27 @@ function AbstractMCMC.step( varinfo = DynamicPPL.link(varinfo, model) end - # We need to extract the vectorised initial_params, because the later call to - # AbstractMCMC.step only sees a `LogDensityModel` which expects `initial_params` - # to be a vector. - initial_params_vector = varinfo[:] - # Construct LogDensityFunction + # Construct LogDensityFunction FIRST (we need this for validation) f = DynamicPPL.LogDensityFunction( model, DynamicPPL.getlogjoint_internal, varinfo; adtype=sampler_wrapper.adtype ) + # Use shared function to find valid initial parameters with gradient checking + validator = vi -> begin + θ = vi[:] + logp, grad = LogDensityProblems.logdensity_and_gradient(f, θ) + return isfinite(logp) && all(isfinite, grad) + + end + + varinfo = find_initial_params( + rng, model, varinfo, initial_params, validator; max_attempts=10 + ) + + initial_params_vector = varinfo[:] + + # Then just call `AbstractMCMC.step` with the right arguments. _, state_inner = if initial_state === nothing AbstractMCMC.step( diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 7209325127..4bcde1d1ba 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -157,24 +157,22 @@ function find_initial_params( init_strategy::DynamicPPL.AbstractInitStrategy; max_attempts::Int=1000, ) - varinfo = deepcopy(varinfo) # Don't mutate - - for attempts in 1:max_attempts - theta = varinfo[:] - z = AHMC.phasepoint(rng, theta, hamiltonian) - isfinite(z) && return varinfo, z - - attempts == 10 && - @warn "failed to find valid initial parameters in $(attempts) tries; consider providing a different initialisation strategy with the `initial_params` keyword" - - # Resample and try again. - _, varinfo = DynamicPPL.init!!(rng, model, varinfo, init_strategy) + # Create validator function for HMC + validator = vi -> begin + θ = vi[:] + z = AHMC.phasepoint(rng, θ, hamiltonian) + return (isfinite(z)) end - - # if we failed to find valid initial parameters, error - return error( - "failed to find valid initial parameters in $(max_attempts) tries. See https://turinglang.org/docs/uri/initial-parameters for common causes and solutions. If the issue persists, please open an issue at https://github.com/TuringLang/Turing.jl/issues", + + varinfo = find_initial_params( + rng, model, varinfo, init_strategy, validator; max_attempts=max_attempts ) + + # Construct the final phasepoint + θ = varinfo[:] + z = AHMC.phasepoint(rng, θ, hamiltonian) + + return varinfo, z end function Turing.Inference.initialstep( diff --git a/test/mcmc/abstractmcmc.jl b/test/mcmc/abstractmcmc.jl index 957d33acfb..e4cc3ed918 100644 --- a/test/mcmc/abstractmcmc.jl +++ b/test/mcmc/abstractmcmc.jl @@ -2,7 +2,7 @@ module TuringAbstractMCMCTests using AbstractMCMC: AbstractMCMC using DynamicPPL: DynamicPPL -using Random: AbstractRNG +using Random: AbstractRNG, Random using Test: @test, @testset, @test_throws using Turing @@ -150,4 +150,110 @@ end end end +@testset "Initial parameter retry logic" begin + init_counter_1 = Ref(0) + @model function bad_init_model() + init_counter_1[] += 1 + x ~ Normal(0, 1) + Turing.@addlogprob! (init_counter_1[] > 5) ? 0.0 : -Inf + end + + model = bad_init_model() + + @testset "NUTS sampler" begin + Random.seed!(1234) + chain = sample(model, NUTS(), 10) + @test size(chain, 1) == 10 + end + + @test init_counter_1 => 6 + + init_counter_2 = Ref(0) + + @model function very_bad_init_model() + init_counter_2[] += 1 + x ~ Normal(0, 1) + Turing.@addlogprob! (abs(x) < 0.1) ? 0.0 : -Inf + end + + @testset "HMC sampler" begin + chain = sample(bad_init_model_hmc(), HMC(0.1, 5), 10) + @test size(chain, 1) == 10 + end + + @test init_counter_2[] => 6 + + # Model that's impossible to initialize (always -Inf) + @model function impossible_model() + x ~ Normal(0, 1) + Turing.@addlogprob! -Inf # Always invalid + end + + @testset "Impossible initialization" begin + model = impossible_model() + + # Should throw an error with informative message + @test_throws ErrorException sample(model, NUTS(), 10) + sample(model, NUTS(), 10) + @test occursin("Failed to find initial parameters", error_msg) + end + end + + # Model that requires many attempts + @model function difficult_model() + x ~ Normal(0, 50) + # Valid only in tiny range, should trigger warning + Turing.@addlogprob! (abs(x) < 0.05) ? 0.0 : -Inf + end + + @testset "Warning at attempt 10" begin + # Use a counter to ensure model fails exactly 30 times then succeeds + attempt_counter = Ref(0) + + @model function counter_model() + attempt_counter[] += 1 + x ~ Normal(0, 1) + # Fail for first 30 attempts, then succeed + Turing.@addlogprob! (attempt_counter[] > 30) ? 0.0 : -Inf + end + + model = counter_model() + + # Should see warning at attempt 10 + @test_logs( + (:warn, r"failed to find valid initial parameters in 10 tries"), + match_mode=:any, + sample(model, NUTS(), 10) + ) + + # Verify it actually tried more than 30 times + @test attempt_counter[] > 30 + end + + @testset "Direct find_initial_params test" begin + @model function simple_model() + x ~ Normal(0, 1) + end + + model = simple_model() + vi = DynamicPPL.VarInfo(model) + _, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromPrior()) + + # Validator that always succeeds + validator_success = vi -> true + result_vi = Turing.Inference.find_initial_params( + Random.default_rng(), model, vi, DynamicPPL.InitFromPrior(), validator_success + ) + @test result_vi isa DynamicPPL.AbstractVarInfo + + # Validator that succeeds after a few tries + counter = Ref(0) + validator_counter = vi -> (counter[] += 1; counter[] > 3) + result_vi = Turing.Inference.find_initial_params( + Random.default_rng(), model, vi, DynamicPPL.InitFromPrior(), validator_counter + ) + @test counter[] > 3 + end +end + end # module diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index c6b5af2162..6cacc08ece 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -197,30 +197,6 @@ using Turing end end - @testset "warning for difficult init params" begin - attempt = 0 - @model function demo_warn_initial_params() - x ~ Normal() - if (attempt += 1) < 30 - @addlogprob! -Inf - end - end - - # verbose=false to suppress the initial step size notification, which messes with - # the test - @test_logs (:warn, r"consider providing a different initialisation strategy") sample( - demo_warn_initial_params(), NUTS(), 5; verbose=false - ) - end - - @testset "error for impossible model" begin - @model function demo_impossible() - x ~ Normal() - @addlogprob! -Inf - end - - @test_throws ErrorException sample(demo_impossible(), NUTS(), 5) - end @testset "NUTS initial parameters" begin @model function f()