Skip to content

Commit b02ac86

Browse files
committed
add tunable parameter for StereographicSlice
1 parent 9edc5ff commit b02ac86

File tree

3 files changed

+34
-25
lines changed

3 files changed

+34
-25
lines changed

src/multivariate/stereographic.jl

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11

22
"""
3-
StereographicSlice(; max_proposals)
3+
StereographicSlice(invscale; max_proposals)
44
55
Stereographic slice sampling algorithm by Bell, Latuszynski, and Roberts[^BLR2024].
66
77
# Keyword Arguments
8-
- `max_proposals::Int`: Maximum number of proposals allowed until throwing an error (default: `$(DEFAULT_MAX_PROPOSALS)`).
8+
- `invscale::Real`: Inverse scale of the stereographic projection.
9+
- `max_proposals::Int`: Maximum number of proposals allowed until throwing an error (default: `$(100*DEFAULT_MAX_PROPOSALS)`).
910
"""
10-
@kwdef struct StereographicSlice <: AbstractMultivariateSliceSampling
11-
max_proposals::Int = DEFAULT_MAX_PROPOSALS
11+
struct StereographicSlice{S<:Real} <: AbstractMultivariateSliceSampling
12+
invscale::S
13+
max_proposals::Int
14+
end
15+
16+
function StereographicSlice(invscale::Real; max_proposals::Int = 100*DEFAULT_MAX_PROPOSALS)
17+
@assert invscale > 0
18+
StereographicSlice{typeof(invscale)}(invscale, max_proposals)
1219
end
1320

