Skip to content
44 changes: 44 additions & 0 deletions src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,50 @@ parameters for sampling are chosen if not specified by the user. By default, thi
"""
init_strategy(::AbstractSampler) = DynamicPPL.InitFromPrior()

"""
find_initial_params(rng, model, varinfo, init_strategy, validator; max_attempts=1000)

Attempt to find valid initial parameters for MCMC sampling, using the provided `init_strategy` up to `max_attempts` times.

The function `validator` should take the (updated) VarInfo with new parameters, and return a Bool indicating whether the parameters are valid.
"""
function find_initial_params(
rng::AbstractRNG,
model::Model,
varinfo::AbstractVarInfo,
init_strategy::DynamicPPL.AbstractInitStrategy,
validator::Function;
max_attempts::Int=1000,
)
varinfo = deepcopy(varinfo) # Don't mutate the input

for attempt in 1:max_attempts
# Validate current parameters
is_valid = validator(varinfo)

if is_valid
return varinfo # Success!
end

# Warn at attempt 10
if attempt == 10
@warn "failed to find valid initial parameters in $(attempt) tries; consider providing a different initialisation strategy with the `initial_params` keyword"
end

# If this is the last attempt, throw informative error
if attempt == max_attempts
error(
"Failed to find valid initial parameters after $max_attempts attempts. " *
"See https://turinglang.org/docs/uri/initial-parameters for common causes and solutions. " *
"If the issue persists, please open an issue at https://github.com/TuringLang/Turing.jl/issues"
)
end

# Regenerate parameters for next attempt
_, varinfo = DynamicPPL.init!!(rng, model, varinfo, init_strategy)
end
end

"""
_convert_initial_params(initial_params)

