diff --git a/.github/workflows/Docs.yml b/.github/workflows/Docs.yml index 475399d..024f8e1 100644 --- a/.github/workflows/Docs.yml +++ b/.github/workflows/Docs.yml @@ -24,6 +24,8 @@ jobs: steps: - name: Build and deploy Documenter.jl docs uses: TuringLang/actions/DocsDocumenter@main + with: + julia-version: 1.11 - name: Run doctests shell: julia --project=docs --color=yes {0} diff --git a/Project.toml b/Project.toml index 21a79c2..3ee0dd5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "SliceSampling" uuid = "43f4d3e8-9711-4a8c-bd1b-03ac73a255cf" -version = "0.7.8" +version = "0.7.9" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -21,7 +21,7 @@ Distributions = "0.25" LinearAlgebra = "1" LogDensityProblems = "2" Random = "1" -Turing = "0.40" +Turing = "0.41" julia = "1.10" [extras] diff --git a/docs/Project.toml b/docs/Project.toml index 3d162a2..83fff61 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -28,5 +28,5 @@ Random = "1" SliceSampling = "0.7.1" StableRNGs = "1" Statistics = "1" -Turing = "0.37, 0.38, 0.39, 0.40" +Turing = "0.41" julia = "1.10" diff --git a/docs/src/gibbs_polar.md b/docs/src/gibbs_polar.md index 8372708..191ae17 100644 --- a/docs/src/gibbs_polar.md +++ b/docs/src/gibbs_polar.md @@ -63,8 +63,9 @@ end model = demo() n_samples = 1000 -latent_chain = sample(model, externalsampler(LatentSlice(10)), n_samples; initial_params=ones(10)) -polar_chain = sample(model, externalsampler(GibbsPolarSlice(10)), n_samples; initial_params=ones(10)) +initial_params = InitFromParams((x = ones(10),)) +latent_chain = sample(model, externalsampler(LatentSlice(10)), n_samples; initial_params=initial_params) +polar_chain = sample(model, externalsampler(GibbsPolarSlice(10)), n_samples; initial_params=initial_params) l = @layout [a; b] p1 = Plots.plot(1:n_samples, latent_chain[:,1,:], ylims=[-10,10], label="LSS") diff --git a/ext/SliceSamplingTuringExt.jl b/ext/SliceSamplingTuringExt.jl index d20fc71..3eba4bd 100644 --- a/ext/SliceSamplingTuringExt.jl +++ b/ext/SliceSamplingTuringExt.jl @@ -38,11 +38,11 @@ end function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction) n_max_attempts = 1000 - model = ℓ.model - vi = Turing.DynamicPPL.VarInfo(rng, model, Turing.SampleFromUniform()) - vi_spl = last(Turing.DynamicPPL.evaluate_and_sample!!(rng, model, vi, Turing.SampleFromUniform())) - θ = vi_spl[:] - ℓp = LogDensityProblems.logdensity(ℓ, θ) + model, vi = ℓ.model, ℓ.varinfo + vi_spl = last( + Turing.DynamicPPL.init!!(rng, model, vi, Turing.DynamicPPL.InitFromUniform()) + ) + ℓp = ℓ.getlogdensity(vi_spl) init_attempt_count = 1 for attempts in 1:n_max_attempts @@ -50,14 +50,10 @@ function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDe @warn "Failed to find valid initial parameters after $(init_attempt_count) attempts; consider providing explicit initial parameters using the `initial_params` keyword" end - # NOTE: This will sample in the unconstrained space. - vi_spl = last( - Turing.DynamicPPL.evaluate_and_sample!!( - rng, model, vi, Turing.SampleFromUniform() - ), - ) + # NOTE: This will sample in the unconstrained space if ℓ.varinfo is linked + vi_spl = last(Turing.DynamicPPL.init!!(rng, model, vi, Turing.InitFromUniform())) + ℓp = ℓ.getlogdensity(vi_spl) θ = vi_spl[:] - ℓp = LogDensityProblems.logdensity(ℓ, θ) if all(isfinite.(θ)) && isfinite(ℓp) return θ @@ -65,6 +61,7 @@ function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDe end @error "Failed to find valid initial parameters after $(n_max_attempts) attempts; consider providing explicit initial parameters using the `initial_params` keyword" + θ = vi_spl[:] return θ end diff --git a/src/multivariate/gibbspolar.jl b/src/multivariate/gibbspolar.jl index f5be0d3..7872ee4 100644 --- a/src/multivariate/gibbspolar.jl +++ b/src/multivariate/gibbspolar.jl @@ -18,10 +18,7 @@ Gibbsian polar slice sampling algorithm by P. Schär, M. Habeck, and D. Rudolf [ The initial window size `w` must be set at least an order of magnitude larger than what is sensible for other slice samplers. Otherwise, a large number of rejections might be experienced. !!! warning - When initializing the chain (*e.g.* the `initial_params` keyword arguments in `AbstractMCMC.sample`), it is necessary to inialize from a point \$\$x_0\$\$ that has a sensible norm \$\$\\lVert x_0 \\rVert > 0\$\$, otherwise, the chain will start from a pathologic point in polar coordinates. This might even result in the sampler getting stuck in an infinite loop. (This can be prevented by setting `max_proposals`.) If \$\$\\lVert x_0 \\rVert \\leq 10^{-5}\$\$, the current implementation will display a warning. - -!!! info - For Turing users: `Turing` might change `initial_params` to match the support of the posterior. This might lead to \$\$\\lVert x_0 \\rVert\$\$ being small, even though the vector you passed to`initial_params` has a sufficiently large norm. If this is suspected, simply try a different initialization value. + When initializing the chain (*e.g.* the `initial_params` keyword arguments in `AbstractMCMC.sample`), it is necessary to initialize from a point \$\$x_0\$\$ that has a sensible norm \$\$\\lVert x_0 \\rVert > 0\$\$, otherwise, the chain will start from a pathological point in polar coordinates. This might even result in the sampler getting stuck in an infinite loop. (This can be prevented by setting `max_proposals`.) If \$\$\\lVert x_0 \\rVert \\leq 10^{-5}\$\$, the current implementation will display a warning. """ struct GibbsPolarSlice{W<:Real} <: AbstractMultivariateSliceSampling w::W diff --git a/test/Project.toml b/test/Project.toml index 2c43400..b6ae6f6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -18,5 +18,5 @@ MCMCTesting = "0.3" Random = "1" StableRNGs = "1" Test = "1" -Turing = "0.37, 0.38, 0.39, 0.40" +Turing = "0.41" julia = "1.10" diff --git a/test/turing.jl b/test/turing.jl index b0e1886..db02909 100644 --- a/test/turing.jl +++ b/test/turing.jl @@ -49,7 +49,7 @@ model, externalsampler(sampler), n_samples; - initial_params=[1.0, 0.1], + initial_params=InitFromParams((s=1.0, m=0.1)), progress=false, )