Skip to content

Commit 6705d7b

Browse files
committed
Rename {Prior,Uniform,Params}Init -> InitFrom{Prior,Uniform,Params}
1 parent deed931 commit 6705d7b

File tree

5 files changed

+65
-59
lines changed

5 files changed

+65
-59
lines changed

docs/src/api.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -474,9 +474,9 @@ To accomplish this, an initialisation _strategy_ is required, which defines how
474474
There are three concrete strategies provided in DynamicPPL:
475475

476476
```@docs
477-
PriorInit
478-
UniformInit
479-
ParamsInit
477+
InitFromPrior
478+
InitFromUniform
479+
InitFromParams
480480
```
481481

482482
If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method.

src/DynamicPPL.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ export AbstractVarInfo,
111111
# Initialisation
112112
InitContext,
113113
AbstractInitStrategy,
114-
PriorInit,
115-
UniformInit,
116-
ParamsInit,
114+
InitFromPrior,
115+
InitFromUniform,
116+
InitFromParams,
117117
# Pseudo distributions
118118
NamedDist,
119119
NoDist,

src/contexts/init.jl

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,24 @@ Generate a new value for a random variable with the given distribution.
1717
!!! warning "Return values must be unlinked"
1818
The values returned by `init` must always be in the untransformed space, i.e.,
1919
they must be within the support of the original distribution. That means that,
20-
for example, `init(rng, dist, u::UniformInit)` will in general return values that
20+
for example, `init(rng, dist, u::InitFromUniform)` will in general return values that
2121
are outside the range [u.lower, u.upper].
2222
"""
2323
function init end
2424

2525
"""
26-
PriorInit()
26+
InitFromPrior()
2727
2828
Obtain new values by sampling from the prior distribution.
2929
"""
30-
struct PriorInit <: AbstractInitStrategy end
31-
init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::PriorInit) = rand(rng, dist)
30+
struct InitFromPrior <: AbstractInitStrategy end
31+
function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior)
32+
return rand(rng, dist)
33+
end
3234

3335
"""
34-
UniformInit()
35-
UniformInit(lower, upper)
36+
InitFromUniform()
37+
InitFromUniform(lower, upper)
3638
3739
Obtain new values by first transforming the distribution of the random variable
3840
to unconstrained space, then sampling a value uniformly between `lower` and
@@ -47,17 +49,17 @@ Requires that `lower <= upper`.
4749
4850
[Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization)
4951
"""
50-
struct UniformInit{T<:AbstractFloat} <: AbstractInitStrategy
52+
struct InitFromUniform{T<:AbstractFloat} <: AbstractInitStrategy
5153
lower::T
5254
upper::T
53-
function UniformInit(lower::T, upper::T) where {T<:AbstractFloat}
55+
function InitFromUniform(lower::T, upper::T) where {T<:AbstractFloat}
5456
lower > upper &&
5557
throw(ArgumentError("`lower` must be less than or equal to `upper`"))
5658
return new{T}(lower, upper)
5759
end
58-
UniformInit() = UniformInit(-2.0, 2.0)
60+
InitFromUniform() = InitFromUniform(-2.0, 2.0)
5961
end
60-
function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::UniformInit)
62+
function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFromUniform)
6163
b = Bijectors.bijector(dist)
6264
sz = Bijectors.output_size(b, size(dist))
6365
y = u.lower .+ ((u.upper - u.lower) .* rand(rng, sz...))
@@ -71,9 +73,9 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::Uniform
7173
end
7274

7375
"""
74-
ParamsInit(
76+
InitFromParams(
7577
params::Union{AbstractDict{<:VarName},NamedTuple},
76-
fallback::Union{AbstractInitStrategy,Nothing}=PriorInit()
78+
fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior()
7779
)
7880
7981
Obtain new values by extracting them from the given dictionary or NamedTuple.
@@ -82,28 +84,30 @@ The parameter `fallback` specifies how new values are to be obtained if they
8284
cannot be found in `params`, or they are specified as `missing`. `fallback`
8385
can either be an initialisation strategy itself, in which case it will be
8486
used to obtain new values, or it can be `nothing`, in which case an error
85-
will be thrown. The default for `fallback` is `PriorInit()`.
87+
will be thrown. The default for `fallback` is `InitFromPrior()`.
8688
8789
!!! note
8890
The values in `params` must be provided in the space of the untransformed
89-
distribution.
91+
distribution.
9092
"""
91-
struct ParamsInit{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy
93+
struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy
9294
params::P
9395
fallback::S
94-
function ParamsInit(
96+
function InitFromParams(
9597
params::AbstractDict{<:VarName}, fallback::Union{AbstractInitStrategy,Nothing}
9698
)
9799
return new{typeof(params),typeof(fallback)}(params, fallback)
98100
end
99-
ParamsInit(params::AbstractDict{<:VarName}) = ParamsInit(params, PriorInit())
100-
function ParamsInit(
101-
params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=PriorInit()
101+
function InitFromParams(params::AbstractDict{<:VarName})
102+
return InitFromParams(params, InitFromPrior())
103+
end
104+
function InitFromParams(
105+
params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior()
102106
)
103-
return ParamsInit(to_varname_dict(params), fallback)
107+
return InitFromParams(to_varname_dict(params), fallback)
104108
end
105109
end
106-
function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::ParamsInit)
110+
function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams)
107111
# TODO(penelopeysm): It would be nice to do a check to make sure that all
108112
# of the parameters in `p.params` were actually used, and either warn or
109113
# error if they aren't. This is actually quite non-trivial though because
@@ -128,7 +132,7 @@ end
128132
"""
129133
InitContext(
130134
[rng::Random.AbstractRNG=Random.default_rng()],
131-
[strategy::AbstractInitStrategy=PriorInit()],
135+
[strategy::AbstractInitStrategy=InitFromPrior()],
132136
)
133137
134138
A leaf context that indicates that new values for random variables are
@@ -140,11 +144,11 @@ struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractCon
140144
rng::R
141145
strategy::S
142146
function InitContext(
143-
rng::Random.AbstractRNG, strategy::AbstractInitStrategy=PriorInit()
147+
rng::Random.AbstractRNG, strategy::AbstractInitStrategy=InitFromPrior()
144148
)
145149
return new{typeof(rng),typeof(strategy)}(rng, strategy)
146150
end
147-
function InitContext(strategy::AbstractInitStrategy=PriorInit())
151+
function InitContext(strategy::AbstractInitStrategy=InitFromPrior())
148152
return InitContext(Random.default_rng(), strategy)
149153
end
150154
end

src/model.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -894,30 +894,32 @@ end
894894
[rng::Random.AbstractRNG,]
895895
model::Model,
896896
varinfo::AbstractVarInfo,
897-
[init_strategy::AbstractInitStrategy=PriorInit()]
897+
[init_strategy::AbstractInitStrategy=InitFromPrior()]
898898
)
899899
900900
Evaluate the `model` and replace the values of the model's random variables in
901901
the given `varinfo` with new values using a specified initialisation strategy.
902902
If the values in `varinfo` are not already present, they will be added using
903903
that same strategy.
904904
905-
If `init_strategy` is not provided, defaults to PriorInit().
905+
If `init_strategy` is not provided, defaults to InitFromPrior().
906906
907907
Returns a tuple of the model's return value, plus the updated `varinfo` object.
908908
"""
909909
function init!!(
910910
rng::Random.AbstractRNG,
911911
model::Model,
912912
varinfo::AbstractVarInfo,
913-
init_strategy::AbstractInitStrategy=PriorInit(),
913+
init_strategy::AbstractInitStrategy=InitFromPrior(),
914914
)
915915
new_context = setleafcontext(model.context, InitContext(rng, init_strategy))
916916
new_model = contextualize(model, new_context)
917917
return evaluate!!(new_model, varinfo)
918918
end
919919
function init!!(
920-
model::Model, varinfo::AbstractVarInfo, init_strategy::AbstractInitStrategy=PriorInit()
920+
model::Model,
921+
varinfo::AbstractVarInfo,
922+
init_strategy::AbstractInitStrategy=InitFromPrior(),
921923
)
922924
return init!!(Random.default_rng(), model, varinfo, init_strategy)
923925
end

test/contexts.jl

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -537,36 +537,36 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
537537
end
538538
end
539539

540-
@testset "PriorInit" begin
541-
test_generating_new_values(PriorInit())
542-
test_replacing_values(PriorInit())
543-
test_rng_respected(PriorInit())
544-
test_link_status_respected(PriorInit())
540+
@testset "InitFromPrior" begin
541+
test_generating_new_values(InitFromPrior())
542+
test_replacing_values(InitFromPrior())
543+
test_rng_respected(InitFromPrior())
544+
test_link_status_respected(InitFromPrior())
545545

546546
@testset "check that values are within support" begin
547547
# Not many other sensible checks we can do for priors.
548548
@model just_unif() = x ~ Uniform(0.0, 1e-7)
549549
for _ in 1:100
550-
_, vi = DynamicPPL.init!!(just_unif(), VarInfo(), PriorInit())
550+
_, vi = DynamicPPL.init!!(just_unif(), VarInfo(), InitFromPrior())
551551
@test vi[@varname(x)] isa Real
552552
@test 0.0 <= vi[@varname(x)] <= 1e-7
553553
end
554554
end
555555
end
556556

557-
@testset "UniformInit" begin
558-
test_generating_new_values(UniformInit())
559-
test_replacing_values(UniformInit())
560-
test_rng_respected(UniformInit())
561-
test_link_status_respected(UniformInit())
557+
@testset "InitFromUniform" begin
558+
test_generating_new_values(InitFromUniform())
559+
test_replacing_values(InitFromUniform())
560+
test_rng_respected(InitFromUniform())
561+
test_link_status_respected(InitFromUniform())
562562

563563
@testset "check that bounds are respected" begin
564564
@testset "unconstrained" begin
565565
umin, umax = -1.0, 1.0
566566
@model just_norm() = x ~ Normal()
567567
for _ in 1:100
568568
_, vi = DynamicPPL.init!!(
569-
just_norm(), VarInfo(), UniformInit(umin, umax)
569+
just_norm(), VarInfo(), InitFromUniform(umin, umax)
570570
)
571571
@test vi[@varname(x)] isa Real
572572
@test umin <= vi[@varname(x)] <= umax
@@ -579,7 +579,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
579579
tmin, tmax = inv_bijector(umin), inv_bijector(umax)
580580
for _ in 1:100
581581
_, vi = DynamicPPL.init!!(
582-
just_beta(), VarInfo(), UniformInit(umin, umax)
582+
just_beta(), VarInfo(), InitFromUniform(umin, umax)
583583
)
584584
@test vi[@varname(x)] isa Real
585585
@test tmin <= vi[@varname(x)] <= tmax
@@ -588,9 +588,9 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
588588
end
589589
end
590590

591-
@testset "ParamsInit" begin
592-
test_link_status_respected(ParamsInit((; a=1.0)))
593-
test_link_status_respected(ParamsInit(Dict(@varname(a) => 1.0)))
591+
@testset "InitFromParams" begin
592+
test_link_status_respected(InitFromParams((; a=1.0)))
593+
test_link_status_respected(InitFromParams(Dict(@varname(a) => 1.0)))
594594

595595
@testset "given full set of parameters" begin
596596
# test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I)
@@ -600,13 +600,13 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
600600
model = test_init_model()
601601
@testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos
602602
_, vi = DynamicPPL.init!!(
603-
model, deepcopy(empty_vi), ParamsInit(params_nt)
603+
model, deepcopy(empty_vi), InitFromParams(params_nt)
604604
)
605605
@test vi[@varname(x)] == my_x
606606
@test vi[@varname(y)] == my_y
607607
logp_nt = getlogp(vi)
608608
_, vi = DynamicPPL.init!!(
609-
model, deepcopy(empty_vi), ParamsInit(params_dict)
609+
model, deepcopy(empty_vi), InitFromParams(params_dict)
610610
)
611611
@test vi[@varname(x)] == my_x
612612
@test vi[@varname(y)] == my_y
@@ -621,12 +621,12 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
621621
params_dict = Dict(@varname(x) => my_x)
622622
model = test_init_model()
623623
@testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos
624-
@testset "with PriorInit fallback" begin
624+
@testset "with InitFromPrior fallback" begin
625625
_, vi = DynamicPPL.init!!(
626626
Xoshiro(468),
627627
model,
628628
deepcopy(empty_vi),
629-
ParamsInit(params_nt, PriorInit()),
629+
InitFromParams(params_nt, InitFromPrior()),
630630
)
631631
@test vi[@varname(x)] == my_x
632632
nt_y = vi[@varname(y)]
@@ -636,7 +636,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
636636
Xoshiro(469),
637637
model,
638638
deepcopy(empty_vi),
639-
ParamsInit(params_dict, PriorInit()),
639+
InitFromParams(params_dict, InitFromPrior()),
640640
)
641641
@test vi[@varname(x)] == my_x
642642
dict_y = vi[@varname(y)]
@@ -649,10 +649,10 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
649649
@testset "with no fallback" begin
650650
# These just don't have an entry for `y`.
651651
@test_throws ErrorException DynamicPPL.init!!(
652-
model, deepcopy(empty_vi), ParamsInit(params_nt, nothing)
652+
model, deepcopy(empty_vi), InitFromParams(params_nt, nothing)
653653
)
654654
@test_throws ErrorException DynamicPPL.init!!(
655-
model, deepcopy(empty_vi), ParamsInit(params_dict, nothing)
655+
model, deepcopy(empty_vi), InitFromParams(params_dict, nothing)
656656
)
657657
# We also explicitly test the case where `y = missing`.
658658
params_nt_missing = (; x=my_x, y=missing)
@@ -662,12 +662,12 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
662662
@test_throws ErrorException DynamicPPL.init!!(
663663
model,
664664
deepcopy(empty_vi),
665-
ParamsInit(params_nt_missing, nothing),
665+
InitFromParams(params_nt_missing, nothing),
666666
)
667667
@test_throws ErrorException DynamicPPL.init!!(
668668
model,
669669
deepcopy(empty_vi),
670-
ParamsInit(params_dict_missing, nothing),
670+
InitFromParams(params_dict_missing, nothing),
671671
)
672672
end
673673
end

0 commit comments

Comments
 (0)