Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,54 @@
# DynamicPPL Changelog

## 0.36.0

**Breaking changes**

### VarName prefixing behaviour

The way in which VarNames in submodels are prefixed has been changed.
This is best explained through an example.
Consider this model and submodel:

```julia
using DynamicPPL, Distributions
@model inner() = x ~ Normal()
@model outer() = a ~ to_submodel(inner())
```

In previous versions, the inner variable `x` would be saved as `a.x`.
However, this was represented as a single symbol `Symbol("a.x")`:

```julia
julia> dump(keys(VarInfo(outer()))[1])
VarName{Symbol("a.x"), typeof(identity)}
optic: identity (function of type typeof(identity))
```

Now, the inner variable is stored as a field `x` on the VarName `a`:

```julia
julia> dump(keys(VarInfo(outer()))[1])
VarName{:a, Accessors.PropertyLens{:x}}
optic: Accessors.PropertyLens{:x} (@o _.x)
```

In practice, this means that if you are trying to condition a variable in the submodel, you now need to use

```julia
outer() | (@varname(a.x) => 1.0,)
```

instead of either of these (which would have worked previously)

```julia
outer() | (@varname(var"a.x") => 1.0,)
outer() | (a.x=1.0,)
```

If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain.
(This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.)

## 0.35.5

Several internal methods have been removed:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ DynamicPPLMooncakeExt = ["Mooncake"]
[compat]
ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.10.1"
AbstractPPL = "0.11"
Accessors = "0.1"
BangBang = "0.4.1"
Bijectors = "0.13.18, 0.14, 0.15"
Expand Down
5 changes: 3 additions & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ using DocStringExtensions

using Random: Random

# For extending
import AbstractPPL: predict, prefix

# TODO: Remove these when it's possible.
import Bijectors: link, invlink

Expand All @@ -39,8 +42,6 @@ import Base:
keys,
haskey

import AbstractPPL: predict

# VarInfo
export AbstractVarInfo,
VarInfo,
Expand Down
4 changes: 2 additions & 2 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ function contextual_isassumption(context::ConditionContext, vn)
return contextual_isassumption(childcontext(context), vn)
end
function contextual_isassumption(context::PrefixContext, vn)
return contextual_isassumption(childcontext(context), prefix(context, vn))
return contextual_isassumption(childcontext(context), prefix_with_context(context, vn))
end

isfixed(expr, vn) = false
Expand All @@ -132,7 +132,7 @@ function contextual_isfixed(context::AbstractContext, vn)
return contextual_isfixed(NodeTrait(context), context, vn)
end
function contextual_isfixed(context::PrefixContext, vn)
return contextual_isfixed(childcontext(context), prefix(context, vn))
return contextual_isfixed(childcontext(context), prefix_with_context(context, vn))
end
function contextual_isfixed(context::FixedContext, vn)
if hasfixed(context, vn)
Expand Down
6 changes: 4 additions & 2 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,14 @@ function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, rig
end

function tilde_assume(context::PrefixContext, right, vn, vi)
return tilde_assume(context.context, right, prefix(context, vn), vi)
return tilde_assume(context.context, right, prefix_with_context(context, vn), vi)
end
function tilde_assume(
rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi
)
return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), vi)
return tilde_assume(
rng, context.context, sampler, right, prefix_with_context(context, vn), vi
)
end

