@@ -508,11 +508,29 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
508
508
end
509
509
end
510
510
end
511
+ function test_link_status_respected (strategy:: AbstractInitStrategy )
512
+ @testset " check that varinfo linking is preserved: $(typeof (strategy)) " begin
513
+ @model logn () = a ~ LogNormal ()
514
+ model = logn ()
515
+ vi = VarInfo (model)
516
+ linked_vi = DynamicPPL. link!! (vi, model)
517
+ _, new_vi = DynamicPPL. init!! (model, linked_vi, strategy)
518
+ @test DynamicPPL. istrans (new_vi)
519
+ # this is the unlinked value, since it uses `getindex`
520
+ a = new_vi[@varname (a)]
521
+ # logp should correspond to the transformed value
522
+ @test isapprox (DynamicPPL. getlogjoint (new_vi), logpdf (Normal (), log (a)))
523
+ @test isapprox (
524
+ only (DynamicPPL. getindex_internal (new_vi, @varname (a))), log (a)
525
+ )
526
+ end
527
+ end
511
528
512
529
@testset " PriorInit" begin
513
530
test_generating_new_values (PriorInit ())
514
531
test_replacing_values (PriorInit ())
515
532
test_rng_respected (PriorInit ())
533
+ test_link_status_respected (PriorInit ())
516
534
517
535
@testset " check that values are within support" begin
518
536
# Not many other sensible checks we can do for priors.
@@ -529,6 +547,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
529
547
test_generating_new_values (UniformInit ())
530
548
test_replacing_values (UniformInit ())
531
549
test_rng_respected (UniformInit ())
550
+ test_link_status_respected (UniformInit ())
532
551
533
552
@testset " check that bounds are respected" begin
534
553
@testset " unconstrained" begin
@@ -559,6 +578,9 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
559
578
end
560
579
561
580
@testset " ParamsInit" begin
581
+ test_link_status_respected (ParamsInit ((; a= 1.0 )))
582
+ test_link_status_respected (ParamsInit (Dict (@varname (a) => 1.0 )))
583
+
562
584
@testset " given full set of parameters" begin
563
585
# test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I)
564
586
my_x, my_y = 1.0 , [2.0 , 3.0 ]
0 commit comments