diff --git a/Project.toml b/Project.toml index 28fec35..21a79c2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "SliceSampling" uuid = "43f4d3e8-9711-4a8c-bd1b-03ac73a255cf" -version = "0.7.7" +version = "0.7.8" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -21,7 +21,7 @@ Distributions = "0.25" LinearAlgebra = "1" LogDensityProblems = "2" Random = "1" -Turing = "0.39.5" +Turing = "0.40" julia = "1.10" [extras] diff --git a/docs/Project.toml b/docs/Project.toml index 65b7f91..845e9b0 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -28,5 +28,5 @@ Random = "1" SliceSampling = "0.7.1" StableRNGs = "1" Statistics = "1" -Turing = "0.37, 0.38, 0.39" +Turing = "0.37, 0.38, 0.39, 0.40" julia = "1.10" diff --git a/ext/SliceSamplingTuringExt.jl b/ext/SliceSamplingTuringExt.jl index 44f9f6a..d20fc71 100644 --- a/ext/SliceSamplingTuringExt.jl +++ b/ext/SliceSamplingTuringExt.jl @@ -1,6 +1,7 @@ module SliceSamplingTuringExt +using LogDensityProblems using Random using SliceSampling using Turing @@ -32,32 +33,38 @@ const SliceSamplingStates = Union{ function Turing.Inference.getparams(::Turing.DynamicPPL.Model, sample::SliceSamplingStates) return sample.transition.params end - -function Turing.Inference.getlogp_external( - ::Turing.DynamicPPL.Model, t::SliceSampling.Transition, state -) - return t.lp -end # end function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction) + n_max_attempts = 1000 + model = ℓ.model vi = Turing.DynamicPPL.VarInfo(rng, model, Turing.SampleFromUniform()) - vi_spl = last(Turing.DynamicPPL.evaluate!!(model, rng, vi, Turing.SampleFromUniform())) + vi_spl = last(Turing.DynamicPPL.evaluate_and_sample!!(rng, model, vi, Turing.SampleFromUniform())) θ = vi_spl[:] + ℓp = LogDensityProblems.logdensity(ℓ, θ) init_attempt_count = 1 - while !all(isfinite.(θ)) - if init_attempt_count == 10 - @warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `initial_params` keyword" + for attempts in 1:n_max_attempts + if attempts == 10 + @warn "Failed to find valid initial parameters after $(init_attempt_count) attempts; consider providing explicit initial parameters using the `initial_params` keyword" end # NOTE: This will sample in the unconstrained space. - vi_spl = last(Turing.DynamicPPL.evaluate!!(model, rng, vi, Turing.SampleFromUniform())) - θ = vi_spl[:] + vi_spl = last( + Turing.DynamicPPL.evaluate_and_sample!!( + rng, model, vi, Turing.SampleFromUniform() + ), + ) + θ = vi_spl[:] + ℓp = LogDensityProblems.logdensity(ℓ, θ) - init_attempt_count += 1 + if all(isfinite.(θ)) && isfinite(ℓp) + return θ + end end + + @error "Failed to find valid initial parameters after $(n_max_attempts) attempts; consider providing explicit initial parameters using the `initial_params` keyword" return θ end diff --git a/src/SliceSampling.jl b/src/SliceSampling.jl index b8df323..c43875c 100644 --- a/src/SliceSampling.jl +++ b/src/SliceSampling.jl @@ -1,4 +1,4 @@ -ls + module SliceSampling using AbstractMCMC diff --git a/test/Project.toml b/test/Project.toml index cadc38f..2c43400 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -18,5 +18,5 @@ MCMCTesting = "0.3" Random = "1" StableRNGs = "1" Test = "1" -Turing = "0.37, 0.38, 0.39" +Turing = "0.37, 0.38, 0.39, 0.40" julia = "1.10" diff --git a/test/turing.jl b/test/turing.jl index f5c6104..b0e1886 100644 --- a/test/turing.jl +++ b/test/turing.jl @@ -8,11 +8,30 @@ return nothing end + @model function illbehavedmodel() + @addlogprob! -Inf + return nothing + end + @model function logp_check() a ~ Normal() return b ~ Normal() end + rng = Random.default_rng() + @test begin + init = SliceSampling.initial_sample(rng, LogDensityFunction(demo())) + all(isfinite.(init)) + end + + @test_warn "Warning: Failed" SliceSampling.initial_sample( + rng, LogDensityFunction(illbehavedmodel()) + ) + + @test_warn "Error: Failed" SliceSampling.initial_sample( + rng, LogDensityFunction(illbehavedmodel()) + ) + n_samples = 1000 model = demo()