"""
Expand Down
40 changes: 20 additions & 20 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,25 +260,25 @@
return PrefixContext{Prefix}(child)
end

const PREFIX_SEPARATOR = Symbol(".")

@generated function PrefixContext{PrefixOuter}(
context::PrefixContext{PrefixInner}
) where {PrefixOuter,PrefixInner}
return :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}(
context.context
))
end
"""
prefix_with_context(ctx::AbstractContext, vn::VarName)
function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
vn_prefixed_inner = prefix(childcontext(ctx), vn)
return VarName{Symbol(Prefix, PREFIX_SEPARATOR, getsym(vn_prefixed_inner))}(
getoptic(vn_prefixed_inner)
Apply the prefixes in the context `ctx` to the variable name `vn`.
"""
function prefix_with_context(
ctx::PrefixContext{Prefix}, vn::VarName{Sym}
) where {Prefix,Sym}
return AbstractPPL.prefix(
prefix_with_context(childcontext(ctx), vn), VarName{Symbol(Prefix)}()
)
end
prefix(ctx::AbstractContext, vn::VarName) = prefix(NodeTrait(ctx), ctx, vn)
prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn
prefix(::IsParent, ctx::AbstractContext, vn::VarName) = prefix(childcontext(ctx), vn)
function prefix_with_context(ctx::AbstractContext, vn::VarName)
return prefix_with_context(NodeTrait(ctx), ctx, vn)
end
prefix_with_context(::IsLeaf, ::AbstractContext, vn::VarName) = vn
function prefix_with_context(::IsParent, ctx::AbstractContext, vn::VarName)
return prefix_with_context(childcontext(ctx), vn)
end
Copy link
Member Author

@penelopeysm penelopeysm Mar 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to change the name of this function because Aqua was complaining about method ambiguities. I'm not entirely sure why but I think it's because AbstractPPL defines prefix(::VarName, ::VarName) and if we also defined prefix(::AbstractContext, ::VarName) here, it's probably something to do with the fact that AbstractContext is an abstract type. I still don't entirely see why this is a problem but changing the name did fix it 😅

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need

import AbstractPPL: predict, prefix

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I am thinking that maybe we don't need to extend AbstractPPL.prefix? Otherwise, the name prefix_with_context is fine.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ehhh, I think that's fair. I think I tried that route as well, and the only thing I didn't entirely like is that it means that in DynamicPPL we have to be explicit every time we write prefix - we would have to write either DynamicPPL.prefix or AbstractPPL.prefix. Do you have a preference?

(Personally, I actually think it's better if they are separate - semantically they are different things. Also, it's a nightmare trying to pull up docstrings for the right function when packages just keep extending it - this is the case with sample for example. But historically we have done a lot of this overloading-functions-from-base-packages thing so extending it was kind of in line with that.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer they have separate names -- not a huge fun of name overloading, I think this is abusing multiple dispatch. So totally agree with you.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, we'll go with DynamicPPL.prefix :)


"""
prefix(model::Model, x)
Expand Down Expand Up @@ -392,7 +392,7 @@
return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn)
end
function hasconditioned_nested(context::PrefixContext, vn)
return hasconditioned_nested(childcontext(context), prefix(context, vn))
return hasconditioned_nested(childcontext(context), prefix_with_context(context, vn))
end

"""
Expand All @@ -410,7 +410,7 @@
return error("context $(context) does not contain value for $vn")
end
function getconditioned_nested(context::PrefixContext, vn)
return getconditioned_nested(childcontext(context), prefix(context, vn))
return getconditioned_nested(childcontext(context), prefix_with_context(context, vn))
end
function getconditioned_nested(::IsParent, context, vn)
return if hasconditioned(context, vn)
Expand Down Expand Up @@ -543,7 +543,7 @@
return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn)
end
function hasfixed_nested(context::PrefixContext, vn)
return hasfixed_nested(childcontext(context), prefix(context, vn))
return hasfixed_nested(childcontext(context), prefix_with_context(context, vn))

Check warning on line 546 in src/contexts.jl

View check run for this annotation

Codecov / codecov/patch

src/contexts.jl#L546

Added line #L546 was not covered by tests
end

"""
Expand All @@ -561,7 +561,7 @@
return error("context $(context) does not contain value for $vn")
end
function getfixed_nested(context::PrefixContext, vn)
return getfixed_nested(childcontext(context), prefix(context, vn))
return getfixed_nested(childcontext(context), prefix_with_context(context, vn))
end
function getfixed_nested(::IsParent, context, vn)
return if hasfixed(context, vn)
Expand Down
2 changes: 1 addition & 1 deletion src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ function DynamicPPL.setchildcontext(context::DebugContext, child)
end

function record_varname!(context::DebugContext, varname::VarName, dist)
prefixed_varname = prefix(context, varname)
prefixed_varname = DynamicPPL.prefix_with_context(context, varname)
if haskey(context.varnames_seen, prefixed_varname)
if context.error_on_failure
error("varname $prefixed_varname used multiple times in model")
Expand Down
58 changes: 17 additions & 41 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ julia> model() ≠ 1.0
true

julia> # To condition the variable inside `demo_inner` we need to refer to it as `inner.m`.
conditioned_model = model | (var"inner.m" = 1.0, );
conditioned_model = model | (@varname(inner.m) => 1.0, );

julia> conditioned_model()
1.0
Expand All @@ -255,15 +255,6 @@ julia> conditioned_model_fail()
ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported
[...]
```

And similarly when using `Dict`:

```jldoctest condition
julia> conditioned_model_dict = model | (@varname(var"inner.m") => 1.0);

