@@ -508,11 +508,33 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
508
508
end
509
509
end
510
510
end
511
+ function test_link_status_respected (strategy:: AbstractInitStrategy )
512
+ @testset " check that varinfo linking is preserved: $(typeof (strategy)) " begin
513
+ @model logn () = a ~ LogNormal ()
514
+ model = logn ()
515
+ vi = VarInfo (model)
516
+ linked_vi = DynamicPPL. link!! (vi, model)
517
+ _, new_vi = DynamicPPL. init!! (model, linked_vi, strategy)
518
+ @test DynamicPPL. istrans (new_vi)
519
+ # this is the unlinked value, since it uses `getindex`
520
+ a = new_vi[@varname (a)]
521
+ # internal logjoint should correspond to the transformed value
522
+ @test isapprox (
523
+ DynamicPPL. getlogjoint_internal (new_vi), logpdf (Normal (), log (a))
524
+ )
525
+ # user logjoint should correspond to the transformed value
526
+ @test isapprox (DynamicPPL. getlogjoint (new_vi), logpdf (LogNormal (), a))
527
+ @test isapprox (
528
+ only (DynamicPPL. getindex_internal (new_vi, @varname (a))), log (a)
529
+ )
530
+ end
531
+ end
511
532
512
533
@testset " PriorInit" begin
513
534
test_generating_new_values (PriorInit ())
514
535
test_replacing_values (PriorInit ())
515
536
test_rng_respected (PriorInit ())
537
+ test_link_status_respected (PriorInit ())
516
538
517
539
@testset " check that values are within support" begin
518
540
# Not many other sensible checks we can do for priors.
@@ -529,6 +551,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
529
551
test_generating_new_values (UniformInit ())
530
552
test_replacing_values (UniformInit ())
531
553
test_rng_respected (UniformInit ())
554
+ test_link_status_respected (UniformInit ())
532
555
533
556
@testset " check that bounds are respected" begin
534
557
@testset " unconstrained" begin
@@ -559,6 +582,9 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
559
582
end
560
583
561
584
@testset " ParamsInit" begin
585
+ test_link_status_respected (ParamsInit ((; a= 1.0 )))
586
+ test_link_status_respected (ParamsInit (Dict (@varname (a) => 1.0 )))
587
+
562
588
@testset " given full set of parameters" begin
563
589
# test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I)
564
590
my_x, my_y = 1.0 , [2.0 , 3.0 ]
0 commit comments