Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ abstract type AbstractVarInfo <: AbstractModelTrace end
include("utils.jl")
include("chains.jl")
include("contexts.jl")
include("contexts/default.jl")
include("contexts/init.jl")
include("contexts/transformation.jl")
include("contexts/prefix.jl")
include("contexts/conditionfix.jl") # Must come after contexts/prefix.jl
Comment on lines 176 to +181
Copy link
Member Author

@penelopeysm penelopeysm Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TLDR:

contexts.jl used to contain all the parent contexts, those have been moved to their own files.

  • Now contexts.jl includes only the interface methods for AbstractContext. In a sense it's more like abstract_context.jl.
  • It's still not fully documented exactly what the interface for AbstractContext is. I'd like to make that a separate PR though because I actually think NodeTrait can be removed which would simplify matters there.

context_implementations.jl used to contain stray methods for DefaultContext, PrefixContext, and submodels. Those have been sent to their respective files.

transforming.jl basically existed to define DynamicTransformationContext so I moved it into contexts/ too. That file also defined some methods for linking AbstractVarInfo so those were moved to abstract_varinfo.jl too since they form part of the interface methods for it.

include("model.jl")
include("sampler.jl")
include("varname.jl")
Expand All @@ -187,10 +191,8 @@ include("abstract_varinfo.jl")
include("threadsafe.jl")
include("varinfo.jl")
include("simple_varinfo.jl")
include("context_implementations.jl")
include("compiler.jl")
include("pointwise_logdensities.jl")
include("transforming.jl")
include("logdensityfunction.jl")
include("model_utils.jl")
include("extract_priors.jl")
Expand Down
50 changes: 34 additions & 16 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,27 @@ end
function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return link!!(default_transformation(model, vi), vi, vns, model)
end
function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
# Note that in practice this method is only called for SimpleVarInfo, because VarInfo
# has a dedicated implementation
ctx = DynamicTransformationContext{false}()
model = contextualize(model, setleafcontext(model.context, ctx))
vi = last(evaluate!!(model, vi))
return settrans!!(vi, t)
end
function link!!(
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
)
b = inverse(t.bijector)
x = vi[:]
y, logjac = with_logabsdet_jacobian(b, x)
# Set parameters and add the logjac term.
vi = unflatten(vi, y)
if hasacc(vi, Val(:LogJacobian))
vi = acclogjac!!(vi, logjac)
end
return settrans!!(vi, t)
end

"""
link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
Expand All @@ -846,6 +867,9 @@ end
function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return link(default_transformation(model, vi), vi, vns, model)
end
function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
return link!!(t, deepcopy(vi), model)
end

"""
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
Expand All @@ -866,23 +890,14 @@ end
function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return invlink!!(default_transformation(model, vi), vi, vns, model)
end

# Vector-based ones.
function link!!(
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
)
b = inverse(t.bijector)
x = vi[:]
y, logjac = with_logabsdet_jacobian(b, x)

# Set parameters and add the logjac term.
vi = unflatten(vi, y)
if hasacc(vi, Val(:LogJacobian))
vi = acclogjac!!(vi, logjac)
end
return settrans!!(vi, t)
function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model)
# Note that in practice this method is only called for SimpleVarInfo, because VarInfo
# has a dedicated implementation
ctx = DynamicTransformationContext{true}()
model = contextualize(model, setleafcontext(model.context, ctx))
vi = last(evaluate!!(model, vi))
return settrans!!(vi, NoTransformation())
end

function invlink!!(
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
)
Expand Down Expand Up @@ -919,6 +934,9 @@ end
function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return invlink(default_transformation(model, vi), vi, vns, model)
end
function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
return invlink!!(t, deepcopy(vi), model)
end

"""
maybe_invlink_before_eval!!([t::Transformation,] vi, model)
Expand Down
128 changes: 0 additions & 128 deletions src/context_implementations.jl

This file was deleted.

Loading
Loading