Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
7488238
fixed incorrect implementation of `dot_tilde_assume` for `PrefixContext`
torfjelde Nov 1, 2024
1d211c5
removed `vars` field from `PriorContext` and `LikelihoodContext` as
torfjelde Nov 1, 2024
fcccfd2
Merge branch 'master' into torfjelde/context-cleanup
penelopeysm Nov 2, 2024
f978287
replaced `NoDist` with `nodist`
torfjelde Nov 5, 2024
ae84112
fixed method ambiguity issue
torfjelde Nov 5, 2024
881e28d
added missing `Distributions.rand!` definition for `NoDist`
torfjelde Nov 5, 2024
1f6fc62
added more elaborate testing of evaluation of contexts
torfjelde Nov 5, 2024
55e24e9
added `DynamicPPL.TestUtils.test_context` for testing contexts and re…
torfjelde Nov 5, 2024
e5b7e44
added proper testing for PrefixContext of all demo models
torfjelde Nov 5, 2024
e1c8fd1
formatting
torfjelde Nov 5, 2024
23bc877
Merge branch 'master' into torfjelde/context-cleanup
torfjelde Nov 6, 2024
decaa78
Merge branch 'master' into torfjelde/context-cleanup
torfjelde Nov 8, 2024
b67f51b
added some dropped tests
torfjelde Nov 11, 2024
1da4c6e
Update src/test_utils.jl
torfjelde Nov 11, 2024
b551356
Update src/test_utils.jl
torfjelde Nov 11, 2024
1b91a72
Merge branch 'master' into torfjelde/context-cleanup
penelopeysm Nov 11, 2024
247f754
Update test/debug_utils.jl
torfjelde Nov 12, 2024
ebbda53
bump patch version
torfjelde Nov 12, 2024
8d41d7b
Merge branch 'master' into torfjelde/context-cleanup
torfjelde Nov 25, 2024
527ef0c
Merge remote-tracking branch 'origin/master' into torfjelde/context-c…
penelopeysm Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 2 additions & 103 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,44 +77,6 @@
return tilde_assume(rng, childcontext(context), args...)
end

function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(PriorContext(), right, vn, vi)
end
function tilde_assume(
rng::Random.AbstractRNG, context::PriorContext{<:NamedTuple}, sampler, right, vn, vi
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, PriorContext(), sampler, right, vn, vi)
end

function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(LikelihoodContext(), right, vn, vi)
end
function tilde_assume(
rng::Random.AbstractRNG,
context::LikelihoodContext{<:NamedTuple},
sampler,
right,
vn,
vi,
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi)
end
function tilde_assume(::LikelihoodContext, right, vn, vi)
return assume(NoDist(right), vn, vi)
end
Expand Down Expand Up @@ -328,37 +290,6 @@
end

# `LikelihoodContext`
function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, vn, vi)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(LikelihoodContext(), right, left, vn, vi)
end
end
function dot_tilde_assume(
rng::Random.AbstractRNG,
context::LikelihoodContext{<:NamedTuple},
sampler,
right,
left,
vn,
vi,
)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi)
end
end

function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi)
return dot_assume(nodist(right), left, vn, vi)
end
Expand All @@ -368,46 +299,14 @@
return dot_assume(rng, sampler, nodist(right), vn, left, vi)
end

# `PriorContext`
function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, vi)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(PriorContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(PriorContext(), right, left, vn, vi)
end
end
function dot_tilde_assume(
rng::Random.AbstractRNG,
context::PriorContext{<:NamedTuple},
sampler,
right,
left,
vn,
vi,
)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi)
end
end

# `PrefixContext`
function dot_tilde_assume(context::PrefixContext, right, left, vn, vi)
return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), vi)
return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi)

Check warning on line 304 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L304

Added line #L304 was not covered by tests
end

function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, vi)
return dot_tilde_assume(
rng, context.context, sampler, right, prefix.(Ref(context), vn), vi
rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi
)
end

Expand Down
29 changes: 8 additions & 21 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ DefaultContext()
julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior

julia> DynamicPPL.childcontext(ctx_prior)
PriorContext{Nothing}(nothing)
PriorContext()
```
"""
setchildcontext
Expand Down Expand Up @@ -97,7 +97,7 @@ ParentContext(ParentContext(DefaultContext()))

julia> # Replace the leaf context with another leaf.
leafcontext(setleafcontext(ctx, PriorContext()))
PriorContext{Nothing}(nothing)
PriorContext()

julia> # Append another parent context.
setleafcontext(ctx, ParentContext(DefaultContext()))
Expand Down Expand Up @@ -195,32 +195,19 @@ struct DefaultContext <: AbstractContext end
NodeTrait(context::DefaultContext) = IsLeaf()

"""
struct PriorContext{Tvars} <: AbstractContext
vars::Tvars
end
PriorContext <: AbstractContext

The `PriorContext` enables the computation of the log prior of the parameters `vars` when
running the model.
A leaf context resulting in the exclusion of likelihood terms when running the model.
"""
struct PriorContext{Tvars} <: AbstractContext
vars::Tvars
end
PriorContext() = PriorContext(nothing)
struct PriorContext <: AbstractContext end
NodeTrait(context::PriorContext) = IsLeaf()

"""
struct LikelihoodContext{Tvars} <: AbstractContext
vars::Tvars
end
LikelihoodContext <: AbstractContext

The `LikelihoodContext` enables the computation of the log likelihood of the parameters when
running the model. `vars` can be used to evaluate the log likelihood for specific values
of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default.
A leaf context resulting in the exclusion of prior terms when running the model.
"""
struct LikelihoodContext{Tvars} <: AbstractContext
vars::Tvars
end
LikelihoodContext() = LikelihoodContext(nothing)
struct LikelihoodContext <: AbstractContext end
NodeTrait(context::LikelihoodContext) = IsLeaf()

"""
Expand Down
Loading