Skip to content

Commit f349e8c

Browse files
committed
update Turing compat by implementing mising AbstractMCMC interfaces
1 parent 6a6070d commit f349e8c

File tree

5 files changed

+39
-20
lines changed

5 files changed

+39
-20
lines changed

ext/SliceSamplingTuringExt.jl

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,11 @@ if isdefined(Base, :get_extension)
66
using Random
77
using SliceSampling
88
using Turing
9-
# using Turing: Turing, Experimental
109
else
1110
using ..LogDensityProblemsAD
1211
using ..Random
1312
using ..SliceSampling
1413
using ..Turing
15-
#using ..Turing: Turing, Experimental
1614
end
1715

1816
# Required for using the slice samplers as `externalsampler`s in Turing
@@ -24,12 +22,18 @@ function Turing.Inference.getparams(
2422
end
2523
# end
2624

27-
# Required for using the slice samplers as `Experimental.Gibbs` samplers in Turing
25+
# Required for using the slice samplers as `Gibbs` samplers in Turing
2826
# begin
27+
Turing.Inference.isgibbscomponent(::SliceSampling.RandPermGibbs) = true
28+
Turing.Inference.isgibbscomponent(::SliceSampling.HitAndRun) = true
29+
Turing.Inference.isgibbscomponent(::SliceSampling.Slice) = true
30+
Turing.Inference.isgibbscomponent(::SliceSampling.SliceSteppingOut) = true
31+
Turing.Inference.isgibbscomponent(::SliceSampling.SliceDoublingOut) = true
32+
2933
function Turing.Inference.getparams(
30-
::Turing.DynamicPPL.Model, state::SliceSampling.UnivariateSliceState
34+
::Turing.DynamicPPL.Model, sample::SliceSampling.UnivariateSliceState
3135
)
32-
return state.transition.params
36+
return sample.transition.params
3337
end
3438

3539
function Turing.Inference.getparams(
@@ -43,18 +47,6 @@ function Turing.Inference.getparams(
4347
)
4448
return state.transition.params
4549
end
46-
47-
function Turing.Experimental.gibbs_requires_recompute_logprob(
48-
model_dst,
49-
::Turing.DynamicPPL.Sampler{
50-
<:Turing.Inference.ExternalSampler{<:SliceSampling.AbstractSliceSampling,A,U}
51-
},
52-
sampler_src,
53-
state_dst,
54-
state_src,
55-
) where {A,U}
56-
return false
57-
end
5850
# end
5951

6052
function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction)

src/multivariate/hitandrun.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ struct HitAndRunState{T<:Transition}
1717
transition::T
1818
end
1919

20+
function AbstractMCMC.setparams!!(
21+
model::AbstractMCMC.LogDensityModel,
22+
state::HitAndRunState,
23+
params
24+
)
25+
lp = LogDensityProblems.logdensity(model.logdensity, params)
26+
return HitAndRunState(Transition(params, lp, NamedTuple()))
27+
end
28+
2029
struct HitAndRunTarget{Model,Vec<:AbstractVector}
2130
model :: Model
2231
direction :: Vec

src/multivariate/randpermgibbs.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ struct GibbsState{T<:Transition}
2525
transition::T
2626
end
2727

28+
function AbstractMCMC.setparams!!(
29+
model::AbstractMCMC.LogDensityModel,
30+
state::GibbsState,
31+
params
32+
)
33+
lp = LogDensityProblems.logdensity(model.logdensity, params)
34+
return GibbsState(Transition(params, lp, NamedTuple()))
35+
end
36+
2837
struct GibbsTarget{Model,Idx<:Integer,Vec<:AbstractVector}
2938
model :: Model
3039
idx :: Idx

src/univariate/univariate.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ struct UnivariateSliceState{T<:Transition}
2929
transition::T
3030
end
3131

32+
function AbstractMCMC.setparams!!(
33+
model::AbstractMCMC.LogDensityModel,
34+
state::UnivariateSliceState,
35+
params
36+
)
37+
lp = LogDensityProblems.logdensity(model.logdensity, params)
38+
return UnivariateSliceState(Transition(params, lp, NamedTuple()))
39+
end
40+
3241
function AbstractMCMC.step(
3342
rng::Random.AbstractRNG,
3443
model::AbstractMCMC.LogDensityModel,

test/runtests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using StableRNGs
1111

1212
using SliceSampling
1313

14-
include("univariate.jl")
15-
include("multivariate.jl")
16-
include("maxprops.jl")
14+
#include("univariate.jl")
15+
#include("multivariate.jl")
16+
#include("maxprops.jl")
1717
include("turing.jl")

0 commit comments

Comments
 (0)