Skip to content

Commit e561651

Browse files
committed
fix apply custom formatter
1 parent d14df00 commit e561651

15 files changed

+156
-186
lines changed

docs/make.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
using SliceSampling
22
using Documenter
33

4-
DocMeta.setdocmeta!(SliceSampling, :DocTestSetup, :(using SliceSampling); recursive = true)
4+
DocMeta.setdocmeta!(SliceSampling, :DocTestSetup, :(using SliceSampling); recursive=true)
55

66
makedocs(;
7-
modules = [SliceSampling],
8-
authors = "Kyurae Kim <[email protected]> and contributors",
9-
repo = "https://github.com/TuringLang/SliceSampling.jl/blob/{commit}{path}#{line}",
10-
sitename = "SliceSampling.jl",
11-
format = Documenter.HTML(;
12-
prettyurls = get(ENV, "CI", "false") == "true",
13-
canonical = "https://TuringLang.org/SliceSampling.jl",
14-
edit_link = "main",
15-
assets = String[],
7+
modules=[SliceSampling],
8+
authors="Kyurae Kim <[email protected]> and contributors",
9+
repo="https://github.com/TuringLang/SliceSampling.jl/blob/{commit}{path}#{line}",
10+
sitename="SliceSampling.jl",
11+
format=Documenter.HTML(;
12+
prettyurls=get(ENV, "CI", "false") == "true",
13+
canonical="https://TuringLang.org/SliceSampling.jl",
14+
edit_link="main",
15+
assets=String[],
1616
),
17-
pages = [
17+
pages=[
1818
"Home" => "index.md",
1919
"General Usage" => "general.md",
2020
"Univariate Slice Sampling" => "univariate_slice.md",
@@ -24,4 +24,4 @@ makedocs(;
2424
],
2525
)
2626

27-
deploydocs(; repo = "github.com/TuringLang/SliceSampling.jl", push_preview = true)
27+
deploydocs(; repo="github.com/TuringLang/SliceSampling.jl", push_preview=true)

ext/SliceSamplingTuringExt.jl

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,44 @@ end
1717

1818
# Required for using the slice samplers as `externalsampler`s in Turing
1919
# begin
20-
Turing.Inference.getparams(::Turing.DynamicPPL.Model, sample::SliceSampling.Transition) =
21-
sample.params
20+
function Turing.Inference.getparams(
21+
::Turing.DynamicPPL.Model, sample::SliceSampling.Transition
22+
)
23+
return sample.params
24+
end
2225
# end
2326

2427
# Required for using the slice samplers as `Experimental.Gibbs` samplers in Turing
2528
# begin
26-
Turing.Inference.getparams(
27-
::Turing.DynamicPPL.Model,
28-
state::SliceSampling.UnivariateSliceState,
29-
) = state.transition.params
29+
function Turing.Inference.getparams(
30+
::Turing.DynamicPPL.Model, state::SliceSampling.UnivariateSliceState
31+
)
32+
return state.transition.params
33+
end
3034

31-
Turing.Inference.getparams(::Turing.DynamicPPL.Model, state::SliceSampling.GibbsState) =
32-
state.transition.params
35+
function Turing.Inference.getparams(
36+
::Turing.DynamicPPL.Model, state::SliceSampling.GibbsState
37+
)
38+
return state.transition.params
39+
end
3340

34-
Turing.Inference.getparams(::Turing.DynamicPPL.Model, state::SliceSampling.HitAndRunState) =
35-
state.transition.params
41+
function Turing.Inference.getparams(
42+
::Turing.DynamicPPL.Model, state::SliceSampling.HitAndRunState
43+
)
44+
return state.transition.params
45+
end
3646

37-
Turing.Experimental.gibbs_requires_recompute_logprob(
47+
function Turing.Experimental.gibbs_requires_recompute_logprob(
3848
model_dst,
3949
::Turing.DynamicPPL.Sampler{
40-
<:Turing.Inference.ExternalSampler{<:SliceSampling.AbstractSliceSampling,A,U},
50+
<:Turing.Inference.ExternalSampler{<:SliceSampling.AbstractSliceSampling,A,U}
4151
},
4252
sampler_src,
4353
state_dst,
4454
state_src,
45-
) where {A,U} = false
55+
) where {A,U}
56+
return false
57+
end
4658
# end
4759

4860
function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction)
@@ -63,7 +75,7 @@ function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDe
6375

6476
init_attempt_count += 1
6577
end
66-
θ
78+
return θ
6779
end
6880

6981
end

src/SliceSampling.jl

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,22 @@ Return the initial sample for the `model` using the random number generator `rng
5353
- `model`: The target `LogDensityProblem`.
5454
"""
5555
function initial_sample(::Random.AbstractRNG, ::Any)
56-
error(
56+
return error(
5757
"`initial_sample` is not implemented but an initialization wasn't provided. ",
5858
"Consider supplying an initialization to `initial_params`.",
5959
)
6060
end
6161

6262
# If target is from `LogDensityProblemsAD`, unwrap target before calling `initial_sample`.
6363
# This is necessary since Turing wraps `DynamicPPL.Model`s when passed to an `externalsampler`.
64-
initial_sample(rng::Random.AbstractRNG, wrap::LogDensityProblemsAD.ADGradientWrapper) =
65-
initial_sample(rng, parent(wrap))
64+
function initial_sample(
65+
rng::Random.AbstractRNG, wrap::LogDensityProblemsAD.ADGradientWrapper
66+
)
67+
return initial_sample(rng, parent(wrap))
68+
end
6669

6770
function exceeded_max_prop(max_prop::Int)
68-
error(
71+
return error(
6972
"Exceeded maximum number of proposal $(max_prop), ",
7073
"which indicates an acceptance rate less than $(1/max_prop*100)%. ",
7174
"A quick fix is to increase `max_prop`, ",
@@ -83,16 +86,11 @@ export Slice, SliceSteppingOut, SliceDoublingOut
8386

8487
abstract type AbstractUnivariateSliceSampling <: AbstractSliceSampling end
8588

86-
accept_slice_proposal(
87-
::AbstractSliceSampling,
88-
::Any,
89-
::Real,
90-
::Real,
91-
::Real,
92-
::Real,
93-
::Real,
94-
::Real,
95-
) = true
89+
function accept_slice_proposal(
90+
::AbstractSliceSampling, ::Any, ::Real, ::Real, ::Real, ::Real, ::Real, ::Real
91+
)
92+
return true
93+
end
9694

9795
function find_interval end
9896

@@ -127,7 +125,7 @@ end
127125
@static if !isdefined(Base, :get_extension)
128126
function __init__()
129127
@require Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" include(
130-
"../ext/SliceSamplingTuringExt.jl",
128+
"../ext/SliceSamplingTuringExt.jl"
131129
)
132130
end
133131
end

src/multivariate/gibbspolar.jl

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ struct GibbsPolarSlice{W<:Real} <: AbstractMultivariateSliceSampling
2828
max_proposals::Int
2929
end
3030

31-
GibbsPolarSlice(w::Real; max_proposals::Int = DEFAULT_MAX_PROPOSALS) =
32-
GibbsPolarSlice(w, max_proposals)
31+
function GibbsPolarSlice(w::Real; max_proposals::Int=DEFAULT_MAX_PROPOSALS)
32+
return GibbsPolarSlice(w, max_proposals)
33+
end
3334

3435
struct GibbsPolarSliceState{T<:Transition,R<:Real,D<:AbstractVector}
3536
"Current [`Transition`](@ref)."
@@ -48,14 +49,14 @@ end
4849

4950
function logdensity(target::GibbsPolarSliceTarget, x)
5051
d = length(x)
51-
(d - 1) * log(norm(x)) + LogDensityProblems.logdensity(target.model, x)
52+
return (d - 1) * log(norm(x)) + LogDensityProblems.logdensity(target.model, x)
5253
end
5354

5455
function AbstractMCMC.step(
5556
rng::Random.AbstractRNG,
5657
model::AbstractMCMC.LogDensityModel,
5758
sampler::GibbsPolarSlice;
58-
initial_params = nothing,
59+
initial_params=nothing,
5960
kwargs...,
6061
)
6162
logdensitymodel = model.logdensity
@@ -76,7 +77,7 @@ function rand_subsphere(rng::Random.AbstractRNG, θ::AbstractVector)
7677
d = length(θ)
7778
V1 = randn(rng, eltype(θ), d)
7879
V2 = V1 - dot(θ, V1) * θ
79-
V2 / max(norm(V2), eps(eltype(θ)))
80+
return V2 / max(norm(V2), eps(eltype(θ)))
8081
end
8182

8283
function geodesic_shrinkage(
@@ -91,7 +92,7 @@ function geodesic_shrinkage(
9192
ω_max = convert(F, 2π) * rand(rng, F)
9293
ω_min = ω_max - convert(F, 2π)
9394

94-
for n_props = 1:max_prop
95+
for n_props in 1:max_prop
9596
# `Uniform` had a type instability issue:
9697
# https://github.com/JuliaStats/Distributions.jl/pull/1860
9798
# ω = rand(rng, Uniform(ω_min, ω_max))
@@ -108,7 +109,7 @@ function geodesic_shrinkage(
108109
ω_max = ω
109110
end
110111
end
111-
exceeded_max_prop(max_prop)
112+
return exceeded_max_prop(max_prop)
112113
end
113114

114115
function radius_shrinkage(
@@ -148,7 +149,7 @@ function radius_shrinkage(
148149
end
149150
n_props_total += n_props
150151

151-
for n_props = 1:max_prop
152+
for n_props in 1:max_prop
152153
# `Uniform` had a type instability issue:
153154
# https://github.com/JuliaStats/Distributions.jl/pull/1860
154155
#r′ = rand(rng, Uniform{F}(r_min, r_max))
@@ -165,7 +166,7 @@ function radius_shrinkage(
165166
r_max = r′
166167
end
167168
end
168-
exceeded_max_prop(max_prop)
169+
return exceeded_max_prop(max_prop)
169170
end
170171

171172
function AbstractMCMC.step(
@@ -194,9 +195,7 @@ function AbstractMCMC.step(
194195

195196
ℓp = LogDensityProblems.logdensity(logdensitymodel, x)
196197
t = Transition(
197-
x,
198-
ℓp,
199-
(num_radius_proposals = n_props_r, num_direction_proposals = n_props_θ),
198+
x, ℓp, (num_radius_proposals=n_props_r, num_direction_proposals=n_props_θ)
200199
)
201-
t, GibbsPolarSliceState(t, θ, r)
200+
return t, GibbsPolarSliceState(t, θ, r)
202201
end

src/multivariate/hitandrun.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ end
2525

2626
function LogDensityProblems.logdensity(target::HitAndRunTarget, λ)
2727
(; model, reference, direction) = target
28-
LogDensityProblems.logdensity(model, reference + λ * direction)
28+
return LogDensityProblems.logdensity(model, reference + λ * direction)
2929
end
3030

3131
function AbstractMCMC.step(
3232
rng::Random.AbstractRNG,
3333
model::AbstractMCMC.LogDensityModel,
3434
sampler::HitAndRun;
35-
initial_params = nothing,
35+
initial_params=nothing,
3636
kwargs...,
3737
)
3838
logdensitymodel = model.logdensity
@@ -46,7 +46,7 @@ end
4646

4747
function rand_uniform_unit_sphere(rng::Random.AbstractRNG, type::Type, d::Int)
4848
x = randn(rng, type, d)
49-
x / norm(x)
49+
return x / norm(x)
5050
end
5151

5252
function AbstractMCMC.step(
@@ -67,6 +67,6 @@ function AbstractMCMC.step(
6767
λ = zero(eltype(θ))
6868
λ, ℓp, props = slice_sampling_univariate(rng, unislice, hnrtarget, ℓp, λ)
6969
θ′ = θ + direction * λ
70-
t = Transition(θ′, ℓp, (num_proposals = props,))
71-
t, HitAndRunState(t)
70+
t = Transition(θ′, ℓp, (num_proposals=props,))
71+
return t, HitAndRunState(t)
7272
end

src/multivariate/latent.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ struct LatentSlice{B<:Real} <: AbstractMultivariateSliceSampling
1515
max_proposals::Int
1616
end
1717

18-
function LatentSlice(beta::Real; max_proposals::Int = DEFAULT_MAX_PROPOSALS)
18+
function LatentSlice(beta::Real; max_proposals::Int=DEFAULT_MAX_PROPOSALS)
1919
@assert beta > 0 "Beta must be strictly positive"
20-
LatentSlice(beta, max_proposals)
20+
return LatentSlice(beta, max_proposals)
2121
end
2222

2323
struct LatentSliceState{T<:Transition,S<:AbstractVector}
@@ -32,7 +32,7 @@ function AbstractMCMC.step(
3232
rng::Random.AbstractRNG,
3333
model::AbstractMCMC.LogDensityModel,
3434
sampler::LatentSlice;
35-
initial_params = nothing,
35+
initial_params=nothing,
3636
kwargs...,
3737
)
3838
logdensitymodel = model.logdensity
@@ -85,7 +85,7 @@ function AbstractMCMC.step(
8585
exceeded_max_prop(max_proposals)
8686
end
8787

88-
@inbounds for i = 1:d
88+
@inbounds for i in 1:d
8989
if ystar[i] < y[i]
9090
a[i] = ystar[i]
9191
else
@@ -94,6 +94,6 @@ function AbstractMCMC.step(
9494
end
9595
end
9696
s = β * randexp(rng, eltype(y), d) + 2 * abs.(l - y)
97-
t = Transition(y, ℓp, (num_proposals = props,))
98-
t, LatentSliceState(t, s)
97+
t = Transition(y, ℓp, (num_proposals=props,))
98+
return t, LatentSliceState(t, s)
9999
end

src/multivariate/randpermgibbs.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@ end
3333

3434
function LogDensityProblems.logdensity(gibbs::GibbsTarget, θi)
3535
(; model, idx, θ) = gibbs
36-
LogDensityProblems.logdensity(model, (@set θ[idx] = θi))
36+
return LogDensityProblems.logdensity(model, (@set θ[idx] = θi))
3737
end
3838

3939
function AbstractMCMC.step(
4040
rng::Random.AbstractRNG,
4141
model::AbstractMCMC.LogDensityModel,
4242
sampler::RandPermGibbs;
43-
initial_params = nothing,
43+
initial_params=nothing,
4444
kwargs...,
4545
)
4646
logdensitymodel = model.logdensity
@@ -75,11 +75,12 @@ function AbstractMCMC.step(
7575
for i in shuffle(rng, 1:d)
7676
model_gibbs = GibbsTarget(logdensitymodel, i, θ)
7777
unislice = unislices[i]
78-
θ′_coord, ℓp, props_coord =
79-
slice_sampling_univariate(rng, unislice, model_gibbs, ℓp, θ[i])
78+
θ′_coord, ℓp, props_coord = slice_sampling_univariate(
79+
rng, unislice, model_gibbs, ℓp, θ[i]
80+
)
8081
props[i] = props_coord
8182
θ[i] = θ′_coord
8283
end
83-
t = Transition(θ, ℓp, (num_proposals = props,))
84-
t, GibbsState(t)
84+
t = Transition(θ, ℓp, (num_proposals=props,))
85+
return t, GibbsState(t)
8586
end

0 commit comments

Comments
 (0)