Skip to content

Commit c0921ae

Browse files
committed
Shuffle context code around and remove dead code
1 parent e37a6a6 commit c0921ae

File tree

11 files changed

+755
-719
lines changed

11 files changed

+755
-719
lines changed

src/DynamicPPL.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,11 @@ abstract type AbstractVarInfo <: AbstractModelTrace end
175175
include("utils.jl")
176176
include("chains.jl")
177177
include("contexts.jl")
178+
include("contexts/default.jl")
178179
include("contexts/init.jl")
180+
include("contexts/transformation.jl")
181+
include("contexts/prefix.jl")
182+
include("contexts/conditionfix.jl") # Must come after contexts/prefix.jl
179183
include("model.jl")
180184
include("sampler.jl")
181185
include("varname.jl")
@@ -188,10 +192,8 @@ include("abstract_varinfo.jl")
188192
include("threadsafe.jl")
189193
include("varinfo.jl")
190194
include("simple_varinfo.jl")
191-
include("context_implementations.jl")
192195
include("compiler.jl")
193196
include("pointwise_logdensities.jl")
194-
include("transforming.jl")
195197
include("logdensityfunction.jl")
196198
include("model_utils.jl")
197199
include("extract_priors.jl")

src/abstract_varinfo.jl

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,27 @@ end
827827
function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
828828
return link!!(default_transformation(model, vi), vi, vns, model)
829829
end
830+
function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
831+
# Note that in practice this method is only called for SimpleVarInfo, because VarInfo
832+
# has a dedicated implementation
833+
ctx = DynamicTransformationContext{false}()
834+
model = contextualize(model, setleafcontext(model.context, ctx))
835+
vi = last(evaluate!!(model, vi))
836+
return settrans!!(vi, t)
837+
end
838+
function link!!(
839+
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
840+
)
841+
b = inverse(t.bijector)
842+
x = vi[:]
843+
y, logjac = with_logabsdet_jacobian(b, x)
844+
# Set parameters and add the logjac term.
845+
vi = unflatten(vi, y)
846+
if hasacc(vi, Val(:LogJacobian))
847+
vi = acclogjac!!(vi, logjac)
848+
end
849+
return settrans!!(vi, t)
850+
end
830851

831852
"""
832853
link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
@@ -846,6 +867,9 @@ end
846867
function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
847868
return link(default_transformation(model, vi), vi, vns, model)
848869
end
870+
function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
871+
return link!!(t, deepcopy(vi), model)
872+
end
849873

850874
"""
851875
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
@@ -866,23 +890,14 @@ end
866890
function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
867891
return invlink!!(default_transformation(model, vi), vi, vns, model)
868892
end
869-
870-
# Vector-based ones.
871-
function link!!(
872-
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
873-
)
874-
b = inverse(t.bijector)
875-
x = vi[:]
876-
y, logjac = with_logabsdet_jacobian(b, x)
877-
878-
# Set parameters and add the logjac term.
879-
vi = unflatten(vi, y)
880-
if hasacc(vi, Val(:LogJacobian))
881-
vi = acclogjac!!(vi, logjac)
882-
end
883-
return settrans!!(vi, t)
893+
function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model)
894+
# Note that in practice this method is only called for SimpleVarInfo, because VarInfo
895+
# has a dedicated implementation
896+
ctx = DynamicTransformationContext{true}()
897+
model = contextualize(model, setleafcontext(model.context, ctx))
898+
vi = last(evaluate!!(model, vi))
899+
return settrans!!(vi, NoTransformation())
884900
end
885-
886901
function invlink!!(
887902
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
888903
)
@@ -919,6 +934,9 @@ end
919934
function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
920935
return invlink(default_transformation(model, vi), vi, vns, model)
921936
end
937+
function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
938+
return invlink!!(t, deepcopy(vi), model)
939+
end
922940

923941
"""
924942
maybe_invlink_before_eval!!([t::Transformation,] vi, model)

src/context_implementations.jl

Lines changed: 0 additions & 128 deletions
This file was deleted.

0 commit comments

Comments
 (0)