Skip to content

Commit 7d312cd

Browse files
torfjeldeyebaidevmotiongithub-actions[bot]
authored
Add fix and unfix (#488)
* aded FixedContext and everything that goes with it * initial work on making fix compat with compiler * added support for dot tilde * exported and added testing for fix and unfix * added equivalent support to condition * fixed some docstrings * added lots of documentation plus some doctests for fixing * added docs on fix * bump patch version * renamd getvalue and hasvalue for contexts to more descriptive get_conditioned_value and has_conditioned_value * formatting * Update src/model.jl * fixeed typo in docstring * fixed docs * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Update src/model.jl * Apply suggestions from code review * Update Project.toml * Update docs/src/api.md Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: David Widmann <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent e8172f0 commit 7d312cd

File tree

7 files changed

+705
-48
lines changed

7 files changed

+705
-48
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.23.1"
3+
version = "0.23.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

docs/src/api.md

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,34 @@ Similarly, one can specify with [`AbstractPPL.decondition`](@ref) that certain,
8282
decondition
8383
```
8484

85+
## Fixing and unfixing
86+
87+
We can also _fix_ a collection of variables in a [`Model`](@ref) to certain using [`fix`](@ref).
88+
89+
This might seem quite similar to the aforementioned [`condition`](@ref) and its siblings,
90+
but they are indeed different operations:
91+
92+
- `condition`ed variables are considered to be _observations_, and are thus
93+
included in the computation [`logjoint`](@ref) and [`loglikelihood`](@ref),
94+
but not in [`logprior`](@ref).
95+
- `fix`ed variables are considered to be _constant_, and are thus not included
96+
in any log-probability computations.
97+
98+
The differences are more clearly spelled out in the docstring of [`fix`](@ref) below.
99+
100+
```@docs
101+
fix
102+
DynamicPPL.fixed
103+
```
104+
105+
The difference between [`fix`](@ref) and [`condition`](@ref) is described in the docstring of [`fix`](@ref) above.
106+
107+
Similarly, we can [`unfix`](@ref) variables, i.e. return them to their original meaning:
108+
109+
```@docs
110+
unfix
111+
```
112+
85113
## Utilities
86114

87115
It is possible to manually increase (or decrease) the accumulated log density from within a model function.
@@ -327,4 +355,3 @@ dot_tilde_assume
327355
tilde_observe
328356
dot_tilde_observe
329357
```
330-

src/DynamicPPL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ export AbstractVarInfo,
119119
pointwise_loglikelihoods,
120120
condition,
121121
decondition,
122+
fix,
123+
unfix,
122124
# Convenience macros
123125
@addlogprob!,
124126
@submodel

src/compiler.jl

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ function contextual_isassumption(context::AbstractContext, vn)
6666
return contextual_isassumption(NodeTrait(context), context, vn)
6767
end
6868
function contextual_isassumption(context::ConditionContext, vn)
69-
if hasvalue(context, vn)
70-
val = getvalue(context, vn)
69+
if hasconditioned(context, vn)
70+
val = getconditioned(context, vn)
7171
# TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler?
7272
if eltype(val) >: Missing && val === missing
7373
return true
@@ -76,14 +76,48 @@ function contextual_isassumption(context::ConditionContext, vn)
7676
end
7777
end
7878

79-
# We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`
79+
# We might have nested contexts, e.g. `ConditionContext{.., <:PrefixContext{..., <:ConditionContext}}`
8080
# so we defer to `childcontext` if we haven't concluded that anything yet.
8181
return contextual_isassumption(childcontext(context), vn)
8282
end
8383
function contextual_isassumption(context::PrefixContext, vn)
8484
return contextual_isassumption(childcontext(context), prefix(context, vn))
8585
end
8686

87+
isfixed(expr, vn) = false
88+
isfixed(::Union{Symbol,Expr}, vn) = :($(DynamicPPL.contextual_isfixed)(__context__, $vn))
89+
90+
"""
91+
contextual_isfixed(context, vn)
92+
93+
Return `true` if `vn` is considered fixed by `context`.
94+
"""
95+
contextual_isfixed(::IsLeaf, context, vn) = false
96+
function contextual_isfixed(::IsParent, context, vn)
97+
return contextual_isfixed(childcontext(context), vn)
98+
end
99+
function contextual_isfixed(context::AbstractContext, vn)
100+
return contextual_isfixed(NodeTrait(context), context, vn)
101+
end
102+
function contextual_isfixed(context::PrefixContext, vn)
103+
return contextual_isfixed(childcontext(context), prefix(context, vn))
104+
end
105+
function contextual_isfixed(context::FixedContext, vn)
106+
if hasfixed(context, vn)
107+
val = getfixed(context, vn)
108+
# TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler?
109+
if eltype(val) >: Missing && val === missing
110+
return false
111+
else
112+
return true
113+
end
114+
end
115+
116+
# We might have nested contexts, e.g. `FixedContext{.., <:PrefixContext{..., <:FixedContext}}`
117+
# so we defer to `childcontext` if we haven't concluded that anything yet.
118+
return contextual_isfixed(childcontext(context), vn)
119+
end
120+
87121
# If we're working with, say, a `Symbol`, then we're not going to `view`.
88122
maybe_view(x) = x
89123
maybe_view(x::Expr) = :(@views($x))
@@ -341,12 +375,14 @@ function generate_tilde(left, right)
341375
$(AbstractPPL.drop_escape(varname(left))), $dist
342376
)
343377
$isassumption = $(DynamicPPL.isassumption(left, vn))
344-
if $isassumption
378+
if $(DynamicPPL.isfixed(left, vn))
379+
$left = $(DynamicPPL.getfixed_nested)(__context__, $vn)
380+
elseif $isassumption
345381
$(generate_tilde_assume(left, dist, vn))
346382
else
347383
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
348384
if !$(DynamicPPL.inargnames)($vn, __model__)
349-
$left = $(DynamicPPL.getvalue_nested)(__context__, $vn)
385+
$left = $(DynamicPPL.getconditioned_nested)(__context__, $vn)
350386
end
351387

352388
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
@@ -400,12 +436,14 @@ function generate_dot_tilde(left, right)
400436
$(AbstractPPL.drop_escape(varname(left))), $right
401437
)
402438
$isassumption = $(DynamicPPL.isassumption(left, vn))
403-
if $isassumption
439+
if $(DynamicPPL.isfixed(left, vn))
440+
$left .= $(DynamicPPL.getfixed_nested)(__context__, $vn)
441+
elseif $isassumption
404442
$(generate_dot_tilde_assume(left, right, vn))
405443
else
406444
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
407445
if !$(DynamicPPL.inargnames)($vn, __model__)
408-
$left .= $(DynamicPPL.getvalue_nested)(__context__, $vn)
446+
$left .= $(DynamicPPL.getconditioned_nested)(__context__, $vn)
409447
end
410448

411449
$value, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)(

0 commit comments

Comments
 (0)