Skip to content

Commit 97f6e4a

Browse files
committed
add stereographic slice sampling
1 parent 233eeaa commit 97f6e4a

File tree

3 files changed

+119
-0
lines changed

3 files changed

+119
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1212
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
13+
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1314
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1415
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1516

@@ -27,6 +28,7 @@ FillArrays = "1"
2728
LinearAlgebra = "1"
2829
LogDensityProblems = "2"
2930
LogDensityProblemsAD = "1"
31+
LogExpFunctions = "0.3"
3032
Random = "1"
3133
Requires = "1"
3234
Turing = "0.36"

src/SliceSampling.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using Distributions
77
using FillArrays
88
using LinearAlgebra
99
using LogDensityProblems
10+
using LogExpFunctions
1011
using Random
1112

1213
# The following is necessary because Turing wraps all models with
@@ -116,6 +117,10 @@ include("multivariate/latent.jl")
116117
export GibbsPolarSlice
117118
include("multivariate/gibbspolar.jl")
118119

120+
# Stereographic Slice Sampling
121+
export StereographicSlice
122+
include("multivariate/stereographic.jl")
123+
119124
# Turing Compatibility
120125

121126
if !isdefined(Base, :get_extension)

src/multivariate/stereographic.jl

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
2+
"""
3+
StereographicSlice(; max_proposals)
4+
5+
Stereographic slice sampling algorithm by Bell, Latuszynski, and Roberts[^BLR].
6+
7+
# Keyword Arguments
8+
- `max_proposals::Int`: Maximum number of proposals allowed until throwing an error (default: `$(DEFAULT_MAX_PROPOSALS)`).
9+
"""
10+
@kwdef struct StereographicSlice{RType <: Real} <: AbstractMultivariateSliceSampling
11+
max_proposals :: Int = DEFAULT_MAX_PROPOSALS
12+
end
13+
14+
struct StereographicSliceState{T<:Transition}
15+
"Current [`Transition`](@ref)."
16+
transition::T
17+
end
18+
19+
function rand_uniform_sphere_orthogonal_subspace(
20+
rng::Random.AbstractRNG, subspace_vector::AbstractVector
21+
)
22+
z = subspace_vector
23+
d = length(subspace_vector)
24+
v = randn(rng, d)
25+
v_proj = dot(z, v)/sum(abs2, z)*z
26+
v_orth = v - v_proj
27+
v_orth / norm(v_orth)
28+
end
29+
30+
function stereographic_projection(z::AbstractVector)
31+
d = length(z) - 1
32+
return z[1:d] ./ (1 - z[d+1])
33+
end
34+
35+
function stereographic_inverse_projection(x::AbstractVector)
36+
d = length(x)
37+
z = zeros(d + 1)
38+
x_norm2 = sum(abs2, x)
39+
z[1:d] = 2*x / (x_norm2 + 1)
40+
z[d+1] = (x_norm2 - 1)/(x_norm2 + 1)
41+
z
42+
end
43+
44+
function AbstractMCMC.step(
45+
rng::Random.AbstractRNG,
46+
model::AbstractMCMC.LogDensityModel,
47+
sampler::StereographicSlice;
48+
initial_params=nothing,
49+
kwargs...,
50+
)
51+
logdensitymodel = model.logdensity
52+
x = initial_params === nothing ? initial_sample(rng, logdensitymodel) : initial_params
53+
lp = LogDensityProblems.logdensity(logdensitymodel, x)
54+
t = Transition(x, lp, NamedTuple())
55+
return t, t
56+
end
57+
58+
function logdensity_sphere(ℓπ::Real, x::AbstractVector)
59+
d = length(x)
60+
return ℓπ + d*log(1 + sum(abs2, x))
61+
end
62+
63+
function AbstractMCMC.step(
64+
rng::Random.AbstractRNG,
65+
model::AbstractMCMC.LogDensityModel,
66+
sampler::StereographicSlice,
67+
state::Transition;
68+
kwargs...,
69+
)
70+
logdensitymodel = model.logdensity
71+
max_proposals = sampler.max_proposals
72+
73+
ℓp = state.lp
74+
x = state.params
75+
z = stereographic_inverse_projection(x)
76+
v = rand_uniform_sphere_orthogonal_subspace(rng, z)
77+
ℓp_sphere = logdensity_sphere(ℓp, x)
78+
ℓw = ℓp_sphere - Random.randexp(rng, eltype(x))
79+
80+
θ = rand(rng, Uniform(0, 2π))
81+
θ_max = θ
82+
θ_min = θ - 2π
83+
84+
props = 0
85+
while true
86+
props += 1
87+
88+
x_prop = stereographic_projection(z*cos(θ) + v*sin(θ))
89+
ℓp_prop = LogDensityProblems.logdensity(logdensitymodel, x_prop)
90+
ℓp_sphere_prop = logdensity_sphere(ℓp_prop, x_prop)
91+
92+
if ℓw < ℓp_sphere_prop
93+
ℓp = ℓp_prop
94+
x = x_prop
95+
break
96+
end
97+
98+
if props > max_proposals
99+
exceeded_max_prop(max_proposals)
100+
end
101+
102+
if θ < 0
103+
θ_min = θ
104+
else
105+
θ_max = θ
106+
end
107+
108+
θ = rand(rng, Uniform(θ_min, θ_max))
109+
end
110+
t = Transition(x, ℓp, (num_proposals=props,))
111+
return t, t
112+
end

0 commit comments

Comments
 (0)