Skip to content

Commit ba16e3b

Browse files
torfjeldeyebai
andauthored
Immutable versions of link and invlink (#525)
* added immutable versions of link and invlink * added explicit invlink implementation for VarInfo * remove false debug statement * fixed default impls of invlink for AbstractVarInfo * formatting * use x to refer to the constrained space in invlink impl * added immuatable link implementation for VarInfo * added threadsafe versions of link and invlink * added default implementations of link and invlink for DynamicTransformation * formatting * added tests for immutable link and invlink * export link and invlink * added link and invlink to docs * fixed setall! for UntypedVarInfo * added testing model demo_one_variable_multiple_constraints * fixed BangBang.setindex!! for setting vector in vector * added tests with unflatten + linking * fixed reference to logabsdetjac in TestUtils * improoved tests for unflatten + linking * improved testing of unflatten + linking a bit * added demo_lkjchol model to TestUtils * formatting * fixed impl of link for AbstractVarInfo * epxanded comment on BangBang hack * Apply suggestions from code review Co-authored-by: Hong Ge <[email protected]> * added references to BangBang issues and PRs talking about the `possible` overloads --------- Co-authored-by: Hong Ge <[email protected]>
1 parent 866eb6f commit ba16e3b

File tree

10 files changed

+434
-11
lines changed

10 files changed

+434
-11
lines changed

docs/src/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@ DynamicPPL.StaticTransformation
243243
DynamicPPL.istrans
244244
DynamicPPL.settrans!!
245245
DynamicPPL.transformation
246+
DynamicPPL.link
247+
DynamicPPL.invlink
246248
DynamicPPL.link!!
247249
DynamicPPL.invlink!!
248250
DynamicPPL.default_transformation

src/DynamicPPL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ export AbstractVarInfo,
6060
updategid!,
6161
setorder!,
6262
istrans,
63+
link,
6364
link!,
6465
link!!,
66+
invlink,
6567
invlink!,
6668
invlink!!,
6769
tonamedtuple,

src/abstract_varinfo.jl

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,8 @@ function settrans!! end
368368
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
369369
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
370370
371-
Transforms the variables in `vi` to their linked space, using the transformation `t`.
371+
Transform the variables in `vi` to their linked space, using the transformation `t`,
372+
mutating `vi` if possible.
372373
373374
If `t` is not provided, `default_transformation(model, vi)` will be used.
374375
@@ -383,12 +384,31 @@ function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
383384
return link!!(default_transformation(model, vi), vi, spl, model)
384385
end
385386

387+
"""
388+
link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
389+
link([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
390+
391+
Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`.
392+
393+
If `t` is not provided, `default_transformation(model, vi)` will be used.
394+
395+
See also: [`default_transformation`](@ref), [`invlink`](@ref).
396+
"""
397+
link(vi::AbstractVarInfo, model::Model) = link(deepcopy(vi), SampleFromPrior(), model)
398+
function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
399+
return link(t, deepcopy(vi), SampleFromPrior(), model)
400+
end
401+
function link(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
402+
# Use `default_transformation` to decide which transformation to use if none is specified.
403+
return link(default_transformation(model, vi), deepcopy(vi), spl, model)
404+
end
405+
386406
"""
387407
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
388408
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
389409
390410
Transform the variables in `vi` to their constrained space, using the (inverse of)
391-
transformation `t`.
411+
transformation `t`, mutating `vi` if possible.
392412
393413
If `t` is not provided, `default_transformation(model, vi)` will be used.
394414
@@ -434,6 +454,25 @@ function invlink!!(
434454
return settrans!!(vi_new, NoTransformation())
435455
end
436456

457+
"""
458+
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
459+
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
460+
461+
Transform the variables in `vi` to their constrained space without mutating `vi`, using the (inverse of)
462+
transformation `t`.
463+
464+
If `t` is not provided, `default_transformation(model, vi)` will be used.
465+
466+
See also: [`default_transformation`](@ref), [`link`](@ref).
467+
"""
468+
invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), model)
469+
function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
470+
return invlink(t, vi, SampleFromPrior(), model)
471+
end
472+
function invlink(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
473+
return invlink(transformation(vi), vi, spl, model)
474+
end
475+
437476
"""
438477
maybe_invlink_before_eval!!([t::Transformation,] vi, context, model)
439478

src/test_utils.jl

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,115 @@ function logprior_true_with_logabsdet_jacobian(
197197
return (m=m, x=x_unconstrained), logprior_true(model, m, x) - Δlogp
198198
end
199199

200+
"""
201+
demo_one_variable_multiple_constraints()
202+
203+
A model with a single multivariate `x` whose components have multiple different constraints.
204+
205+
# Model
206+
```julia
207+
x[1] ~ Normal()
208+
x[2] ~ InverseGamma(2, 3)
209+
x[3] ~ truncated(Normal(), -5, 20)
210+
x[4:5] ~ Dirichlet([1.0, 2.0])
211+
```
212+
213+
"""
214+
@model function demo_one_variable_multiple_constraints(
215+
::Type{TV}=Vector{Float64}
216+
) where {TV}
217+
x = TV(undef, 5)
218+
x[1] ~ Normal()
219+
x[2] ~ InverseGamma(2, 3)
220+
x[3] ~ truncated(Normal(), -5, 20)
221+
x[4:5] ~ Dirichlet([1.0, 2.0])
222+
223+
return (x=x,)
224+
end
225+
226+
function logprior_true(model::Model{typeof(demo_one_variable_multiple_constraints)}, x)
227+
return (
228+
logpdf(Normal(), x[1]) +
229+
logpdf(InverseGamma(2, 3), x[2]) +
230+
logpdf(truncated(Normal(), -5, 20), x[3]) +
231+
logpdf(Dirichlet([1.0, 2.0]), x[4:5])
232+
)
233+
end
234+
function loglikelihood_true(model::Model{typeof(demo_one_variable_multiple_constraints)}, x)
235+
return zero(float(eltype(x)))
236+
end
237+
function varnames(model::Model{typeof(demo_one_variable_multiple_constraints)})
238+
return [@varname(x[1]), @varname(x[2]), @varname(x[3]), @varname(x[4:5])]
239+
end
240+
function logprior_true_with_logabsdet_jacobian(
241+
model::Model{typeof(demo_one_variable_multiple_constraints)}, x
242+
)
243+
b_x2 = Bijectors.bijector(InverseGamma(2, 3))
244+
b_x3 = Bijectors.bijector(truncated(Normal(), -5, 20))
245+
b_x4 = Bijectors.bijector(Dirichlet([1.0, 2.0]))
246+
x_unconstrained = vcat(x[1], b_x2(x[2]), b_x3(x[3]), b_x4(x[4:5]))
247+
Δlogp = (
248+
Bijectors.logabsdetjac(b_x2, x[2]) +
249+
Bijectors.logabsdetjac(b_x3, x[3]) +
250+
Bijectors.logabsdetjac(b_x4, x[4:5])
251+
)
252+
return (x=x_unconstrained,), logprior_true(model, x) - Δlogp
253+
end
254+
255+
function Random.rand(
256+
rng::Random.AbstractRNG,
257+
::Type{NamedTuple},
258+
model::Model{typeof(demo_one_variable_multiple_constraints)},
259+
)
260+
x = Vector{Float64}(undef, 5)
261+
x[1] = rand(rng, Normal())
262+
x[2] = rand(rng, InverseGamma(2, 3))
263+
x[3] = rand(rng, truncated(Normal(), -5, 20))
264+
x[4:5] = rand(rng, Dirichlet([1.0, 2.0]))
265+
return (x=x,)
266+
end
267+
268+
"""
269+
demo_lkjchol(d=2)
270+
271+
A model with a single variable `x` with support on the Cholesky factor of a
272+
LKJ distribution.
273+
274+
# Model
275+
```julia
276+
x ~ LKJCholesky(d, 1.0)
277+
```
278+
"""
279+
@model function demo_lkjchol(d::Int=2)
280+
x ~ LKJCholesky(d, 1.0)
281+
return (x=x,)
282+
end
283+
284+
function logprior_true(model::Model{typeof(demo_lkjchol)}, x)
285+
return logpdf(LKJCholesky(model.args.d, 1.0), x)
286+
end
287+
288+
function loglikelihood_true(model::Model{typeof(demo_lkjchol)}, x)
289+
return zero(float(eltype(x)))
290+
end
291+
292+
function varnames(model::Model{typeof(demo_lkjchol)})
293+
return [@varname(x)]
294+
end
295+
296+
function logprior_true_with_logabsdet_jacobian(model::Model{typeof(demo_lkjchol)}, x)
297+
b_x = Bijectors.bijector(LKJCholesky(model.args.d, 1.0))
298+
x_unconstrained, Δlogp = Bijectors.with_logabsdet_jacobian(b_x, x)
299+
return (x=x_unconstrained,), logprior_true(model, x) - Δlogp
300+
end
301+
302+
function Random.rand(
303+
rng::Random.AbstractRNG, ::Type{NamedTuple}, model::Model{typeof(demo_lkjchol)}
304+
)
305+
x = rand(rng, LKJCholesky(model.args.d, 1.0))
306+
return (x=x,)
307+
end
308+
200309
# A collection of models for which the posterior should be "similar".
201310
# Some utility methods for these.
202311
function _demo_logprior_true_with_logabsdet_jacobian(model, s, m)

src/threadsafe.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,18 @@ function invlink!!(
9393
return invlink!!(t, vi.varinfo, spl, model)
9494
end
9595

96+
function link(
97+
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
98+
)
99+
return link(t, vi.varinfo, spl, model)
100+
end
101+
102+
function invlink(
103+
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
104+
)
105+
return invlink(t, vi.varinfo, spl, model)
106+
end
107+
96108
function maybe_invlink_before_eval!!(
97109
vi::ThreadSafeVarInfo, context::AbstractContext, model::Model
98110
)

src/transforming.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,15 @@ function invlink!!(
9494
NoTransformation(),
9595
)
9696
end
97+
98+
function link(
99+
t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model
100+
)
101+
return link!!(t, deepcopy(vi), spl, model)
102+
end
103+
104+
function invlink(
105+
t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model
106+
)
107+
return invlink!!(t, deepcopy(vi), spl, model)
108+
end

src/utils.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,8 @@ function splitlens(condition, lens)
501501
return current_parent, current_child, condition(current_parent)
502502
end
503503

504+
# HACK: All of these are related to https://github.com/JuliaFolds/BangBang.jl/issues/233
505+
# and https://github.com/JuliaFolds/BangBang.jl/pull/238.
504506
# HACK(torfjelde): Avoids type-instability in `dot_assume` for `SimpleVarInfo`.
505507
function BangBang.possible(
506508
::typeof(BangBang._setindex!), ::C, ::T, ::Colon, ::Integer
@@ -514,6 +516,23 @@ function BangBang.possible(
514516
return BangBang.implements(setindex!, C) &&
515517
promote_type(eltype(C), eltype(T)) <: eltype(C)
516518
end
519+
# HACK: Makes it possible to use ranges, etc. for setting a vector.
520+
# For example, without this hack, BangBang.jl will consider
521+
#
522+
# x[1:2] = [1, 2]
523+
#
524+
# as NOT supported. This results is calling the immutable
525+
# `BangBang.setindex` instead, which also ends up expanding the
526+
# type of the containing array (`x` in the above scenario) to
527+
# have element type `Any`.
528+
# The below code just, correctly, marks this as possible and
529+
# thus we hit the mutable `setindex!` instead.
530+
function BangBang.possible(
531+
::typeof(BangBang._setindex!), ::C, ::T, ::AbstractVector{<:Integer}
532+
) where {C<:AbstractVector,T<:AbstractVector}
533+
return BangBang.implements(setindex!, C) &&
534+
promote_type(eltype(C), eltype(T)) <: eltype(C)
535+
end
517536

518537
# HACK(torfjelde): This makes it so it works on iterators, etc. by default.
519538
# TODO(torfjelde): Do better.

0 commit comments

Comments
 (0)