Skip to content

Commit d9292ad

Browse files
committed
Rename FooInit -> InitFromFoo
1 parent 2041927 commit d9292ad

File tree

7 files changed

+56
-38
lines changed

7 files changed

+56
-38
lines changed

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ function DynamicPPL.predict(
130130
rng,
131131
model,
132132
varinfo,
133-
DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()),
133+
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
134134
)
135135
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
136136
varname_vals = mapreduce(
@@ -268,7 +268,9 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
268268
# Resample any variables that are not present in `values_dict`, and
269269
# return the model's retval.
270270
retval, _ = DynamicPPL.init!!(
271-
model, varinfo, DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit())
271+
model,
272+
varinfo,
273+
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
272274
)
273275
retval
274276
end

src/sampler.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function AbstractMCMC.step(
5858
kwargs...,
5959
)
6060
vi = VarInfo()
61-
strategy = sampler isa SampleFromPrior ? PriorInit() : UniformInit()
61+
strategy = sampler isa SampleFromPrior ? InitFromPrior() : InitFromUniform()
6262
_, new_vi = DynamicPPL.init!!(rng, model, vi, strategy)
6363
return new_vi, nothing
6464
end
@@ -103,9 +103,9 @@ end
103103
init_strategy(sampler)
104104
105105
Define the initialisation strategy used for generating initial values when
106-
sampling with `sampler`. Defaults to `PriorInit()`, but can be overridden.
106+
sampling with `sampler`. Defaults to `InitFromPrior()`, but can be overridden.
107107
"""
108-
init_strategy(::Sampler) = PriorInit()
108+
init_strategy(::Sampler) = InitFromPrior()
109109

110110
function AbstractMCMC.step(
111111
rng::Random.AbstractRNG,
@@ -118,7 +118,7 @@ function AbstractMCMC.step(
118118
# with NamedTuple of Metadata).
119119
vi = default_varinfo(rng, model, spl)
120120

121-
# Fill it with initial parameters. Note that, if `ParamsInit` is used, the
121+
# Fill it with initial parameters. Note that, if `InitFromParams` is used, the
122122
# parameters provided must be in unlinked space (when inserted into the
123123
# varinfo, they will be adjusted to match the linking status of the
124124
# varinfo).

src/simple_varinfo.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,22 +232,26 @@ end
232232

233233
# Constructor from `Model`.
234234
function SimpleVarInfo{T}(
235-
rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit()
235+
rng::Random.AbstractRNG,
236+
model::Model,
237+
init_strategy::AbstractInitStrategy=InitFromPrior(),
236238
) where {T<:Real}
237239
return last(init!!(rng, model, SimpleVarInfo{T}(), init_strategy))
238240
end
239241
function SimpleVarInfo{T}(
240-
model::Model, init_strategy::AbstractInitStrategy=PriorInit()
242+
model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()
241243
) where {T<:Real}
242244
return SimpleVarInfo{T}(Random.default_rng(), model, init_strategy)
243245
end
244246
# Constructors without type param
245247
function SimpleVarInfo(
246-
rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit()
248+
rng::Random.AbstractRNG,
249+
model::Model,
250+
init_strategy::AbstractInitStrategy=InitFromPrior(),
247251
)
248252
return SimpleVarInfo{LogProbType}(rng, model, init_strategy)
249253
end
250-
function SimpleVarInfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit())
254+
function SimpleVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior())
251255
return SimpleVarInfo{LogProbType}(Random.default_rng(), model, init_strategy)
252256
end
253257

src/varinfo.jl

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,13 @@ given `rng` and `init_strategy`.
133133
instead.
134134
"""
135135
function VarInfo(
136-
rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit()
136+
rng::Random.AbstractRNG,
137+
model::Model,
138+
init_strategy::AbstractInitStrategy=InitFromPrior(),
137139
)
138140
return typed_varinfo(rng, model, init_strategy)
139141
end
140-
function VarInfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit())
142+
function VarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior())
141143
return VarInfo(Random.default_rng(), model, init_strategy)
142144
end
143145

