Skip to content

Commit c63d9ab

Browse files
committed
Add a test to check that init!! doesn't change linking
1 parent efef53e commit c63d9ab

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

src/contexts/init.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ function tilde_assume(
165165
# necessary.
166166
insert_transformed_value && settrans!!(vi, true, vn)
167167
# `accumulate_assume!!` wants untransformed values as the second argument.
168-
vi = accumulate_assume!!(vi, x, -logjac, vn, dist)
168+
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
169169
# We always return the untransformed value here, as that will determine
170170
# what the lhs of the tilde-statement is set to.
171171
return x, vi

test/contexts.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,11 +508,33 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
508508
end
509509
end
510510
end
511+
function test_link_status_respected(strategy::AbstractInitStrategy)
512+
@testset "check that varinfo linking is preserved: $(typeof(strategy))" begin
513+
@model logn() = a ~ LogNormal()
514+
model = logn()
515+
vi = VarInfo(model)
516+
linked_vi = DynamicPPL.link!!(vi, model)
517+
_, new_vi = DynamicPPL.init!!(model, linked_vi, strategy)
518+
@test DynamicPPL.istrans(new_vi)
519+
# this is the unlinked value, since it uses `getindex`
520+
a = new_vi[@varname(a)]
521+
# internal logjoint should correspond to the transformed value
522+
@test isapprox(
523+
DynamicPPL.getlogjoint_internal(new_vi), logpdf(Normal(), log(a))
524+
)
525+
# user logjoint should correspond to the transformed value
526+
@test isapprox(DynamicPPL.getlogjoint(new_vi), logpdf(LogNormal(), a))
527+
@test isapprox(
528+
only(DynamicPPL.getindex_internal(new_vi, @varname(a))), log(a)
529+
)
530+
end
531+
end
511532

512533
@testset "PriorInit" begin
513534
test_generating_new_values(PriorInit())
514535
test_replacing_values(PriorInit())
515536
test_rng_respected(PriorInit())
537+
test_link_status_respected(PriorInit())
516538

517539
@testset "check that values are within support" begin
518540
# Not many other sensible checks we can do for priors.
@@ -529,6 +551,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
529551
test_generating_new_values(UniformInit())
530552
test_replacing_values(UniformInit())
531553
test_rng_respected(UniformInit())
554+
test_link_status_respected(UniformInit())
532555

533556
@testset "check that bounds are respected" begin
534557
@testset "unconstrained" begin
@@ -559,6 +582,9 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
559582
end
560583

561584
@testset "ParamsInit" begin
585+
test_link_status_respected(ParamsInit((; a=1.0)))
586+
test_link_status_respected(ParamsInit(Dict(@varname(a) => 1.0)))
587+
562588
@testset "given full set of parameters" begin
563589
# test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I)
564590
my_x, my_y = 1.0, [2.0, 3.0]

0 commit comments

Comments
 (0)