@@ -508,11 +508,33 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
508508 end
509509 end
510510 end
511+ function test_link_status_respected(strategy:: AbstractInitStrategy )
512+ @testset " check that varinfo linking is preserved: $(typeof(strategy)) " begin
513+ @model logn() = a ~ LogNormal()
514+ model = logn()
515+ vi = VarInfo(model)
516+ linked_vi = DynamicPPL. link!!(vi, model)
517+ _, new_vi = DynamicPPL. init!!(model, linked_vi, strategy)
518+ @test DynamicPPL. istrans(new_vi)
519+ # this is the unlinked value, since it uses `getindex`
520+ a = new_vi[@varname(a)]
521+ # internal logjoint should correspond to the transformed value
522+ @test isapprox(
523+ DynamicPPL. getlogjoint_internal(new_vi), logpdf(Normal(), log(a))
524+ )
525+ # user logjoint should correspond to the transformed value
526+ @test isapprox(DynamicPPL. getlogjoint(new_vi), logpdf(LogNormal(), a))
527+ @test isapprox(
528+ only(DynamicPPL. getindex_internal(new_vi, @varname(a))), log(a)
529+ )
530+ end
531+ end
511532
512533 @testset " PriorInit" begin
513534 test_generating_new_values(PriorInit())
514535 test_replacing_values(PriorInit())
515536 test_rng_respected(PriorInit())
537+ test_link_status_respected(PriorInit())
516538
517539 @testset " check that values are within support" begin
518540 # Not many other sensible checks we can do for priors.
@@ -529,6 +551,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
529551 test_generating_new_values(UniformInit())
530552 test_replacing_values(UniformInit())
531553 test_rng_respected(UniformInit())
554+ test_link_status_respected(UniformInit())
532555
533556 @testset " check that bounds are respected" begin
534557 @testset " unconstrained" begin
@@ -559,6 +582,9 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
559582 end
560583
561584 @testset " ParamsInit" begin
585+ test_link_status_respected(ParamsInit((; a= 1.0 )))
586+ test_link_status_respected(ParamsInit(Dict(@varname(a) => 1.0 )))
587+
562588 @testset " given full set of parameters" begin
563589 # test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I)
564590 my_x, my_y = 1.0 , [2.0 , 3.0 ]
0 commit comments