Skip to content

Commit cef1f7a

Browse files
committed
fix add tests for drawing initial parameter
1 parent fd30281 commit cef1f7a

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

ext/SliceSamplingTuringExt.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
module SliceSamplingTuringExt
33

4+
using LogDensityProblems
45
using Random
56
using SliceSampling
67
using Turing
@@ -35,15 +36,18 @@ end
3536
# end
3637

3738
function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction)
39+
n_max_attempts = 1000
40+
3841
model =.model
3942
vi = Turing.DynamicPPL.VarInfo(rng, model, Turing.SampleFromUniform())
4043
vi_spl = last(Turing.DynamicPPL.evaluate_and_sample!!(rng, model, vi, Turing.SampleFromUniform()))
4144
θ = vi_spl[:]
45+
ℓp = LogDensityProblems.logdensity(ℓ, θ)
4246

4347
init_attempt_count = 1
44-
while !all(isfinite.(θ))
45-
if init_attempt_count == 10
46-
@warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `initial_params` keyword"
48+
for attempts in 1:n_max_attempts
49+
if attempts == 10
50+
@warn "Failed to find valid initial parameters after $(init_attempt_count) attempts; consider providing explicit initial parameters using the `initial_params` keyword"
4751
end
4852

4953
# NOTE: This will sample in the unconstrained space.
@@ -52,10 +56,15 @@ function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDe
5256
rng, model, vi, Turing.SampleFromUniform()
5357
),
5458
)
55-
θ = vi_spl[:]
59+
θ = vi_spl[:]
60+
ℓp = LogDensityProblems.logdensity(ℓ, θ)
5661

57-
init_attempt_count += 1
62+
if all(isfinite.(θ)) && isfinite(ℓp)
63+
return θ
64+
end
5865
end
66+
67+
@error "Failed to find valid initial parameters after $(n_max_attempts) attempts; consider providing explicit initial parameters using the `initial_params` keyword"
5968
return θ
6069
end
6170

test/turing.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,30 @@
88
return nothing
99
end
1010

11+
@model function illbehavedmodel()
12+
@addlogprob! -Inf
13+
return nothing
14+
end
15+
1116
@model function logp_check()
1217
a ~ Normal()
1318
return b ~ Normal()
1419
end
1520

21+
rng = Random.default_rng()
22+
@test begin
23+
init = SliceSampling.initial_sample(rng, LogDensityFunction(demo()))
24+
all(isfinite.(init))
25+
end
26+
27+
@test_warn "Warning: Failed" SliceSampling.initial_sample(
28+
rng, LogDensityFunction(illbehavedmodel())
29+
)
30+
31+
@test_warn "Error: Failed" SliceSampling.initial_sample(
32+
rng, LogDensityFunction(illbehavedmodel())
33+
)
34+
1635
n_samples = 1000
1736
model = demo()
1837

0 commit comments

Comments
 (0)