Skip to content

Commit 5242b79

Browse files
committed
apply formatter
1 parent 7b261ae commit 5242b79

16 files changed

+394
-421
lines changed

docs/make.jl

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,27 @@
11
using SliceSampling
22
using Documenter
33

4-
DocMeta.setdocmeta!(SliceSampling, :DocTestSetup, :(using SliceSampling); recursive=true)
4+
DocMeta.setdocmeta!(SliceSampling, :DocTestSetup, :(using SliceSampling); recursive = true)
55

66
makedocs(;
7-
modules=[SliceSampling],
8-
authors="Kyurae Kim <[email protected]> and contributors",
9-
repo="https://github.com/TuringLang/SliceSampling.jl/blob/{commit}{path}#{line}",
10-
sitename="SliceSampling.jl",
11-
format=Documenter.HTML(;
12-
prettyurls=get(ENV, "CI", "false") == "true",
13-
canonical="https://TuringLang.org/SliceSampling.jl",
14-
edit_link="main",
15-
assets=String[],
7+
modules = [SliceSampling],
8+
authors = "Kyurae Kim <[email protected]> and contributors",
9+
repo = "https://github.com/TuringLang/SliceSampling.jl/blob/{commit}{path}#{line}",
10+
sitename = "SliceSampling.jl",
11+
format = Documenter.HTML(;
12+
prettyurls = get(ENV, "CI", "false") == "true",
13+
canonical = "https://TuringLang.org/SliceSampling.jl",
14+
edit_link = "main",
15+
assets = String[],
1616
),
17-
pages=[
18-
"Home" => "index.md",
19-
"General Usage" => "general.md",
20-
"Univariate Slice Sampling" => "univariate_slice.md",
21-
"Meta Multivariate Samplers" => "meta_multivariate.md",
22-
"Latent Slice Sampling" => "latent_slice.md",
23-
"Gibbsian Polar Slice Sampling" => "gibbs_polar.md"
17+
pages = [
18+
"Home" => "index.md",
19+
"General Usage" => "general.md",
20+
"Univariate Slice Sampling" => "univariate_slice.md",
21+
"Meta Multivariate Samplers" => "meta_multivariate.md",
22+
"Latent Slice Sampling" => "latent_slice.md",
23+
"Gibbsian Polar Slice Sampling" => "gibbs_polar.md",
2424
],
2525
)
2626

27-
deploydocs(;
28-
repo="github.com/TuringLang/SliceSampling.jl",
29-
push_preview=true
30-
)
27+
deploydocs(; repo = "github.com/TuringLang/SliceSampling.jl", push_preview = true)

ext/SliceSamplingTuringExt.jl

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ if isdefined(Base, :get_extension)
66
using Random
77
using SliceSampling
88
using Turing
9-
# using Turing: Turing, Experimental
9+
# using Turing: Turing, Experimental
1010
else
1111
using ..LogDensityProblemsAD
1212
using ..Random
@@ -17,60 +17,49 @@ end
1717

1818
# Required for using the slice samplers as `externalsampler`s in Turing
1919
# begin
20-
Turing.Inference.getparams(
21-
::Turing.DynamicPPL.Model,
22-
sample::SliceSampling.Transition
23-
) = sample.params
20+
Turing.Inference.getparams(::Turing.DynamicPPL.Model, sample::SliceSampling.Transition) =
21+
sample.params
2422
# end
2523

2624
# Required for using the slice samplers as `Experimental.Gibbs` samplers in Turing
2725
# begin
2826
Turing.Inference.getparams(
29-
::Turing.DynamicPPL.Model,
30-
state::SliceSampling.UnivariateSliceState
27+
::Turing.DynamicPPL.Model,
28+
state::SliceSampling.UnivariateSliceState,
3129
) = state.transition.params
3230

33-
Turing.Inference.getparams(
34-
::Turing.DynamicPPL.Model,
35-
state::SliceSampling.GibbsState
36-
) = state.transition.params
31+
Turing.Inference.getparams(::Turing.DynamicPPL.Model, state::SliceSampling.GibbsState) =
32+
state.transition.params
3733

38-
Turing.Inference.getparams(
39-
::Turing.DynamicPPL.Model,
40-
state::SliceSampling.HitAndRunState
41-
) = state.transition.params
34+
Turing.Inference.getparams(::Turing.DynamicPPL.Model, state::SliceSampling.HitAndRunState) =
35+
state.transition.params
4236

4337
Turing.Experimental.gibbs_requires_recompute_logprob(
4438
model_dst,
4539
::Turing.DynamicPPL.Sampler{
46-
<: Turing.Inference.ExternalSampler{
47-
<: SliceSampling.AbstractSliceSampling, A, U
48-
}
40+
<:Turing.Inference.ExternalSampler{<:SliceSampling.AbstractSliceSampling,A,U},
4941
},
5042
sampler_src,
5143
state_dst,
52-
state_src
44+
state_src,
5345
) where {A,U} = false
5446
# end
5547

56-
function SliceSampling.initial_sample(
57-
rng::Random.AbstractRNG,
58-
::Turing.LogDensityFunction
59-
)
48+
function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction)
6049
model =.model
61-
spl = Turing.SampleFromUniform()
62-
vi = Turing.VarInfo(rng, model, spl)
63-
θ = vi[spl]
50+
spl = Turing.SampleFromUniform()
51+
vi = Turing.VarInfo(rng, model, spl)
52+
θ = vi[spl]
6453

6554
init_attempt_count = 1
6655
while !isfinite(θ)
6756
if init_attempt_count == 10
6857
@warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `initial_params` keyword"
6958
end
70-
59+
7160
# NOTE: This will sample in the unconstrained space.
7261
vi = last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromUniform()))
73-
θ = vi[spl]
62+
θ = vi[spl]
7463

