Skip to content

Commit 5ac886d

Browse files
torfjeldeyebaigithub-actions[bot]CompatHelper Julia
authored
Removed deprecation of init_params + bump minor version (#355)
* removed deprecation of init_params and only allow the new initial_params * bump minor version since change is breaking * Fix some tests. (#356) * CompatHelper: add new compat entry for Statistics at version 1, (keep existing compat) (#354) Co-authored-by: CompatHelper Julia <[email protected]> * Update constructors.jl * Update constructors.jl * Update test/constructors.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update Project.toml * Update abstractmcmc.jl --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: CompatHelper Julia <[email protected]> * Fix docs CI (#357) * Update make.jl * Update trajectory.jl * Update integrator.jl * Update metric.jl * Update constructors.jl * Update api.md * Update constructors.jl * Update Adaptation.jl * Update abstractmcmc.jl * Update Adaptation.jl * Update Adaptation.jl * Update api.md * Update constructors.jl * Update api.md * Update make.jl * Update docs/make.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update docs/make.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: CompatHelper Julia <[email protected]>
1 parent 362a053 commit 5ac886d

File tree

10 files changed

+78
-42
lines changed

10 files changed

+78
-42
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedHMC"
22
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
3-
version = "0.5.6"
3+
version = "0.6.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -43,7 +43,7 @@ ProgressMeter = "1"
4343
Requires = "0.5, 1"
4444
Setfield = "0.7, 0.8, 1"
4545
SimpleUnPack = "1.1"
46-
Statistics = "1"
46+
Statistics = "1.6"
4747
StatsBase = "0.31, 0.32, 0.33, 0.34"
4848
StatsFuns = "0.8, 0.9, 1"
4949
julia = "1.6"

docs/make.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ using AdvancedHMC
55

66
# cp(joinpath(@__DIR__, "../README.md"), joinpath(@__DIR__, "src/index.md"))
77

8-
makedocs(sitename = "AdvancedHMC", format = Documenter.HTML(), modules = [AdvancedHMC])
8+
makedocs(
9+
sitename = "AdvancedHMC",
10+
format = Documenter.HTML(),
11+
warnonly = [:cross_references],
12+
)
913

1014
deploydocs(
1115
repo = "github.com/TuringLang/AdvancedHMC.jl.git",

docs/src/api.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Documentation for AdvancedHMC.jl
55
```@contents
66
```
77

8-
## Structs
8+
## Types
99
```@docs
1010
ClassicNoUTurn
1111
HMCSampler
@@ -18,4 +18,11 @@ HMCDA
1818

1919
```@docs
2020
sample
21-
```
21+
```
22+
23+
## More types
24+
25+
```@autodocs; canonical=false
26+
Modules = [AdvancedHMC, AdvancedHMC.Adaptation]
27+
Order = [:type]
28+
```

src/abstractmcmc.jl

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ A convenient wrapper around `AbstractMCMC.sample` avoiding explicit construction
3838

3939
function AbstractMCMC.sample(
4040
rng::Random.AbstractRNG,
41-
model::LogDensityModel,
41+
model::AbstractMCMC.LogDensityModel,
4242
sampler::AbstractHMCSampler,
4343
N::Integer;
4444
n_adapts::Int = min(div(N, 10), 1_000),
@@ -67,7 +67,7 @@ end
6767

6868
function AbstractMCMC.sample(
6969
rng::Random.AbstractRNG,
70-
model::LogDensityModel,
70+
model::AbstractMCMC.LogDensityModel,
7171
sampler::AbstractHMCSampler,
7272
parallel::AbstractMCMC.AbstractMCMCEnsemble,
7373
N::Integer,
@@ -101,17 +101,11 @@ end
101101

102102
function AbstractMCMC.step(
103103
rng::AbstractRNG,
104-
model::LogDensityModel,
104+
model::AbstractMCMC.LogDensityModel,
105105
spl::AbstractHMCSampler;
106106
initial_params = nothing,
107-
init_params = initial_params,
108107
kwargs...,
109108
)
110-
if init_params !== initial_params
111-
Base.depwarn("`init_params` is deprecated, use `initial_params` instead", :step)
112-
initial_params = init_params
113-
end
114-
115109
# Unpack model
116110
logdensity = model.logdensity
117111

@@ -123,7 +117,7 @@ function AbstractMCMC.step(
123117

124118
# Define integration algorithm
125119
# Find good eps if not provided one
126-
initial_params = make_init_params(rng, spl, logdensity, initial_params)
120+
initial_params = make_initial_params(rng, spl, logdensity, initial_params)
127121
ϵ = make_step_size(rng, spl, hamiltonian, initial_params)
128122
integrator = make_integrator(spl, ϵ)
129123

@@ -144,7 +138,7 @@ end
144138

145139
function AbstractMCMC.step(
146140
rng::AbstractRNG,
147-
model::LogDensityModel,
141+
model::AbstractMCMC.LogDensityModel,
148142
spl::AbstractHMCSampler,
149143
state::HMCState;
150144
kwargs...,
@@ -257,18 +251,18 @@ end
257251
#############
258252
### Utils ###
259253
#############
260-
function make_init_params(
254+
function make_initial_params(
261255
rng::AbstractRNG,
262256
spl::AbstractHMCSampler,
263257
logdensity,
264-
init_params,
258+
initial_params,
265259
)
266260
T = sampler_eltype(spl)
267-
if init_params == nothing
261+
if initial_params == nothing
268262
d = LogDensityProblems.dimension(logdensity)
269-
init_params = randn(rng, d)
263+
initial_params = randn(rng, d)
270264
end
271-
return T.(init_params)
265+
return T.(initial_params)
272266
end
273267

274268
#########
@@ -277,21 +271,21 @@ function make_step_size(
277271
rng::Random.AbstractRNG,
278272
spl::HMCSampler,
279273
hamiltonian::Hamiltonian,
280-
init_params,
274+
initial_params,
281275
)
282276
T = typeof(spl.κ.τ.integrator.ϵ)
283-
ϵ = make_step_size(rng, spl.κ.τ.integrator, T, hamiltonian, init_params)
277+
ϵ = make_step_size(rng, spl.κ.τ.integrator, T, hamiltonian, initial_params)
284278
return ϵ
285279
end
286280

287281
function make_step_size(
288282
rng::Random.AbstractRNG,
289283
spl::AbstractHMCSampler,
290284
hamiltonian::Hamiltonian,
291-
init_params,
285+
initial_params,
292286
)
293287
T = sampler_eltype(spl)
294-
return make_step_size(rng, spl.integrator, T, hamiltonian, init_params)
288+
return make_step_size(rng, spl.integrator, T, hamiltonian, initial_params)
295289

296290
end
297291

@@ -300,12 +294,12 @@ function make_step_size(
300294
integrator::AbstractIntegrator,
301295
T::Type,
302296
hamiltonian::Hamiltonian,
303-
init_params,
297+
initial_params,
304298
)
305299
if integrator.ϵ > 0
306300
ϵ = integrator.ϵ
307301
else
308-
ϵ = find_good_stepsize(rng, hamiltonian, init_params)
302+
ϵ = find_good_stepsize(rng, hamiltonian, initial_params)
309303
@info string("Found initial step size ", ϵ)
310304
end
311305
return T(ϵ)
@@ -316,9 +310,9 @@ function make_step_size(
316310
integrator::Symbol,
317311
T::Type,
318312
hamiltonian::Hamiltonian,
319-
init_params,
313+
initial_params,
320314
)
321-
ϵ = find_good_stepsize(rng, hamiltonian, init_params)
315+
ϵ = find_good_stepsize(rng, hamiltonian, initial_params)
322316
@info string("Found initial step size ", ϵ)
323317
return T(ϵ)
324318
end

src/adaptation/Adaptation.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@ using Statistics: Statistics
66
using SimpleUnPack: @unpack, @pack!
77

88
using ..AdvancedHMC: DEBUG, AbstractScalarOrVec
9+
using DocStringExtensions
910

11+
"""
12+
$(TYPEDEF)
13+
14+
Abstract type for HMC adaptors.
15+
"""
1016
abstract type AbstractAdaptor end
1117
function getM⁻¹ end
1218
function getϵ end

src/constructors.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ in favour of determined types from the other arguments.
88
"""
99
determine_sampler_eltype(xs...) = float(_determine_sampler_eltype(xs...))
1010
# NOTE: We want to defer conversion to `float` until the very "end" of the
11-
# process, so as to allow `promote_type` to do it's job properly.
11+
# process to allow `promote_type` to do its job properly.
1212
# For example, in the scenario `determine_sampler_eltype(::Int64, ::Float32)`
1313
# we want to return `Float32`, not `Float64`. The latter would occur
1414
# if we did `float(eltype(x))` instead of just `eltype(x)`.
@@ -49,15 +49,15 @@ Note that all the fields have the prefix `initial_` to indicate
4949
that these will not necessarily correspond to the `kernel`, `metric`,
5050
and `adaptor` after sampling.
5151
52-
To access the updated fields use the resulting [`HMCState`](@ref).
52+
To access the updated fields, use the resulting [`HMCState`](@ref).
5353
"""
5454
struct HMCSampler{K<:AbstractMCMCKernel,M<:AbstractMetric,A<:AbstractAdaptor} <:
5555
AbstractHMCSampler
5656
"[`AbstractMCMCKernel`](@ref)."
5757
κ::K
5858
"Choice of initial metric [`AbstractMetric`](@ref). The metric type will be preserved during adaption."
5959
metric::M
60-
"[`AbstractAdaptor`](@ref)."
60+
"[`AdvancedHMC.Adaptation.AbstractAdaptor`](@ref)."
6161
adaptor::A
6262
end
6363

src/integrator.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# TODO: The type `<:Tuple{Integer,Bool}` is introduced to address
66
# https://github.com/TuringLang/Turing.jl/pull/941#issuecomment-549191813
77
# We might want to simplify it to `Tuple{Int,Bool}` when we figured out
8-
# why the it behaves unexpected on Windos 32.
8+
# why the it behaves unexpected on Windows 32.
99

1010
"""
1111
$(TYPEDEF)

src/metric.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
"""
2+
$(TYPEDEF)
3+
4+
Abstract type for preconditioning metrics.
5+
"""
16
abstract type AbstractMetric end
27

38
_string_M⁻¹(mat::AbstractMatrix, n_chars::Int = 32) = _string_M⁻¹(diag(mat), n_chars)

src/trajectory.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,28 @@ end
2525
"Returns the statistics for transition `t`."
2626
stat(t::Transition) = t.stat
2727

28+
"""
29+
$(TYPEDEF)
30+
Abstract type for HMC kernels.
31+
"""
2832
abstract type AbstractMCMCKernel end
2933

34+
"""
35+
$(TYPEDEF)
36+
Abstract type for termination criteria for Hamiltonian trajectories, e.g. no-U-turn and fixed number of leapfrog integration steps.
37+
"""
3038
abstract type AbstractTerminationCriterion end
3139

40+
"""
41+
$(TYPEDEF)
42+
Abstract type for a fixed number of leapfrog integration steps.
43+
"""
3244
abstract type StaticTerminationCriterion <: AbstractTerminationCriterion end
3345

46+
"""
47+
$(TYPEDEF)
48+
Abstract type for dynamic Hamiltonian trajectory termination criteria.
49+
"""
3450
abstract type DynamicTerminationCriterion <: AbstractTerminationCriterion end
3551

3652
"""

test/constructors.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,13 @@ get_kernel_hyperparamsT(spl::NUTS, state) = typeof(state.κ.τ.termination_crite
131131
@test AdvancedHMC.sampler_eltype(sampler) == T
132132

133133
# Step.
134-
transition, state =
135-
AbstractMCMC.step(rng, model, sampler; n_adapts = 0, init_params = θ_init)
136-
134+
transition, state = AbstractMCMC.step(
135+
rng,
136+
model,
137+
sampler;
138+
n_adapts = 0,
139+
initial_params = θ_init,
140+
)
137141
# Verify that the types are preserved in the transition.
138142
@test eltype(transition.z.θ) == T
139143
@test eltype(transition.z.r) == T
@@ -159,7 +163,7 @@ get_kernel_hyperparamsT(spl::NUTS, state) = typeof(state.κ.τ.termination_crite
159163
end
160164

161165
@testset "Utils" begin
162-
@testset "init_params" begin
166+
@testset "initial_params" begin
163167
d = 2
164168
θ_init = randn(d)
165169
rng = Random.default_rng()
@@ -171,10 +175,10 @@ end
171175
metric = AdvancedHMC.make_metric(spl, logdensity)
172176
hamiltonian = Hamiltonian(metric, model)
173177

174-
init_params1 = AdvancedHMC.make_init_params(rng, spl, logdensity, nothing)
175-
@test typeof(init_params1) == Vector{T}
176-
@test length(init_params1) == d
177-
init_params2 = AdvancedHMC.make_init_params(rng, spl, logdensity, θ_init)
178-
@test init_params2 == θ_init
178+
initial_params1 = AdvancedHMC.make_initial_params(rng, spl, logdensity, nothing)
179+
@test typeof(initial_params1) == Vector{T}
180+
@test length(initial_params1) == d
181+
initial_params2 = AdvancedHMC.make_initial_params(rng, spl, logdensity, θ_init)
182+
@test initial_params2 == θ_init
179183
end
180184
end

0 commit comments

Comments
 (0)