Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
/docs/Manifest.toml
/docs/build/
**/*~
Maniest.toml
test/Manifest.toml
Manifest.toml
test/Manifest.toml
20 changes: 4 additions & 16 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,32 +1,20 @@
name = "SliceSampling"
uuid = "43f4d3e8-9711-4a8c-bd1b-03ac73a255cf"
version = "0.7.10"
version = "0.7.11"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[weakdeps]
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[extensions]
SliceSamplingTuringExt = ["Turing"]

[compat]
AbstractMCMC = "4, 5"
AbstractMCMC = "5.9"
Accessors = "0.1"
Distributions = "0.25"
LinearAlgebra = "1"
LogDensityProblems = "2"
Random = "1"
Turing = "0.41"
julia = "1.10"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[targets]
test = ["Test"]
4 changes: 2 additions & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
AbstractMCMC = "5"
AbstractMCMC = "5.9"
Distributions = "0.25"
Documenter = "1"
FillArrays = "1"
Expand All @@ -28,5 +28,5 @@ Random = "1"
SliceSampling = "0.7.1"
StableRNGs = "1"
Statistics = "1"
Turing = "0.41, 0.42"
Turing = "0.42"
julia = "1.10"
4 changes: 2 additions & 2 deletions docs/src/general.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ sample([rng,] model, slice, N; initial_params)
- `model`: A model implementing the `LogDensityProblems` interface.
- `N`: The number of samples

The output is a `SliceSampling.Transition` object, which contains the following:
The output is a vector of `SliceSampling.Transition`s, which contains the following:
```@docs
SliceSampling.Transition
```

For the keyword arguments, `SliceSampling` allows:
- `initial_params`: The intial state of the Markov chain (default: `nothing`).
- `initial_params`: The initial state of the Markov chain (default: `nothing`).

If `initial_params` is `nothing`, the following function can be implemented to provide an initialization:
```@docs
Expand Down
68 changes: 0 additions & 68 deletions ext/SliceSamplingTuringExt.jl

This file was deleted.

13 changes: 13 additions & 0 deletions src/SliceSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
module SliceSampling

using AbstractMCMC
using Accessors: Accessors
using Distributions
using LinearAlgebra
using LogDensityProblems
Expand Down Expand Up @@ -37,6 +38,18 @@ struct Transition{P,L<:Real,I<:NamedTuple}
info::I
end

# Base type for MCMC states that contain a `Transition` stored in the `transition` field.
abstract type AbstractStateWithTransition end
AbstractMCMC.getparams(state::AbstractStateWithTransition) = state.transition.params
AbstractMCMC.getstats(state::AbstractStateWithTransition) = state.transition.info
function AbstractMCMC.setparams!!(
model::AbstractMCMC.LogDensityModel, state::AbstractStateWithTransition, params
)
new_lp = LogDensityProblems.logdensity(model.logdensity, params)
new_transition = Transition(params, new_lp, NamedTuple())
return Accessors.@set state.transition = new_transition
end
Comment on lines +42 to +51
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the definitions of these functions are the same for all states in this package, I thought it would be cleaner to just define the behaviour on an abstract type. It does necessitate an extra dep on Accessors, but that's fairly lightweight.


"""
initial_sample(rng, model)

Expand Down
3 changes: 2 additions & 1 deletion src/multivariate/gibbspolar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ function GibbsPolarSlice(w::Real; max_proposals::Int=DEFAULT_MAX_PROPOSALS)
return GibbsPolarSlice(w, max_proposals)
end

struct GibbsPolarSliceState{T<:Transition,R<:Real,D<:AbstractVector}
struct GibbsPolarSliceState{T<:Transition,R<:Real,D<:AbstractVector} <:
AbstractStateWithTransition
"Current [`Transition`](@ref)."
transition::T

Expand Down
9 changes: 1 addition & 8 deletions src/multivariate/hitandrun.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,11 @@ struct HitAndRun{S<:AbstractUnivariateSliceSampling} <: AbstractMultivariateSlic
unislice::S
end

struct HitAndRunState{T<:Transition}
struct HitAndRunState{T<:Transition} <: AbstractStateWithTransition
"Current [`Transition`](@ref)."
transition::T
end

function AbstractMCMC.setparams!!(
model::AbstractMCMC.LogDensityModel, state::HitAndRunState, params
)
lp = LogDensityProblems.logdensity(model.logdensity, params)
return HitAndRunState(Transition(params, lp, NamedTuple()))
end

struct HitAndRunTarget{Model,Vec<:AbstractVector}
model :: Model
direction :: Vec
Expand Down
2 changes: 1 addition & 1 deletion src/multivariate/latent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function LatentSlice(beta::Real; max_proposals::Int=DEFAULT_MAX_PROPOSALS)
return LatentSlice(beta, max_proposals)
end

struct LatentSliceState{T<:Transition,S<:AbstractVector}
struct LatentSliceState{T<:Transition,S<:AbstractVector} <: AbstractStateWithTransition
"Current [`Transition`](@ref)."
transition::T

Expand Down
7 changes: 1 addition & 6 deletions src/multivariate/randpermgibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,11 @@ struct RandPermGibbs{
unislice::S
end

struct GibbsState{T<:Transition}
struct GibbsState{T<:Transition} <: AbstractStateWithTransition
"Current [`Transition`](@ref)."
transition::T
end

function AbstractMCMC.setparams!!(model::AbstractMCMC.LogDensityModel, ::GibbsState, params)
lp = LogDensityProblems.logdensity(model.logdensity, params)
return GibbsState(Transition(params, lp, NamedTuple()))
end

struct GibbsTarget{Model,Idx<:Integer,Vec<:AbstractVector}
model :: Model
idx :: Idx
Expand Down
9 changes: 1 addition & 8 deletions src/univariate/univariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,11 @@ function slice_sampling_univariate(
return exceeded_max_prop(max_prop)
end

struct UnivariateSliceState{T<:Transition}
struct UnivariateSliceState{T<:Transition} <: AbstractStateWithTransition
"Current [`Transition`](@ref)."
transition::T
end

function AbstractMCMC.setparams!!(
model::AbstractMCMC.LogDensityModel, state::UnivariateSliceState, params
)
lp = LogDensityProblems.logdensity(model.logdensity, params)
return UnivariateSliceState(Transition(params, lp, NamedTuple()))
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
Expand Down
6 changes: 4 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCTesting = "9963b6a1-5d46-439c-8efc-3a487843c7fa"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -10,13 +11,14 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
AbstractMCMC = "4, 5"
AbstractMCMC = "5.9"
Accessors = "0.1"
Distributions = "0.25"
DynamicPPL = "0.39.12"
LogDensityProblems = "2"
MCMCTesting = "0.3"
Random = "1"
StableRNGs = "1"
Test = "1"
Turing = "0.41, 0.42"
Turing = "0.42"
julia = "1.10"
23 changes: 2 additions & 21 deletions test/turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,11 @@
return nothing
end

@model function illbehavedmodel()
@addlogprob! -Inf
return nothing
end

@model function logp_check()
a ~ Normal()
return b ~ Normal()
end

rng = Random.default_rng()
@test begin
init = SliceSampling.initial_sample(rng, LogDensityFunction(demo()))
all(isfinite.(init))
end

@test_warn "Warning: Failed" SliceSampling.initial_sample(
rng, LogDensityFunction(illbehavedmodel())
)

@test_warn "Error: Failed" SliceSampling.initial_sample(
rng, LogDensityFunction(illbehavedmodel())
)

n_samples = 1000
model = demo()

Expand Down Expand Up @@ -61,7 +42,7 @@
@test isapprox(
logpdf.(Normal(), chain_logp_check[:a]) .+
logpdf.(Normal(), chain_logp_check[:b]),
chain_logp_check[:lp],
chain_logp_check[:logjoint],
)
end

Expand Down Expand Up @@ -89,7 +70,7 @@
@test isapprox(
logpdf.(Normal(), chain_logp_check[:a]) .+
logpdf.(Normal(), chain_logp_check[:b]),
chain_logp_check[:lp],
chain_logp_check[:logjoint],
)
end
end
Loading