@@ -207,14 +209,16 @@ Construct a VarInfo object for the given `model`, which has just a single
207209
# Arguments
208210
- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation
209211
- `model::Model`: The model for which to create the varinfo object
210-
- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`.
212+
- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`.
211213
"""
212214
function untyped_varinfo(
213-
rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit()
215+
rng::Random.AbstractRNG,
216+
model::Model,
217+
init_strategy::AbstractInitStrategy=InitFromPrior(),
214218
)
215219
return last(init!!(rng, model, VarInfo(Metadata()), init_strategy))
216220
end
217-
function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit())
221+
function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior())
218222
return untyped_varinfo(Random.default_rng(), model, init_strategy)
219223
end
220224

@@ -282,14 +286,16 @@ Return a VarInfo object for the given `model`, which has a NamedTuple of
282286
# Arguments
283287
- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation
284288
- `model::Model`: The model for which to create the varinfo object
285-
- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`.
289+
- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`.
286290
"""
287291
function typed_varinfo(
288-
rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit()
292+
rng::Random.AbstractRNG,
293+
model::Model,
294+
init_strategy::AbstractInitStrategy=InitFromPrior(),
289295
)
290296
return typed_varinfo(untyped_varinfo(rng, model, init_strategy))
291297
end
292-
function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit())
298+
function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior())
293299
return typed_varinfo(Random.default_rng(), model, init_strategy)
294300
end
295301

@@ -302,19 +308,21 @@ Return a VarInfo object for the given `model`, which has just a single
302308
# Arguments
303309
- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation
304310
- `model::Model`: The model for which to create the varinfo object
305-
- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`.
311+
- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`.
306312
"""
307313
function untyped_vector_varinfo(vi::UntypedVarInfo)
308314
md = metadata_to_varnamedvector(vi.metadata)
309315
return VarInfo(md, copy(vi.accs))
310316
end
311317
function untyped_vector_varinfo(
312-
rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit()
318+
rng::Random.AbstractRNG,
319+
model::Model,
320+
init_strategy::AbstractInitStrategy=InitFromPrior(),
313321
)
314322
return untyped_vector_varinfo(untyped_varinfo(rng, model, init_strategy))
315323
end
316324
function untyped_vector_varinfo(
317-
model::Model, init_strategy::AbstractInitStrategy=PriorInit()
325+
model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()
318326
)
319327
return untyped_vector_varinfo(Random.default_rng(), model, init_strategy)
320328
end
@@ -328,7 +336,7 @@ Return a VarInfo object for the given `model`, which has a NamedTuple of
328336
# Arguments
329337
- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation
330338
- `model::Model`: The model for which to create the varinfo object
331-
- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`.
339+
- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`.
332340
"""
333341
function typed_vector_varinfo(vi::NTVarInfo)
334342
md = map(metadata_to_varnamedvector, vi.metadata)
@@ -340,11 +348,15 @@ function typed_vector_varinfo(vi::UntypedVectorVarInfo)
340348
return VarInfo(nt, copy(vi.accs))
341349
end
342350
function typed_vector_varinfo(
343-
rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit()
351+
rng::Random.AbstractRNG,
352+
model::Model,
353+
init_strategy::AbstractInitStrategy=InitFromPrior(),
344354
)
345355
return typed_vector_varinfo(untyped_vector_varinfo(rng, model, init_strategy))
346356
end
347-
function typed_vector_varinfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit())
357+
function typed_vector_varinfo(
358+
model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()
359+
)
348360
return typed_vector_varinfo(Random.default_rng(), model, init_strategy)
349361
end
350362

test/ext/DynamicPPLJETExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
@test DynamicPPL.Experimental.determine_suitable_varinfo(demo4()) isa
4343
DynamicPPL.NTVarInfo
4444
init_model = DynamicPPL.contextualize(
45-
demo4(), DynamicPPL.InitContext(DynamicPPL.PriorInit())
45+
demo4(), DynamicPPL.InitContext(DynamicPPL.InitFromPrior())
4646
)
4747
@test DynamicPPL.Experimental.determine_suitable_varinfo(init_model) isa
4848
DynamicPPL.UntypedVarInfo

test/sampler.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@
6969
end
7070

7171
# initial samplers
72-
DynamicPPL.init_strategy(::Sampler{OnlyInitAlgUniform}) = UniformInit()
73-
@test DynamicPPL.init_strategy(Sampler(OnlyInitAlgDefault())) == PriorInit()
72+
DynamicPPL.init_strategy(::Sampler{OnlyInitAlgUniform}) = InitFromUniform()
73+
@test DynamicPPL.init_strategy(Sampler(OnlyInitAlgDefault())) == InitFromPrior()
7474

