|
1 | 1 |
|
2 | 2 | module SliceSamplingTuringExt
|
3 | 3 |
|
| 4 | +using LogDensityProblems |
4 | 5 | using Random
|
5 | 6 | using SliceSampling
|
6 | 7 | using Turing
|
@@ -32,32 +33,38 @@ const SliceSamplingStates = Union{
|
32 | 33 | function Turing.Inference.getparams(::Turing.DynamicPPL.Model, sample::SliceSamplingStates)
|
33 | 34 | return sample.transition.params
|
34 | 35 | end
|
35 |
| - |
36 |
| -function Turing.Inference.getlogp_external( |
37 |
| - ::Turing.DynamicPPL.Model, t::SliceSampling.Transition, state |
38 |
| -) |
39 |
| - return t.lp |
40 |
| -end |
41 | 36 | # end
|
42 | 37 |
|
43 | 38 | function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction)
|
| 39 | + n_max_attempts = 1000 |
| 40 | + |
44 | 41 | model = ℓ.model
|
45 | 42 | vi = Turing.DynamicPPL.VarInfo(rng, model, Turing.SampleFromUniform())
|
46 |
| - vi_spl = last(Turing.DynamicPPL.evaluate!!(model, rng, vi, Turing.SampleFromUniform())) |
| 43 | + vi_spl = last(Turing.DynamicPPL.evaluate_and_sample!!(rng, model, vi, Turing.SampleFromUniform())) |
47 | 44 | θ = vi_spl[:]
|
| 45 | + ℓp = LogDensityProblems.logdensity(ℓ, θ) |
48 | 46 |
|
49 | 47 | init_attempt_count = 1
|
50 |
| - while !all(isfinite.(θ)) |
51 |
| - if init_attempt_count == 10 |
52 |
| - @warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `initial_params` keyword" |
| 48 | + for attempts in 1:n_max_attempts |
| 49 | + if attempts == 10 |
| 50 | + @warn "Failed to find valid initial parameters after $(init_attempt_count) attempts; consider providing explicit initial parameters using the `initial_params` keyword" |
53 | 51 | end
|
54 | 52 |
|
55 | 53 | # NOTE: This will sample in the unconstrained space.
|
56 |
| - vi_spl = last(Turing.DynamicPPL.evaluate!!(model, rng, vi, Turing.SampleFromUniform())) |
57 |
| - θ = vi_spl[:] |
| 54 | + vi_spl = last( |
| 55 | + Turing.DynamicPPL.evaluate_and_sample!!( |
| 56 | + rng, model, vi, Turing.SampleFromUniform() |
| 57 | + ), |
| 58 | + ) |
| 59 | + θ = vi_spl[:] |
| 60 | + ℓp = LogDensityProblems.logdensity(ℓ, θ) |
58 | 61 |
|
59 |
| - init_attempt_count += 1 |
| 62 | + if all(isfinite.(θ)) && isfinite(ℓp) |
| 63 | + return θ |
| 64 | + end |
60 | 65 | end
|
| 66 | + |
| 67 | + @error "Failed to find valid initial parameters after $(n_max_attempts) attempts; consider providing explicit initial parameters using the `initial_params` keyword" |
61 | 68 | return θ
|
62 | 69 | end
|
63 | 70 |
|
|
0 commit comments