7564
init_attempt_count += 1
7665
end

src/SliceSampling.jl

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Struct containing the results of the transition.
3232
- `lp::Real`: Log-target density of the samples.
3333
- `info::NamedTuple`: Named tuple containing information about the transition.
3434
"""
35-
struct Transition{P, L <: Real, I <: NamedTuple}
35+
struct Transition{P,L<:Real,I<:NamedTuple}
3636
"current state of the slice sampling chain"
3737
params::P
3838

@@ -55,34 +55,33 @@ Return the initial sample for the `model` using the random number generator `rng
5555
function initial_sample(::Random.AbstractRNG, ::Any)
5656
error(
5757
"`initial_sample` is not implemented but an initialization wasn't provided. ",
58-
"Consider supplying an initialization to `initial_params`."
58+
"Consider supplying an initialization to `initial_params`.",
5959
)
6060
end
6161

6262
# If target is from `LogDensityProblemsAD`, unwrap target before calling `initial_sample`.
6363
# This is necessary since Turing wraps `DynamicPPL.Model`s when passed to an `externalsampler`.
64-
initial_sample(
65-
rng::Random.AbstractRNG,
66-
wrap::LogDensityProblemsAD.ADGradientWrapper
67-
) = initial_sample(rng, parent(wrap))
64+
initial_sample(rng::Random.AbstractRNG, wrap::LogDensityProblemsAD.ADGradientWrapper) =
65+
initial_sample(rng, parent(wrap))
6866

6967
function exceeded_max_prop(max_prop::Int)
70-
error("Exceeded maximum number of proposal $(max_prop), ",
71-
"which indicates an acceptance rate less than $(1/max_prop*100)%. ",
72-
"A quick fix is to increase `max_prop`, ",
73-
"but an acceptance rate that is too low often indicates that there is a problem. ",
74-
"Here are some possible causes:\n",
75-
"- The model might be broken or degenerate (most likely cause).\n",
76-
"- The tunable parameters of the sampler are suboptimal.\n",
77-
"- The initialization is pathologic. (try supplying a (different) `initial_params`)\n",
78-
"- There might be a bug in the sampler. (if this is suspected, file an issue to `SliceSampling`)\n"
79-
)
68+
error(
69+
"Exceeded maximum number of proposal $(max_prop), ",
70+
"which indicates an acceptance rate less than $(1/max_prop*100)%. ",
71+
"A quick fix is to increase `max_prop`, ",
72+
"but an acceptance rate that is too low often indicates that there is a problem. ",
73+
"Here are some possible causes:\n",
74+
"- The model might be broken or degenerate (most likely cause).\n",
75+
"- The tunable parameters of the sampler are suboptimal.\n",
76+
"- The initialization is pathologic. (try supplying a (different) `initial_params`)\n",
77+
"- There might be a bug in the sampler. (if this is suspected, file an issue to `SliceSampling`)\n",
78+
)
8079
end
8180

8281
## Univariate Slice Sampling Algorithms
8382
export Slice, SliceSteppingOut, SliceDoublingOut
8483

85-
abstract type AbstractUnivariateSliceSampling <: AbstractSliceSampling end
84+
abstract type AbstractUnivariateSliceSampling <: AbstractSliceSampling end
8685

8786
accept_slice_proposal(
8887
::AbstractSliceSampling,
@@ -103,7 +102,7 @@ include("univariate/steppingout.jl")
103102
include("univariate/doublingout.jl")
104103

105104
## Multivariate slice sampling algorithms
106-
abstract type AbstractMultivariateSliceSampling <: AbstractSliceSampling end
105+
abstract type AbstractMultivariateSliceSampling <: AbstractSliceSampling end
107106

108107
# Meta Multivariate Samplers
109108
export RandPermGibbs, HitAndRun
@@ -128,7 +127,7 @@ end
128127
@static if !isdefined(Base, :get_extension)
129128
function __init__()
130129
@require Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" include(
131-
"../ext/SliceSamplingTuringExt.jl"
130+
"../ext/SliceSamplingTuringExt.jl",
132131
)
133132
end
134133
end

0 commit comments

Comments
 (0)