7575
for alg in (OnlyInitAlgDefault(), OnlyInitAlgUniform())
7676
# model with one variable: initialization p = 0.2
@@ -81,7 +81,7 @@
8181
model = coinflip()
8282
sampler = Sampler(alg)
8383
lptrue = logpdf(Binomial(25, 0.2), 10)
84-
let inits = ParamsInit((; p=0.2))
84+
let inits = InitFromParams((; p=0.2))
8585
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
8686
@test chain[1].metadata.p.vals == [0.2]
8787
@test getlogjoint(chain[1]) == lptrue
@@ -109,7 +109,7 @@
109109
end
110110
model = twovars()
111111
lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1)
112-
let inits = ParamsInit((; s=4, m=-1))
112+
let inits = InitFromParams((; s=4, m=-1))
113113
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
114114
@test chain[1].metadata.s.vals == [4]
115115
@test chain[1].metadata.m.vals == [-1]
@@ -133,7 +133,7 @@
133133
end
134134

135135
# set only m = -1
136-
for inits in (ParamsInit((; s=missing, m=-1)), ParamsInit((; m=-1)))
136+
for inits in (InitFromParams((; s=missing, m=-1)), InitFromParams((; m=-1)))
137137
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
138138
@test !ismissing(chain[1].metadata.s.vals[1])
139139
@test chain[1].metadata.m.vals == [-1]

test/varinfo.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ end
4242
end
4343
model = gdemo(1.0, 2.0)
4444

45-
_, vi = DynamicPPL.init!!(model, VarInfo(), UniformInit())
45+
_, vi = DynamicPPL.init!!(model, VarInfo(), InitFromUniform())
4646
tvi = DynamicPPL.typed_varinfo(vi)
4747

4848
meta = vi.metadata
@@ -479,18 +479,18 @@ end
479479
end
480480
model = gdemo([1.0, 1.5], [2.0, 2.5])
481481

482-
# Check that instantiating the model using UniformInit does not
482+
# Check that instantiating the model using InitFromUniform does not
483483
# perform linking
484-
# Note (penelopeysm): The purpose of using UniformInit specifically in
484+
# Note (penelopeysm): The purpose of using InitFromUniform specifically in
485485
# this test is because it samples from the linked distribution i.e. in
486486
# unconstrained space. However, it does this not by linking the varinfo
487487
# but by transforming the distributions on the fly. That's why it's
488488
# worth specifically checking that it can do this without having to
489489
# change the VarInfo object.
490-
# TODO(penelopeysm): Move this to UniformInit tests rather than here.
490+
# TODO(penelopeysm): Move this to InitFromUniform tests rather than here.
491491
vi = VarInfo()
492492
meta = vi.metadata
493-
_, vi = DynamicPPL.init!!(model, vi, UniformInit())
493+
_, vi = DynamicPPL.init!!(model, vi, InitFromUniform())
494494
@test all(x -> !istrans(vi, x), meta.vns)
495495

496496
# Check that linking and invlinking set the `trans` flag accordingly
@@ -554,7 +554,7 @@ end
554554

555555
function test_linked_varinfo(model, vi)
556556
# vn and dist are taken from the containing scope
557-
vi = last(DynamicPPL.init!!(model, vi, PriorInit()))
557+
vi = last(DynamicPPL.init!!(model, vi, InitFromPrior()))
558558
f = DynamicPPL.from_linked_internal_transform(vi, vn, dist)
559559
x = f(DynamicPPL.getindex_internal(vi, vn))
560560
@test istrans(vi, vn)
@@ -972,7 +972,7 @@ end
972972
varinfo1 = DynamicPPL.link!!(VarInfo(model1), model1)
973973
# Calling init!! should preserve the fact that the variables are linked.
974974
model2 = demo(2)
975-
varinfo2 = last(DynamicPPL.init!!(model2, deepcopy(varinfo1), PriorInit()))
975+
varinfo2 = last(DynamicPPL.init!!(model2, deepcopy(varinfo1), InitFromPrior()))
976976
for vn in [@varname(x[1]), @varname(x[2])]
977977
@test DynamicPPL.istrans(varinfo2, vn)
978978
end
@@ -990,7 +990,7 @@ end
990990
varinfo1 = DynamicPPL.link!!(DynamicPPL.untyped_varinfo(model1), model1)
991991
# Calling init!! should preserve the fact that the variables are linked.
992992
model2 = demo_dot(2)
993-
varinfo2 = last(DynamicPPL.init!!(model2, deepcopy(varinfo1), PriorInit()))
993+
varinfo2 = last(DynamicPPL.init!!(model2, deepcopy(varinfo1), InitFromPrior()))
994994
for vn in [@varname(x), @varname(y[1])]
995995
@test DynamicPPL.istrans(varinfo2, vn)
996996
end

0 commit comments

Comments
 (0)