julia> conditioned_model_dict()
1.0
```
"""
function AbstractPPL.condition(model::Model, values...)
# Positional arguments - need to handle cases carefully
Expand Down Expand Up @@ -443,16 +434,16 @@ julia> conditioned(cm)
julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed,
# `a.m` is treated as a random variable.
keys(VarInfo(cm))
1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}:
1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}:
a.m

julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation.
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((var"a.m"=1.0,)))), x=100.0);
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext(Dict(@varname(a.m) => 1.0)))), x=100.0);

julia> conditioned(cm).x
julia> conditioned(cm)[@varname(x)]
100.0

julia> conditioned(cm).var"a.m"
julia> conditioned(cm)[@varname(a.m)]
1.0

julia> keys(VarInfo(cm)) # No variables are sampled
Expand Down Expand Up @@ -583,7 +574,7 @@ julia> model = demo_outer();
julia> model() ≠ 1.0
true

julia> fixed_model = fix(model, var"inner.m" = 1.0, );
julia> fixed_model = fix(model, (@varname(inner.m) => 1.0, ));

julia> fixed_model()
1.0
Expand All @@ -599,24 +590,9 @@ julia> fixed_model()
2.0
```

And similarly when using `Dict`:

```jldoctest fix
julia> fixed_model_dict = fix(model, @varname(var"inner.m") => 1.0);

julia> fixed_model_dict()
1.0

julia> fixed_model_dict = fix(model, @varname(inner) => 2.0);

julia> fixed_model_dict()
2.0
```

## Difference from `condition`

A very similar functionality is also provided by [`condition`](@ref) which,
not surprisingly, _conditions_ variables instead of fixing them. The only
A very similar functionality is also provided by [`condition`](@ref). The only
difference between fixing and conditioning is as follows:
- `condition`ed variables are considered to be observations, and are thus
included in the computation [`logjoint`](@ref) and [`loglikelihood`](@ref),
Expand Down Expand Up @@ -798,16 +774,16 @@ julia> fixed(cm)
julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed,
# `a.m` is treated as a random variable.
keys(VarInfo(cm))
1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}:
1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}:
a.m

julia> # If we instead fix on `a.m`, `m` in the model will be considered an observation.
cm = fix(contextualize(m, PrefixContext{:a}(fix(var"a.m"=1.0))), x=100.0);
cm = fix(contextualize(m, PrefixContext{:a}(fix(@varname(a.m) => 1.0,))), x=100.0);

julia> fixed(cm).x
julia> fixed(cm)[@varname(x)]
100.0

julia> fixed(cm).var"a.m"
julia> fixed(cm)[@varname(a.m)]
1.0

julia> keys(VarInfo(cm)) # <= no variables are sampled
Expand Down Expand Up @@ -1365,7 +1341,7 @@ When we sample from the model `demo2(missing, 0.4)` random variable `x` will be
```jldoctest submodel-to_submodel
julia> vi = VarInfo(demo2(missing, 0.4));

julia> @varname(var\"a.x\") in keys(vi)
julia> @varname(a.x) in keys(vi)
true
```

Expand All @@ -1379,7 +1355,7 @@ false
We can check that the log joint probability of the model accumulated in `vi` is correct:

```jldoctest submodel-to_submodel
julia> x = vi[@varname(var\"a.x\")];
julia> x = vi[@varname(a.x)];

julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4)
true
Expand Down Expand Up @@ -1417,10 +1393,10 @@ julia> @model function demo2(x, y, z)

julia> vi = VarInfo(demo2(missing, missing, 0.4));

julia> @varname(var"sub1.x") in keys(vi)
julia> @varname(sub1.x) in keys(vi)
true

julia> @varname(var"sub2.x") in keys(vi)
julia> @varname(sub2.x) in keys(vi)
true
```

Expand All @@ -1437,9 +1413,9 @@ false
We can check that the log joint probability of the model accumulated in `vi` is correct:

```jldoctest submodel-to_submodel-prefix
julia> sub1_x = vi[@varname(var"sub1.x")];
julia> sub1_x = vi[@varname(sub1.x)];

julia> sub2_x = vi[@varname(var"sub2.x")];
julia> sub2_x = vi[@varname(sub2.x)];

julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x);

Expand Down
Loading