Skip to content

Commit ef2eb15

Browse files
committed
Compat for Turing v0.41
1 parent 837fe07 commit ef2eb15

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Distributions = "0.25"
2121
LinearAlgebra = "1"
2222
LogDensityProblems = "2"
2323
Random = "1"
24-
Turing = "0.40"
24+
Turing = "0.41"
2525
julia = "1.10"
2626

2727
[extras]

ext/SliceSamplingTuringExt.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,9 @@ end
3838
function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction)
3939
n_max_attempts = 1000
4040

41-
model =.model
42-
vi = Turing.DynamicPPL.VarInfo(rng, model, Turing.SampleFromUniform())
43-
vi_spl = last(Turing.DynamicPPL.evaluate_and_sample!!(rng, model, vi, Turing.SampleFromUniform()))
44-
θ = vi_spl[:]
45-
ℓp = LogDensityProblems.logdensity(ℓ, θ)
41+
model, vi =.model, ℓ.varinfo
42+
vi_spl = last(Turing.DynamicPPL.init!!(rng, model, vi, Turing.DynamicPPL.InitFromUniform()))
43+
ℓp = Turing.DynamicPPL.getlogjoint_internal(vi_spl)
4644

4745
init_attempt_count = 1
4846
for attempts in 1:n_max_attempts
@@ -52,19 +50,20 @@ function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDe
5250

5351
# NOTE: This will sample in the unconstrained space.
5452
vi_spl = last(
55-
Turing.DynamicPPL.evaluate_and_sample!!(
56-
rng, model, vi, Turing.SampleFromUniform()
53+
Turing.DynamicPPL.init!!(
54+
rng, model, vi_spl, Turing.InitFromUniform()
5755
),
5856
)
57+
ℓp = Turing.DynamicPPL.getlogjoint_internal(vi_spl)
5958
θ = vi_spl[:]
60-
ℓp = LogDensityProblems.logdensity(ℓ, θ)
6159

6260
if all(isfinite.(θ)) && isfinite(ℓp)
6361
return θ
6462
end
6563
end
6664

6765
@error "Failed to find valid initial parameters after $(n_max_attempts) attempts; consider providing explicit initial parameters using the `initial_params` keyword"
66+
θ = vi_spl[:]
6867
return θ
6968
end
7069

0 commit comments

Comments
 (0)