Skip to content

Commit 7045b5f

Browse files
committed
Add setleafcontext(::Model, ::AbstractContext)
1 parent d3d32e4 commit 7045b5f

File tree

7 files changed

+22
-24
lines changed

7 files changed

+22
-24
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ The separation of these functions was primarily implemented to avoid performing
5252

5353
**Other changes**
5454

55+
### `setleafcontext(model, context)`
56+
57+
This convenience method has been added to quickly modify the leaf context of a model.
58+
5559
### Reimplementation of functions using `InitContext`
5660

5761
A number of functions have been reimplemented and unified with the help of `InitContext`.

ext/DynamicPPLJETExt.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet(
2424
varinfo = DynamicPPL.typed_varinfo(model)
2525

2626
# Check type stability of evaluation (i.e. DefaultContext)
27-
model = DynamicPPL.contextualize(
28-
model, DynamicPPL.setleafcontext(model.context, DynamicPPL.DefaultContext())
29-
)
27+
model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext())
3028
eval_issuccess, eval_result = DynamicPPL.Experimental.is_suitable_varinfo(
3129
model, varinfo; only_ddpl
3230
)
@@ -36,9 +34,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet(
3634
end
3735

3836
# Check type stability of initialisation (i.e. InitContext)
39-
model = DynamicPPL.contextualize(
40-
model, DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext())
41-
)
37+
model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext())
4238
init_issuccess, init_result = DynamicPPL.Experimental.is_suitable_varinfo(
4339
model, varinfo; only_ddpl
4440
)

src/model.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,15 @@ with its underlying context set to `context`.
9494
function contextualize(model::Model, context::AbstractContext)
9595
return Model(model.f, model.args, model.defaults, context)
9696
end
97+
"""
98+
setleafcontext(model::Model, context::AbstractContext)
99+
100+
Return a new `Model` with its leaf context set to `context`. This is a convenience
101+
shortcut for `contextualize(model, setleafcontext(model.context, context)`).
102+
"""
103+
function setleafcontext(model::Model, context::AbstractContext)
104+
return contextualize(model, setleafcontext(model.context, context))
105+
end
97106

98107
"""
99108
model | (x = 1.0, ...)
@@ -886,8 +895,7 @@ function init!!(
886895
varinfo::AbstractVarInfo,
887896
init_strategy::AbstractInitStrategy=InitFromPrior(),
888897
)
889-
new_context = setleafcontext(model.context, InitContext(rng, init_strategy))
890-
new_model = contextualize(model, new_context)
898+
new_model = setleafcontext(model, InitContext(rng, init_strategy))
891899
return evaluate!!(new_model, varinfo)
892900
end
893901
function init!!(

src/test_utils/contexts.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPP
4747
_, untyped_vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo())
4848
typed_vi = DynamicPPL.typed_varinfo(untyped_vi)
4949
# Set the test context as the new leaf context
50-
new_model = contextualize(model, DynamicPPL.setleafcontext(model.context, context))
50+
new_model = DynamicPPL.setleafcontext(model, context)
5151
# Check that evaluation works
5252
for vi in [untyped_vi, typed_vi]
5353
_, vi = DynamicPPL.evaluate!!(new_model, vi)

src/threadsafe.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,12 @@ end
103103
# consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates
104104
# to define `getacc(vi)`.
105105
function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
106-
model = contextualize(
107-
model, setleafcontext(model.context, DynamicTransformationContext{false}())
108-
)
106+
model = setleafcontext(model, DynamicTransformationContext{false}())
109107
return settrans!!(last(evaluate!!(model, vi)), t)
110108
end
111109

112110
function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
113-
model = contextualize(
114-
model, setleafcontext(model.context, DynamicTransformationContext{true}())
115-
)
111+
model = setleafcontext(model, DynamicTransformationContext{true}())
116112
return settrans!!(last(evaluate!!(model, vi)), NoTransformation())
117113
end
118114

src/transforming.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ function _transform!!(
5959
model::Model,
6060
)
6161
# To transform using DynamicTransformationContext, we evaluate the model using that as the leaf context:
62-
model = contextualize(model, setleafcontext(model.context, ctx))
62+
model = setleafcontext(model, ctx)
6363
vi = settrans!!(last(evaluate!!(model, vi)), t)
6464
return vi
6565
end

test/ext/DynamicPPLJETExt.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,20 +81,14 @@
8181
typed_vi = DynamicPPL.typed_varinfo(model)
8282

8383
@info "Evaluating with DefaultContext:"
84-
model = DynamicPPL.contextualize(
85-
model,
86-
DynamicPPL.setleafcontext(model.context, DynamicPPL.DefaultContext()),
87-
)
84+
model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext())
8885
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
8986
model, varinfo
9087
)
9188
JET.test_call(f, argtypes)
9289

9390
@info "Initialising with InitContext:"
94-
model = DynamicPPL.contextualize(
95-
model,
96-
DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext()),
97-
)
91+
model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext())
9892
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
9993
model, varinfo
10094
)

0 commit comments

Comments
 (0)