Skip to content

Commit 08212a2

Browse files
authored
Fixes for Turing 0.41 (#1057)
* setleafcontext(model, ctx) and various other fixes * fix a bug * Add warning for `initial_parameters=...`
1 parent 11d0b69 commit 08212a2

File tree

9 files changed

+58
-41
lines changed

9 files changed

+58
-41
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ Previously `VarInfo` (or more correctly, the `Metadata` object within a `VarInfo
5656

5757
**Other changes**
5858

59+
### `setleafcontext(model, context)`
60+
61+
This convenience method has been added to quickly modify the leaf context of a model.
62+
5963
### Reimplementation of functions using `InitContext`
6064

6165
A number of functions have been reimplemented and unified with the help of `InitContext`.

ext/DynamicPPLJETExt.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet(
2424
varinfo = DynamicPPL.typed_varinfo(model)
2525

2626
# Check type stability of evaluation (i.e. DefaultContext)
27-
model = DynamicPPL.contextualize(
28-
model, DynamicPPL.setleafcontext(model.context, DynamicPPL.DefaultContext())
29-
)
27+
model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext())
3028
eval_issuccess, eval_result = DynamicPPL.Experimental.is_suitable_varinfo(
3129
model, varinfo; only_ddpl
3230
)
@@ -36,9 +34,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet(
3634
end
3735

3836
# Check type stability of initialisation (i.e. InitContext)
39-
model = DynamicPPL.contextualize(
40-
model, DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext())
41-
)
37+
model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext())
4238
init_issuccess, init_result = DynamicPPL.Experimental.is_suitable_varinfo(
4339
model, varinfo; only_ddpl
4440
)

src/abstract_varinfo.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -830,8 +830,7 @@ end
830830
function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
831831
# Note that in practice this method is only called for SimpleVarInfo, because VarInfo
832832
# has a dedicated implementation
833-
ctx = DynamicTransformationContext{false}()
834-
model = contextualize(model, setleafcontext(model.context, ctx))
833+
model = setleafcontext(model, DynamicTransformationContext{false}())
835834
vi = last(evaluate!!(model, vi))
836835
return settrans!!(vi, t)
837836
end
@@ -893,8 +892,7 @@ end
893892
function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model)
894893
# Note that in practice this method is only called for SimpleVarInfo, because VarInfo
895894
# has a dedicated implementation
896-
ctx = DynamicTransformationContext{true}()
897-
model = contextualize(model, setleafcontext(model.context, ctx))
895+
model = setleafcontext(model, DynamicTransformationContext{true}())
898896
vi = last(evaluate!!(model, vi))
899897
return settrans!!(vi, NoTransformation())
900898
end

src/contexts.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,17 @@ DynamicTransformationContext{true}()
5858
setchildcontext
5959

6060
"""
61-
leafcontext(context)
61+
leafcontext(context::AbstractContext)
6262
6363
Return the leaf of `context`, i.e. the first descendant context that `IsLeaf`.
6464
"""
65-
leafcontext(context) = leafcontext(NodeTrait(leafcontext, context), context)
66-
leafcontext(::IsLeaf, context) = context
67-
leafcontext(::IsParent, context) = leafcontext(childcontext(context))
65+
leafcontext(context::AbstractContext) =
66+
leafcontext(NodeTrait(leafcontext, context), context)
67+
leafcontext(::IsLeaf, context::AbstractContext) = context
68+
leafcontext(::IsParent, context::AbstractContext) = leafcontext(childcontext(context))
6869

6970
"""
70-
setleafcontext(left, right)
71+
setleafcontext(left::AbstractContext, right::AbstractContext)
7172
7273
Return `left` but now with its leaf context replaced by `right`.
7374
@@ -103,19 +104,21 @@ julia> # Append another parent context.
103104
ParentContext(ParentContext(ParentContext(DefaultContext())))
104105
```
105106
"""
106-
function setleafcontext(left, right)
107+
function setleafcontext(left::AbstractContext, right::AbstractContext)
107108
return setleafcontext(
108109
NodeTrait(setleafcontext, left), NodeTrait(setleafcontext, right), left, right
109110
)
110111
end
111-
function setleafcontext(::IsParent, ::IsParent, left, right)
112+
function setleafcontext(
113+
::IsParent, ::IsParent, left::AbstractContext, right::AbstractContext
114+
)
112115
return setchildcontext(left, setleafcontext(childcontext(left), right))
113116
end
114-
function setleafcontext(::IsParent, ::IsLeaf, left, right)
117+
function setleafcontext(::IsParent, ::IsLeaf, left::AbstractContext, right::AbstractContext)
115118
return setchildcontext(left, setleafcontext(childcontext(left), right))
116119
end
117-
setleafcontext(::IsLeaf, ::IsParent, left, right) = right
118-
setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right
120+
setleafcontext(::IsLeaf, ::IsParent, left::AbstractContext, right::AbstractContext) = right
121+
setleafcontext(::IsLeaf, ::IsLeaf, left::AbstractContext, right::AbstractContext) = right
119122

120123
"""
121124
DynamicPPL.tilde_assume!!(

src/model.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,16 @@ function contextualize(model::Model, context::AbstractContext)
9595
return Model(model.f, model.args, model.defaults, context)
9696
end
9797

98+
"""
99+
setleafcontext(model::Model, context::AbstractContext)
100+
101+
Return a new `Model` with its leaf context set to `context`. This is a convenience shortcut
102+
for `contextualize(model, setleafcontext(model.context, context)`).
103+
"""
104+
function setleafcontext(model::Model, context::AbstractContext)
105+
return contextualize(model, setleafcontext(model.context, context))
106+
end
107+
98108
"""
99109
model | (x = 1.0, ...)
100110
@@ -886,8 +896,7 @@ function init!!(
886896
varinfo::AbstractVarInfo,
887897
init_strategy::AbstractInitStrategy=InitFromPrior(),
888898
)
889-
new_context = setleafcontext(model.context, InitContext(rng, init_strategy))
890-
new_model = contextualize(model, new_context)
899+
new_model = setleafcontext(model, InitContext(rng, init_strategy))
891900
return evaluate!!(new_model, varinfo)
892901
end
893902
function init!!(

src/sampler.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@ function default_varinfo(rng::Random.AbstractRNG, model::Model, ::AbstractSample
4646
end
4747

4848
"""
49-
init_strategy(sampler)
49+
init_strategy(sampler::AbstractSampler)
5050
5151
Define the initialisation strategy used for generating initial values when
5252
sampling with `sampler`. Defaults to `InitFromPrior()`, but can be overridden.
5353
"""
54-
init_strategy(::Sampler) = InitFromPrior()
54+
init_strategy(::AbstractSampler) = InitFromPrior()
5555

5656
function AbstractMCMC.sample(
5757
rng::Random.AbstractRNG,
@@ -60,11 +60,15 @@ function AbstractMCMC.sample(
6060
N::Integer;
6161
chain_type=default_chain_type(sampler),
6262
resume_from=nothing,
63+
initial_params=init_strategy(sampler),
6364
initial_state=loadstate(resume_from),
6465
kwargs...,
6566
)
67+
if hasproperty(kwargs, :initial_parameters)
68+
@warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead."
69+
end
6670
return AbstractMCMC.mcmcsample(
67-
rng, model, sampler, N; chain_type, initial_state, kwargs...
71+
rng, model, sampler, N; chain_type, initial_params, initial_state, kwargs...
6872
)
6973
end
7074

@@ -76,12 +80,25 @@ function AbstractMCMC.sample(
7680
N::Integer,
7781
nchains::Integer;
7882
chain_type=default_chain_type(sampler),
83+
initial_params=fill(init_strategy(sampler), nchains),
7984
resume_from=nothing,
8085
initial_state=loadstate(resume_from),
8186
kwargs...,
8287
)
88+
if hasproperty(kwargs, :initial_parameters)
89+
@warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead."
90+
end
8391
return AbstractMCMC.mcmcsample(
84-
rng, model, sampler, parallel, N, nchains; chain_type, initial_state, kwargs...
92+
rng,
93+
model,
94+
sampler,
95+
parallel,
96+
N,
97+
nchains;
98+
chain_type,
99+
initial_params,
100+
initial_state,
101+
kwargs...,
85102
)
86103
end
87104

src/test_utils/contexts.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPP
4747
_, untyped_vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo())
4848
typed_vi = DynamicPPL.typed_varinfo(untyped_vi)
4949
# Set the test context as the new leaf context
50-
new_model = contextualize(model, DynamicPPL.setleafcontext(model.context, context))
50+
new_model = DynamicPPL.setleafcontext(model, context)
5151
# Check that evaluation works
5252
for vi in [untyped_vi, typed_vi]
5353
_, vi = DynamicPPL.evaluate!!(new_model, vi)

src/threadsafe.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,12 @@ end
103103
# consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates
104104
# to define `getacc(vi)`.
105105
function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
106-
model = contextualize(
107-
model, setleafcontext(model.context, DynamicTransformationContext{false}())
108-
)
106+
model = setleafcontext(model, DynamicTransformationContext{false}())
109107
return settrans!!(last(evaluate!!(model, vi)), t)
110108
end
111109

112110
function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
113-
model = contextualize(
114-
model, setleafcontext(model.context, DynamicTransformationContext{true}())
115-
)
111+
model = setleafcontext(model, DynamicTransformationContext{true}())
116112
return settrans!!(last(evaluate!!(model, vi)), NoTransformation())
117113
end
118114

test/ext/DynamicPPLJETExt.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,20 +81,14 @@
8181
typed_vi = DynamicPPL.typed_varinfo(model)
8282

8383
@info "Evaluating with DefaultContext:"
84-
model = DynamicPPL.contextualize(
85-
model,
86-
DynamicPPL.setleafcontext(model.context, DynamicPPL.DefaultContext()),
87-
)
84+
model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext())
8885
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
8986
model, varinfo
9087
)
9188
JET.test_call(f, argtypes)
9289

9390
@info "Initialising with InitContext:"
94-
model = DynamicPPL.contextualize(
95-
model,
96-
DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext()),
97-
)
91+
model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext())
9892
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
9993
model, varinfo
10094
)

0 commit comments

Comments
 (0)