Skip to content

Commit 5da6d85

Browse files
committed
Add a test to check that init!! doesn't change linking
1 parent 044cb24 commit 5da6d85

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

test/contexts.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,11 +508,29 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
508508
end
509509
end
510510
end
511+
function test_link_status_respected(strategy::AbstractInitStrategy)
512+
@testset "check that varinfo linking is preserved: $(typeof(strategy))" begin
513+
@model logn() = a ~ LogNormal()
514+
model = logn()
515+
vi = VarInfo(model)
516+
linked_vi = DynamicPPL.link!!(vi, model)
517+
_, new_vi = DynamicPPL.init!!(model, linked_vi, strategy)
518+
@test DynamicPPL.istrans(new_vi)
519+
# this is the unlinked value, since it uses `getindex`
520+
a = new_vi[@varname(a)]
521+
# logp should correspond to the transformed value
522+
@test isapprox(DynamicPPL.getlogjoint(new_vi), logpdf(Normal(), log(a)))
523+
@test isapprox(
524+
only(DynamicPPL.getindex_internal(new_vi, @varname(a))), log(a)
525+
)
526+
end
527+
end
511528

512529
@testset "PriorInit" begin
513530
test_generating_new_values(PriorInit())
514531
test_replacing_values(PriorInit())
515532
test_rng_respected(PriorInit())
533+
test_link_status_respected(PriorInit())
516534

517535
@testset "check that values are within support" begin
518536
# Not many other sensible checks we can do for priors.
@@ -529,6 +547,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
529547
test_generating_new_values(UniformInit())
530548
test_replacing_values(UniformInit())
531549
test_rng_respected(UniformInit())
550+
test_link_status_respected(UniformInit())
532551

533552
@testset "check that bounds are respected" begin
534553
@testset "unconstrained" begin
@@ -559,6 +578,9 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
559578
end
560579

561580
@testset "ParamsInit" begin
581+
test_link_status_respected(ParamsInit((; a=1.0)))
582+
test_link_status_respected(ParamsInit(Dict(@varname(a) => 1.0)))
583+
562584
@testset "given full set of parameters" begin
563585
# test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I)
564586
my_x, my_y = 1.0, [2.0, 3.0]

0 commit comments

Comments
 (0)