Expand Down
21 changes: 16 additions & 5 deletions src/mcmc/external_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,27 @@ function AbstractMCMC.step(
varinfo = DynamicPPL.link(varinfo, model)
end

# We need to extract the vectorised initial_params, because the later call to
# AbstractMCMC.step only sees a `LogDensityModel` which expects `initial_params`
# to be a vector.
initial_params_vector = varinfo[:]

# Construct LogDensityFunction
# Construct LogDensityFunction FIRST (we need this for validation)
f = DynamicPPL.LogDensityFunction(
model, DynamicPPL.getlogjoint_internal, varinfo; adtype=sampler_wrapper.adtype
)

# Use shared function to find valid initial parameters with gradient checking
validator = vi -> begin
θ = vi[:]
logp, grad = LogDensityProblems.logdensity_and_gradient(f, θ)
return isfinite(logp) && all(isfinite, grad)

end

varinfo = find_initial_params(
rng, model, varinfo, initial_params, validator; max_attempts=10
)

initial_params_vector = varinfo[:]


# Then just call `AbstractMCMC.step` with the right arguments.
_, state_inner = if initial_state === nothing
AbstractMCMC.step(
Expand Down
30 changes: 14 additions & 16 deletions src/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,24 +157,22 @@ function find_initial_params(
init_strategy::DynamicPPL.AbstractInitStrategy;
max_attempts::Int=1000,
)
varinfo = deepcopy(varinfo) # Don't mutate

for attempts in 1:max_attempts
theta = varinfo[:]
z = AHMC.phasepoint(rng, theta, hamiltonian)
isfinite(z) && return varinfo, z

attempts == 10 &&
@warn "failed to find valid initial parameters in $(attempts) tries; consider providing a different initialisation strategy with the `initial_params` keyword"

# Resample and try again.
_, varinfo = DynamicPPL.init!!(rng, model, varinfo, init_strategy)
# Create validator function for HMC
validator = vi -> begin
θ = vi[:]
z = AHMC.phasepoint(rng, θ, hamiltonian)
return (isfinite(z))
end

# if we failed to find valid initial parameters, error
return error(
"failed to find valid initial parameters in $(max_attempts) tries. See https://turinglang.org/docs/uri/initial-parameters for common causes and solutions. If the issue persists, please open an issue at https://github.com/TuringLang/Turing.jl/issues",

varinfo = find_initial_params(
rng, model, varinfo, init_strategy, validator; max_attempts=max_attempts
)

# Construct the final phasepoint
θ = varinfo[:]
z = AHMC.phasepoint(rng, θ, hamiltonian)

return varinfo, z
end

function Turing.Inference.initialstep(
Expand Down
118 changes: 117 additions & 1 deletion test/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module TuringAbstractMCMCTests

using AbstractMCMC: AbstractMCMC
using DynamicPPL: DynamicPPL
using Random: AbstractRNG
using Random: AbstractRNG, Random
using Test: @test, @testset, @test_throws
using Turing

Expand Down Expand Up @@ -150,4 +150,120 @@ end
end
end

@testset "Initial parameter retry logic" begin
# Model that produces -Inf logp for most parameter values
# Only valid when x is in narrow range [-0.3, 0.3]
@model function bad_init_model()
init_counter_1[] += 1
x ~ Normal(0, 1)
Turing.@addlogprob! (init_counter_1[] > 5) ? 0.0 : -Inf
end

# This should succeed with retry logic (might take a few attempts)
model = bad_init_model()

# Test with NUTS (internal HMC sampler with retry logic)
@testset "NUTS sampler" begin
chain = sample(model, NUTS(), 10)
@test size(chain, 1) == 10
# Check that samples are in valid range
x_samples = chain[:x]
@test all(abs.(x_samples) .< 0.3)
end

# Test with HMC (should also work with retry logic)
@testset "HMC sampler" begin
chain = sample(model, HMC(0.1, 5), 10)
@test size(chain, 1) == 10
x_samples = chain[:x]
@test all(abs.(x_samples) .< 0.3)
end

# Model with very narrow valid region
@model function very_bad_init_model()
x ~ Normal(0, 100) # Very wide prior
# Valid only in range [-0.1, 0.1]
Turing.@addlogprob! (abs(x) < 0.1) ? 0.0 : -Inf
end

@testset "Very narrow valid region" begin
model = very_bad_init_model()

# Should eventually find valid parameters
chain = sample(model, NUTS(), 10)
@test size(chain, 1) == 10
x_samples = chain[:x]
@test all(abs.(x_samples) .< 0.1)
end

# Model that's impossible to initialize (always -Inf)
@model function impossible_model()
x ~ Normal(0, 1)
Turing.@addlogprob! -Inf # Always invalid
end

@testset "Impossible initialization" begin
model = impossible_model()

# Should throw an error with informative message
@test_throws ErrorException sample(model, NUTS(), 10)
end

# Model that requires many attempts
@model function difficult_model()
x ~ Normal(0, 50)
# Valid only in tiny range, should trigger warning
Turing.@addlogprob! (abs(x) < 0.05) ? 0.0 : -Inf
end

@testset "Warning at attempt 10" begin
# Use a counter to ensure model fails exactly 30 times then succeeds
attempt_counter = Ref(0)

@model function counter_model()
attempt_counter[] += 1
x ~ Normal(0, 1)
# Fail for first 30 attempts, then succeed
Turing.@addlogprob! (attempt_counter[] > 30) ? 0.0 : -Inf
end

model = counter_model()

# Should see warning at attempt 10
@test_logs(
(:warn, r"failed to find valid initial parameters in 10 tries"),
match_mode=:any,
sample(model, NUTS(), 10)
)

# Verify it actually tried more than 30 times
@test attempt_counter[] > 30
end

@testset "Direct find_initial_params test" begin
@model function simple_model()
x ~ Normal(0, 1)
end

model = simple_model()
vi = DynamicPPL.VarInfo(model)
_, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromPrior())

# Validator that always succeeds
validator_success = vi -> true
result_vi = Turing.Inference.find_initial_params(
Random.default_rng(), model, vi, DynamicPPL.InitFromPrior(), validator_success
)
@test result_vi isa DynamicPPL.AbstractVarInfo

# Validator that succeeds after a few tries
counter = Ref(0)
validator_counter = vi -> (counter[] += 1; counter[] > 3)
result_vi = Turing.Inference.find_initial_params(
Random.default_rng(), model, vi, DynamicPPL.InitFromPrior(), validator_counter
)
@test counter[] > 3
end
end

end # module
16 changes: 0 additions & 16 deletions test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,22 +197,6 @@ using Turing
end
end

@testset "warning for difficult init params" begin
attempt = 0
@model function demo_warn_initial_params()
x ~ Normal()
if (attempt += 1) < 30
@addlogprob! -Inf
end
end

# verbose=false to suppress the initial step size notification, which messes with
# the test
@test_logs (:warn, r"consider providing a different initialisation strategy") sample(
demo_warn_initial_params(), NUTS(), 5; verbose=false
)
end

@testset "error for impossible model" begin
@model function demo_impossible()
x ~ Normal()
Expand Down
Loading