|
1 | 1 |
|
2 | 2 | module SliceSamplingTuringExt |
3 | 3 |
|
| 4 | +using LogDensityProblems |
4 | 5 | using Random |
5 | 6 | using SliceSampling |
6 | 7 | using Turing |
@@ -32,32 +33,38 @@ const SliceSamplingStates = Union{ |
32 | 33 | function Turing.Inference.getparams(::Turing.DynamicPPL.Model, sample::SliceSamplingStates) |
33 | 34 | return sample.transition.params |
34 | 35 | end |
35 | | - |
36 | | -function Turing.Inference.getlogp_external( |
37 | | - ::Turing.DynamicPPL.Model, t::SliceSampling.Transition, state |
38 | | -) |
39 | | - return t.lp |
40 | | -end |
41 | 36 | # end |
42 | 37 |
|
43 | 38 | function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction) |
| 39 | + n_max_attempts = 1000 |
| 40 | + |
44 | 41 | model = ℓ.model |
45 | 42 | vi = Turing.DynamicPPL.VarInfo(rng, model, Turing.SampleFromUniform()) |
46 | | - vi_spl = last(Turing.DynamicPPL.evaluate!!(model, rng, vi, Turing.SampleFromUniform())) |
| 43 | + vi_spl = last(Turing.DynamicPPL.evaluate_and_sample!!(rng, model, vi, Turing.SampleFromUniform())) |
47 | 44 | θ = vi_spl[:] |
| 45 | + ℓp = LogDensityProblems.logdensity(ℓ, θ) |
48 | 46 |
|
49 | 47 | init_attempt_count = 1 |
50 | | - while !all(isfinite.(θ)) |
51 | | - if init_attempt_count == 10 |
52 | | - @warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `initial_params` keyword" |
| 48 | + for attempts in 1:n_max_attempts |
| 49 | + if attempts == 10 |
| 50 | + @warn "Failed to find valid initial parameters after $(init_attempt_count) attempts; consider providing explicit initial parameters using the `initial_params` keyword" |
53 | 51 | end |
54 | 52 |
|
55 | 53 | # NOTE: This will sample in the unconstrained space. |
56 | | - vi_spl = last(Turing.DynamicPPL.evaluate!!(model, rng, vi, Turing.SampleFromUniform())) |
57 | | - θ = vi_spl[:] |
| 54 | + vi_spl = last( |
| 55 | + Turing.DynamicPPL.evaluate_and_sample!!( |
| 56 | + rng, model, vi, Turing.SampleFromUniform() |
| 57 | + ), |
| 58 | + ) |
| 59 | + θ = vi_spl[:] |
| 60 | + ℓp = LogDensityProblems.logdensity(ℓ, θ) |
58 | 61 |
|
59 | | - init_attempt_count += 1 |
| 62 | + if all(isfinite.(θ)) && isfinite(ℓp) |
| 63 | + return θ |
| 64 | + end |
60 | 65 | end |
| 66 | + |
| 67 | + @error "Failed to find valid initial parameters after $(n_max_attempts) attempts; consider providing explicit initial parameters using the `initial_params` keyword" |
61 | 68 | return θ |
62 | 69 | end |
63 | 70 |
|
|
0 commit comments