@@ -39,7 +39,7 @@ julia> rng = StableRNG(42);
39
39
julia> # In the `NamedTuple` version we need to provide the place-holder values for
40
40
# the variables which are using "containers", e.g. `Array`.
41
41
# In this case, this means that we need to specify `x` but not `m`.
42
- _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, SimpleVarInfo((x = ones(2), )));
42
+ _, vi = DynamicPPL.init !!(rng, m, SimpleVarInfo((x = ones(2), )));
43
43
44
44
julia> # (✓) Vroom, vroom! FAST!!!
45
45
vi[@varname(x[1])]
@@ -57,12 +57,12 @@ julia> vi[@varname(x[1:2])]
57
57
1.3736306979834252
58
58
59
59
julia> # (×) If we don't provide the container...
60
- _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, SimpleVarInfo()); vi
60
+ _, vi = DynamicPPL.init !!(rng, m, SimpleVarInfo()); vi
61
61
ERROR: type NamedTuple has no field x
62
62
[...]
63
63
64
64
julia> # If one does not know the varnames, we can use a `OrderedDict` instead.
65
- _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()));
65
+ _, vi = DynamicPPL.init !!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()));
66
66
67
67
julia> # (✓) Sort of fast, but only possible at runtime.
68
68
vi[@varname(x[1])]
@@ -91,28 +91,28 @@ demo_constrained (generic function with 2 methods)
91
91
92
92
julia> m = demo_constrained();
93
93
94
- julia> _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, SimpleVarInfo());
94
+ julia> _, vi = DynamicPPL.init !!(rng, m, SimpleVarInfo());
95
95
96
96
julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞
97
97
1.8632965762164932
98
98
99
- julia> _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true));
99
+ julia> _, vi = DynamicPPL.init !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true));
100
100
101
101
julia> vi[@varname(x)] # (✓) -∞ < x < ∞
102
102
-0.21080155351918753
103
103
104
- julia> xs = [last(DynamicPPL.evaluate_and_sample !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
104
+ julia> xs = [last(DynamicPPL.init !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
105
105
106
106
julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
107
107
true
108
108
109
109
julia> # And with `OrderedDict` of course!
110
- _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));
110
+ _, vi = DynamicPPL.init !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));
111
111
112
112
julia> vi[@varname(x)] # (✓) -∞ < x < ∞
113
113
0.6225185067787314
114
114
115
- julia> xs = [last(DynamicPPL.evaluate_and_sample !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
115
+ julia> xs = [last(DynamicPPL.init !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
116
116
117
117
julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
118
118
true
@@ -232,24 +232,25 @@ end
232
232
233
233
# Constructor from `Model`.
234
234
function SimpleVarInfo {T} (
235
- rng:: Random.AbstractRNG , model:: Model , sampler :: AbstractSampler = SampleFromPrior ()
235
+ rng:: Random.AbstractRNG , model:: Model , init_strategy :: AbstractInitStrategy = PriorInit ()
236
236
) where {T<: Real }
237
- new_model = contextualize (model, SamplingContext (rng, sampler, model. context))
237
+ new_context = setleafcontext (model. context, InitContext (rng, init_strategy))
238
+ new_model = contextualize (model, new_context)
238
239
return last (evaluate!! (new_model, SimpleVarInfo {T} ()))
239
240
end
240
241
function SimpleVarInfo {T} (
241
- model:: Model , sampler :: AbstractSampler = SampleFromPrior ()
242
+ model:: Model , init_strategy :: AbstractInitStrategy = PriorInit ()
242
243
) where {T<: Real }
243
- return SimpleVarInfo {T} (Random. default_rng (), model, sampler )
244
+ return SimpleVarInfo {T} (Random. default_rng (), model, init_strategy )
244
245
end
245
246
# Constructors without type param
246
247
function SimpleVarInfo (
247
- rng:: Random.AbstractRNG , model:: Model , sampler :: AbstractSampler = SampleFromPrior ()
248
+ rng:: Random.AbstractRNG , model:: Model , init_strategy :: AbstractInitStrategy = PriorInit ()
248
249
)
249
- return SimpleVarInfo {LogProbType} (rng, model, sampler )
250
+ return SimpleVarInfo {LogProbType} (rng, model, init_strategy )
250
251
end
251
- function SimpleVarInfo (model:: Model , sampler :: AbstractSampler = SampleFromPrior ())
252
- return SimpleVarInfo {LogProbType} (Random. default_rng (), model, sampler )
252
+ function SimpleVarInfo (model:: Model , init_strategy :: AbstractInitStrategy = PriorInit ())
253
+ return SimpleVarInfo {LogProbType} (Random. default_rng (), model, init_strategy )
253
254
end
254
255
255
256
# Constructor from `VarInfo`.
@@ -265,12 +266,12 @@ end
265
266
266
267
function untyped_simple_varinfo (model:: Model )
267
268
varinfo = SimpleVarInfo (OrderedDict {VarName,Any} ())
268
- return last (evaluate_and_sample !! (model, varinfo))
269
+ return last (init !! (model, varinfo))
269
270
end
270
271
271
272
function typed_simple_varinfo (model:: Model )
272
273
varinfo = SimpleVarInfo {Float64} ()
273
- return last (evaluate_and_sample !! (model, varinfo))
274
+ return last (init !! (model, varinfo))
274
275
end
275
276
276
277
function unflatten (svi:: SimpleVarInfo , x:: AbstractVector )
@@ -480,7 +481,6 @@ function assume(
480
481
return value, vi
481
482
end
482
483
483
- # NOTE: We don't implement `settrans!!(vi, trans, vn)`.
484
484
function settrans!! (vi:: SimpleVarInfo , trans)
485
485
return settrans!! (vi, trans ? DynamicTransformation () : NoTransformation ())
486
486
end
490
490
function settrans!! (vi:: ThreadSafeVarInfo{<:SimpleVarInfo} , trans)
491
491
return Accessors. @set vi. varinfo = settrans!! (vi. varinfo, trans)
492
492
end
493
+ function settrans!! (vi:: SimpleOrThreadSafeSimple , trans:: Bool , :: VarName )
494
+ # We keep this method around just to obey the AbstractVarInfo interface; however,
495
+ # this is only a valid operation if it would be a no-op.
496
+ if trans != istrans (vi)
497
+ error (
498
+ " Individual variables in SimpleVarInfo cannot have different `settrans` statuses." ,
499
+ )
500
+ end
501
+ end
493
502
494
503
istrans (vi:: SimpleVarInfo ) = ! (vi. transformation isa NoTransformation)
495
504
istrans (vi:: SimpleVarInfo , :: VarName ) = istrans (vi)
0 commit comments