1421
function AbstractMCMC.setparams!!(
@@ -29,17 +36,18 @@ function rand_uniform_sphere_orthogonal_subspace(
2936
return v_orth / norm(v_orth)
3037
end
3138

32-
function stereographic_projection(z::AbstractVector)
39+
function stereographic_projection(z::AbstractVector{T}, R::T) where {T<:Real}
3340
d = length(z) - 1
34-
return z[1:d] ./ (1 - z[d + 1])
41+
return R * z[1:d] ./ (1 - z[d + 1])
3542
end
3643

37-
function stereographic_inverse_projection(x::AbstractVector{T}) where {T<:Real}
44+
function stereographic_inverse_projection(x::AbstractVector{T}, R::T) where {T<:Real}
3845
d = length(x)
46+
R2 = R*R
3947
z = zeros(T, d + 1)
4048
x_norm2 = sum(abs2, x)
41-
z[1:d] = 2 * x / (x_norm2 + 1)
42-
z[d + 1] = (x_norm2 - 1) / (x_norm2 + 1)
49+
z[1:d] = 2 * R * x / (x_norm2 + R2)
50+
z[d + 1] = (x_norm2 - R2) / (x_norm2 + R2)
4351
return z
4452
end
4553

@@ -57,9 +65,9 @@ function AbstractMCMC.step(
5765
return t, t
5866
end
5967

60-
function logdensity_sphere(ℓπ::Real, x::AbstractVector)
68+
function logdensity_sphere(ℓπ::Real, x::AbstractVector{T}, R::T) where {T<:Real}
6169
d = length(x)
62-
return ℓπ + d * log(1 + sum(abs2, x))
70+
return ℓπ + d * log(R*R + sum(abs2, x))
6371
end
6472

6573
function AbstractMCMC.step(
@@ -74,22 +82,23 @@ function AbstractMCMC.step(
7482

7583
ℓp = state.lp
7684
x = state.params
77-
z = stereographic_inverse_projection(x)
85+
R = convert(eltype(x), sampler.invscale)
86+
z = stereographic_inverse_projection(x, R)
7887
v = rand_uniform_sphere_orthogonal_subspace(rng, z)
79-
ℓp_sphere = logdensity_sphere(ℓp, x)
88+
ℓp_sphere = logdensity_sphere(ℓp, x, R)
8089
ℓw = ℓp_sphere - Random.randexp(rng, eltype(x))
8190

82-
θ = convert(eltype(x), 2π) * rand(eltype(x), rng)
91+
θ = convert(eltype(x), 2π) * rand(rng, eltype(x))
8392
θ_max = θ
8493
θ_min = θ - convert(eltype(x), 2π)
8594

8695
props = 0
8796
while true
8897
props += 1
8998

90-
x_prop = stereographic_projection(z * cos(θ) + v * sin(θ))
99+
x_prop = stereographic_projection(z * cos(θ) + v * sin(θ), R)
91100
ℓp_prop = LogDensityProblems.logdensity(logdensitymodel, x_prop)
92-
ℓp_sphere_prop = logdensity_sphere(ℓp_prop, x_prop)
101+
ℓp_sphere_prop = logdensity_sphere(ℓp_prop, x_prop, R)
93102

94103
if ℓw < ℓp_sphere_prop
95104
ℓp = ℓp_prop
@@ -98,6 +107,7 @@ function AbstractMCMC.step(
98107
end
99108

100109
if props > max_proposals
110+
println(logdensitymodel)
101111
exceeded_max_prop(max_proposals)
102112
end
103113

@@ -106,8 +116,7 @@ function AbstractMCMC.step(
106116
else
107117
θ_max = θ
108118
end
109-
110-
θ = (θ_max - θ_min) * rand(rng)
119+
θ = (θ_max - θ_min) * rand(rng, eltype(x))
111120
end
112121
t = Transition(x, ℓp, (num_proposals=props,))
113122
return t, t

test/multivariate.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ function LogDensityProblems.dimension(model::MultiModel)
5959
end
6060

6161
@testset "multivariate samplers" begin
62-
model = MultiModel(1.0, 1.0, [0.0])
62+
model = MultiModel(3.0, 3.0, [0.0])
6363
@testset for sampler in [
6464
# Vector-valued windows
6565
RandPermGibbs(Slice.(fill(1, LogDensityProblems.dimension(model)))),
@@ -77,10 +77,10 @@ end
7777
# Multivariate slice samplers
7878
LatentSlice(5),
7979
GibbsPolarSlice(100),
80-
StereographicSlice(),
80+
StereographicSlice(1),
8181
]
8282
@testset "initial_params" begin
83-
model = MultiModel(1.0, 1.0, [0.0])
83+
model = MultiModel(3.0, 3.0, [0.0])
8484
θ, y = MCMCTesting.sample_joint(Random.default_rng(), model)
8585
model′ = AbstractMCMC.LogDensityModel(@set model.y = y)
8686

@@ -91,7 +91,7 @@ end
9191

9292
@testset "initial_sample" begin
9393
rng = StableRNG(1)
94-
model = MultiModel(1.0, 1.0, [0.0])
94+
model = MultiModel(3.0, 3.0, [0.0])
9595
θ0 = SliceSampling.initial_sample(rng, model)
9696

9797
rng = StableRNG(1)
@@ -100,7 +100,7 @@ end
100100
end
101101

102102
@testset "determinism" begin
103-
model = MultiModel(1.0, 1.0, [0.0])
103+
model = MultiModel(3.0, 3.0, [0.0])
104104
θ, y = MCMCTesting.sample_joint(Random.default_rng(), model)
105105
model′ = AbstractMCMC.LogDensityModel(@set model.y = y)
106106

@@ -139,7 +139,7 @@ end
139139
n_mcmc_thin = 10
140140
test = ExactRankTest(n_samples, n_mcmc_steps, n_mcmc_thin)
141141

142-
model = MultiModel(1.0, 1.0, [0.0])
142+
model = MultiModel(3.0, 3.0, [0.0])
143143
subject = TestSubject(model, sampler)
144144
@test seqmcmctest(test, subject, 0.001, n_pvalue_samples; show_progress=false)
145145
end

test/turing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
HitAndRun(SliceDoublingOut(1)),
2121
LatentSlice(5),
2222
GibbsPolarSlice(5),
23-
StereographicSlice(),
23+
StereographicSlice(10),
2424
]
2525
chain = sample(
2626
model,

0 commit comments

Comments
 (0)