Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
4dc2a72
Remove selector stuff from varinfo tests
mhauru Jan 16, 2025
9b492a3
Implement link and invlink for varnames rather than samplers
mhauru Jan 16, 2025
b508f08
Replace set_retained_vns_del_by_spl! with set_retained_vns_del!
mhauru Jan 16, 2025
b8880d1
Make linking tests more extensive
mhauru Jan 16, 2025
99a8490
Remove sampler indexing from link methods (but not invlink)
mhauru Jan 22, 2025
4a79b1f
Remove indexing by samplers from invlink
mhauru Jan 22, 2025
26a1901
Merge remote-tracking branch 'origin/master' into mhauru/remove-selec…
mhauru Jan 22, 2025
090608b
Work towards removing sampler indexing with StaticTransformation
mhauru Jan 22, 2025
4749853
Fix invlink/link for TypedVarInfo and StaticTransformation
mhauru Jan 23, 2025
e960679
Fix a test in models.jl
mhauru Jan 23, 2025
d507a53
Move some functions to utils.jl, add tests and docstrings
mhauru Jan 23, 2025
41150b5
Fix a docstring typo
mhauru Jan 23, 2025
836fb13
Merge branch 'release-0.35' into mhauru/remove-selectors-linking
mhauru Jan 23, 2025
45d1f13
Various simplification to link/invlink
mhauru Jan 23, 2025
98915c2
Improve a docstring
mhauru Jan 23, 2025
f05068d
Style improvements
mhauru Jan 23, 2025
bc4c420
Fix broken link/invlink dispatch cascade for VectorVarInfo
mhauru Jan 23, 2025
71980ba
Fix some more broken dispatch cascades
mhauru Jan 23, 2025
45562a9
Apply suggestions from code review
mhauru Jan 24, 2025
db5b835
Remove comments that messed with docstrings
mhauru Jan 24, 2025
f99effe
Apply suggestions from code review
mhauru Jan 28, 2025
56194cd
Fix issues surfaced in code review
mhauru Jan 28, 2025
c187c49
Simplify link/invlink arguments
mhauru Jan 28, 2025
86b25c5
Fix a bug in unflatten VarNamedVector
mhauru Jan 28, 2025
2a6c1bc
Rename VarNameCollection -> VarNameTuple
mhauru Jan 28, 2025
853f47e
Remove test of a removed varname_namedtuple method
mhauru Jan 28, 2025
ed80328
Apply suggestions from code review
mhauru Jan 29, 2025
d996d0c
Respond to review feedback
mhauru Jan 29, 2025
2083148
Remove _default_sampler and a dead argument of maybe_invlink_before_eval
mhauru Jan 29, 2025
39fa647
Fix a typo in a comment
mhauru Jan 29, 2025
9df364f
Merge remote-tracking branch 'origin/release-0.35' into mhauru/remove…
mhauru Jan 30, 2025
2c73de5
Add HISTORY entry, fix one set_retained_vns_del! method
mhauru Jan 30, 2025
49604e1
Merge remote-tracking branch 'origin/release-0.35' into mhauru/remove…
mhauru Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ set_num_produce!
increment_num_produce!
reset_num_produce!
setorder!
set_retained_vns_del_by_spl!
set_retained_vns_del!
```

```@docs
Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ export AbstractVarInfo,
set_num_produce!,
reset_num_produce!,
increment_num_produce!,
set_retained_vns_del_by_spl!,
set_retained_vns_del!,
is_flagged,
set_flag!,
unset_flag!,
Expand Down
108 changes: 56 additions & 52 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -537,113 +537,116 @@
"""
function settrans!! end

# For link!!, invlink!!, link, and invlink, we deliberately do not provide a fallback
# method for the case when no `vns` is provided, that would get all the keys from the
# `VarInfo`. Hence each subtype of `AbstractVarInfo` needs to implement separately the case
# where `vns` is provided and the one where it is not. This is because having separate
# implementations is typically much more performant, and because not all AbstractVarInfo
# types support partial linking.

"""
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model)
Transform the variables in `vi` to their linked space, using the transformation `t`,
mutating `vi` if possible.
Transform variables in `vi` to their linked space, mutating `vi` if possible.
If `t` is not provided, `default_transformation(model, vi)` will be used.
Either transform all variables, or only ones specified in `vns`.
Use the transformation `t`, or `default_transformation(model, vi)` if one is not provided.
See also: [`default_transformation`](@ref), [`invlink!!`](@ref).
"""
link!!(vi::AbstractVarInfo, model::Model) = link!!(vi, SampleFromPrior(), model)
function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return link!!(t, vi, SampleFromPrior(), model)
function link!!(vi::AbstractVarInfo, model::Model)
return link!!(default_transformation(model, vi), vi, model)
end
function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
# Use `default_transformation` to decide which transformation to use if none is specified.
return link!!(default_transformation(model, vi), vi, spl, model)
function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return link!!(default_transformation(model, vi), vi, vns, model)
end

"""
link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
link([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
link([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model)
Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`.
Transform variables in `vi` to their linked space without mutating `vi`.
If `t` is not provided, `default_transformation(model, vi)` will be used.
Either transform all variables, or only ones specified in `vns`.
Use the transformation `t`, or `default_transformation(model, vi)` if one is not provided.
See also: [`default_transformation`](@ref), [`invlink`](@ref).
"""
link(vi::AbstractVarInfo, model::Model) = link(vi, SampleFromPrior(), model)
function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return link(t, deepcopy(vi), SampleFromPrior(), model)
function link(vi::AbstractVarInfo, model::Model)
return link(default_transformation(model, vi), vi, model)
end
function link(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
# Use `default_transformation` to decide which transformation to use if none is specified.
return link(default_transformation(model, vi), deepcopy(vi), spl, model)
function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return link(default_transformation(model, vi), vi, vns, model)

Check warning on line 584 in src/abstract_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/abstract_varinfo.jl#L583-L584

Added lines #L583 - L584 were not covered by tests
end

"""
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model)
Transform variables in `vi` to their constrained space, mutating `vi` if possible.
Transform the variables in `vi` to their constrained space, using the (inverse of)
transformation `t`, mutating `vi` if possible.
Either transform all variables, or only ones specified in `vns`.
If `t` is not provided, `default_transformation(model, vi)` will be used.
Use the (inverse of) transformation `t`, or `default_transformation(model, vi)` if one is
not provided.
See also: [`default_transformation`](@ref), [`link!!`](@ref).
"""
invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), model)
function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return invlink!!(t, vi, SampleFromPrior(), model)
function invlink!!(vi::AbstractVarInfo, model::Model)
return invlink!!(default_transformation(model, vi), vi, model)
end
function invlink!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
# Here we extract the `transformation` from `vi` rather than using the default one.
return invlink!!(transformation(vi), vi, spl, model)
function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return invlink!!(default_transformation(model, vi), vi, vns, model)
end

# Vector-based ones.
function link!!(
t::StaticTransformation{<:Bijectors.Transform},
vi::AbstractVarInfo,
spl::AbstractSampler,
model::Model,
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
)
b = inverse(t.bijector)
x = vi[spl]
x = vi[:]
y, logjac = with_logabsdet_jacobian(b, x)

lp_new = getlogp(vi) - logjac
vi_new = setlogp!!(unflatten(vi, spl, y), lp_new)
vi_new = setlogp!!(unflatten(vi, y), lp_new)
return settrans!!(vi_new, t)
end

function invlink!!(
t::StaticTransformation{<:Bijectors.Transform},
vi::AbstractVarInfo,
spl::AbstractSampler,
model::Model,
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
)
b = t.bijector
y = vi[spl]
y = vi[:]
x, logjac = with_logabsdet_jacobian(b, y)

lp_new = getlogp(vi) + logjac
vi_new = setlogp!!(unflatten(vi, spl, x), lp_new)
vi_new = setlogp!!(unflatten(vi, x), lp_new)
return settrans!!(vi_new, NoTransformation())
end

"""
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model)
Transform variables in `vi` to their constrained space without mutating `vi`.
Transform the variables in `vi` to their constrained space without mutating `vi`, using the (inverse of)
transformation `t`.
Either transform all variables, or only ones specified in `vns`.
If `t` is not provided, `default_transformation(model, vi)` will be used.
Use the (inverse of) transformation `t`, or `default_transformation(model, vi)` if one is
not provided.
See also: [`default_transformation`](@ref), [`link`](@ref).
"""
invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), model)
function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return invlink(t, vi, SampleFromPrior(), model)
function invlink(vi::AbstractVarInfo, model::Model)
return invlink(default_transformation(model, vi), vi, model)
end
function invlink(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
return invlink(transformation(vi), vi, spl, model)
function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return invlink(default_transformation(model, vi), vi, vns, model)

Check warning on line 649 in src/abstract_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/abstract_varinfo.jl#L648-L649

Added lines #L648 - L649 were not covered by tests
end

"""
Expand Down Expand Up @@ -715,9 +718,10 @@
return vi
end
function maybe_invlink_before_eval!!(
t::StaticTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model
t::StaticTransformation, vi::AbstractVarInfo, ::AbstractContext, model::Model
)
return invlink!!(t, vi, _default_sampler(context), model)
# TODO(mhauru) Why does this function need the context argument?
return invlink!!(t, vi, model)
end

function _default_sampler(context::AbstractContext)
Expand Down
6 changes: 2 additions & 4 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -680,8 +680,7 @@ Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarIn
function link!!(
t::StaticTransformation{<:Bijectors.NamedTransform},
vi::SimpleVarInfo{<:NamedTuple},
spl::AbstractSampler,
model::Model,
::Model,
)
# TODO: Make sure that `spl` is respected.
b = inverse(t.bijector)
Expand All @@ -695,8 +694,7 @@ end
function invlink!!(
t::StaticTransformation{<:Bijectors.NamedTransform},
vi::SimpleVarInfo{<:NamedTuple},
spl::AbstractSampler,
model::Model,
::Model,
)
# TODO: Make sure that `spl` is respected.
b = t.bijector
Expand Down
48 changes: 16 additions & 32 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,59 +81,43 @@

islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl)

function link!!(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, spl, model)
function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...)

Check warning on line 85 in src/threadsafe.jl

View check run for this annotation

Codecov / codecov/patch

src/threadsafe.jl#L84-L85

Added lines #L84 - L85 were not covered by tests
end

function invlink!!(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, spl, model)
function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...)

Check warning on line 89 in src/threadsafe.jl

View check run for this annotation

Codecov / codecov/patch

src/threadsafe.jl#L88-L89

Added lines #L88 - L89 were not covered by tests
end

function link(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return Accessors.@set vi.varinfo = link(t, vi.varinfo, spl, model)
function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
return Accessors.@set vi.varinfo = link(t, vi.varinfo, args...)

Check warning on line 93 in src/threadsafe.jl

View check run for this annotation

Codecov / codecov/patch

src/threadsafe.jl#L92-L93

Added lines #L92 - L93 were not covered by tests
end

function invlink(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, spl, model)
function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, args...)

Check warning on line 97 in src/threadsafe.jl

View check run for this annotation

Codecov / codecov/patch

src/threadsafe.jl#L96-L97

Added lines #L96 - L97 were not covered by tests
end

# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity.
# NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure
# consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates
# to define `getlogp(vi)`.
function link!!(
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t)
end

function invlink!!(
::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
return settrans!!(
last(evaluate!!(model, vi, DynamicTransformationContext{true}())),
NoTransformation(),
)
end

function link(
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return link!!(t, deepcopy(vi), spl, model)
function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
return link!!(t, deepcopy(vi), model)
end

function invlink(
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return invlink!!(t, deepcopy(vi), spl, model)
function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
return invlink!!(t, deepcopy(vi), model)
end

function maybe_invlink_before_eval!!(
Expand Down Expand Up @@ -182,8 +166,8 @@
return vector_getranges(vi.varinfo, vns)
end

function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler)
return set_retained_vns_del_by_spl!(vi.varinfo, spl)
function set_retained_vns_del!(vi::ThreadSafeVarInfo, spl::Sampler)
return set_retained_vns_del!(vi.varinfo, spl)

Check warning on line 170 in src/threadsafe.jl

View check run for this annotation

Codecov / codecov/patch

src/threadsafe.jl#L169-L170

Added lines #L169 - L170 were not covered by tests
end

isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo)
Expand Down
20 changes: 6 additions & 14 deletions src/transforming.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,29 +91,21 @@ function dot_tilde_assume(
return r, lp, vi
end

function link!!(
t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model
)
function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t)
end

function invlink!!(
::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model
)
function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model)
return settrans!!(
last(evaluate!!(model, vi, DynamicTransformationContext{true}())),
NoTransformation(),
)
end

function link(
t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model
)
return link!!(t, deepcopy(vi), spl, model)
function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
return link!!(t, deepcopy(vi), model)
end

function invlink(
t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model
)
return invlink!!(t, deepcopy(vi), spl, model)
function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
return invlink!!(t, deepcopy(vi), model)
end
Loading
Loading