@@ -435,12 +435,15 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
435
435
436
436
@testset " InitContext" begin
437
437
empty_varinfos = [
438
- VarInfo (),
439
- DynamicPPL. typed_varinfo (VarInfo ()),
440
- VarInfo (DynamicPPL. VarNamedVector ()),
441
- DynamicPPL. typed_vector_varinfo (DynamicPPL. typed_varinfo (VarInfo ())),
442
- SimpleVarInfo (),
443
- SimpleVarInfo (Dict {VarName,Any} ()),
438
+ (" untyped+metadata" , VarInfo ()),
439
+ (" typed+metadata" , DynamicPPL. typed_varinfo (VarInfo ())),
440
+ (" untyped+VNV" , VarInfo (DynamicPPL. VarNamedVector ())),
441
+ (
442
+ " typed+VNV" ,
443
+ DynamicPPL. typed_vector_varinfo (DynamicPPL. typed_varinfo (VarInfo ())),
444
+ ),
445
+ (" SVI+NamedTuple" , SimpleVarInfo ()),
446
+ (" Svi+Dict" , SimpleVarInfo (Dict {VarName,Any} ())),
444
447
]
445
448
446
449
@model function test_init_model ()
@@ -455,7 +458,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
455
458
# Check that init!! can generate values that weren't there
456
459
# previously.
457
460
model = test_init_model ()
458
- for empty_vi in empty_varinfos
461
+ @testset " $vi_name " for (vi_name, empty_vi) in empty_varinfos
459
462
this_vi = deepcopy (empty_vi)
460
463
_, vi = DynamicPPL. init!! (model, this_vi, strategy)
461
464
@test Set (keys (vi)) == Set ([@varname (x), @varname (y)])
@@ -475,7 +478,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
475
478
@testset " replacing old values: $(typeof (strategy)) " begin
476
479
# Check that init!! can overwrite values that were already there.
477
480
model = test_init_model ()
478
- for empty_vi in empty_varinfos
481
+ @testset " $vi_name " for (vi_name, empty_vi) in empty_varinfos
479
482
# start by generating some rubbish values
480
483
vi = deepcopy (empty_vi)
481
484
old_x, old_y = 100000.00 , [300000.00 , 500000.00 ]
@@ -494,7 +497,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
494
497
function test_rng_respected (strategy:: AbstractInitStrategy )
495
498
@testset " check that RNG is respected: $(typeof (strategy)) " begin
496
499
model = test_init_model ()
497
- for empty_vi in empty_varinfos
500
+ @testset " $vi_name " for (vi_name, empty_vi) in empty_varinfos
498
501
_, vi1 = DynamicPPL. init!! (
499
502
Xoshiro (468 ), model, deepcopy (empty_vi), strategy
500
503
)
@@ -613,29 +616,60 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
613
616
end
614
617
615
618
@testset " given only partial parameters" begin
616
- # In this case, we expect `ParamsInit` to use the value of x, and
617
- # generate a new value for y.
618
619
my_x = 1.0
619
620
params_nt = (; x= my_x)
620
621
params_dict = Dict (@varname (x) => my_x)
621
622
model = test_init_model ()
622
- for empty_vi in empty_varinfos
623
- _, vi = DynamicPPL. init!! (
624
- Xoshiro (468 ), model, deepcopy (empty_vi), ParamsInit (params_nt)
625
- )
626
- @test vi[@varname (x)] == my_x
627
- nt_y = vi[@varname (y)]
628
- @test nt_y isa AbstractVector{<: Real }
629
- @test length (nt_y) == 2
630
- _, vi = DynamicPPL. init!! (
631
- Xoshiro (469 ), model, deepcopy (empty_vi), ParamsInit (params_dict)
632
- )
633
- @test vi[@varname (x)] == my_x
634
- dict_y = vi[@varname (y)]
635
- @test dict_y isa AbstractVector{<: Real }
636
- @test length (dict_y) == 2
637
- # the values should be different since we used different seeds
638
- @test dict_y != nt_y
623
+ @testset " $vi_name " for (vi_name, empty_vi) in empty_varinfos
624
+ @testset " with PriorInit fallback" begin
625
+ _, vi = DynamicPPL. init!! (
626
+ Xoshiro (468 ),
627
+ model,
628
+ deepcopy (empty_vi),
629
+ ParamsInit (params_nt, PriorInit ()),
630
+ )
631
+ @test vi[@varname (x)] == my_x
632
+ nt_y = vi[@varname (y)]
633
+ @test nt_y isa AbstractVector{<: Real }
634
+ @test length (nt_y) == 2
635
+ _, vi = DynamicPPL. init!! (
636
+ Xoshiro (469 ),
637
+ model,
638
+ deepcopy (empty_vi),
639
+ ParamsInit (params_dict, PriorInit ()),
640
+ )
641
+ @test vi[@varname (x)] == my_x
642
+ dict_y = vi[@varname (y)]
643
+ @test dict_y isa AbstractVector{<: Real }
644
+ @test length (dict_y) == 2
645
+ # the values should be different since we used different seeds
646
+ @test dict_y != nt_y
647
+ end
648
+
649
+ @testset " with no fallback" begin
650
+ # These just don't have an entry for `y`.
651
+ @test_throws ErrorException DynamicPPL. init!! (
652
+ model, deepcopy (empty_vi), ParamsInit (params_nt, nothing )
653
+ )
654
+ @test_throws ErrorException DynamicPPL. init!! (
655
+ model, deepcopy (empty_vi), ParamsInit (params_dict, nothing )
656
+ )
657
+ # We also explicitly test the case where `y = missing`.
658
+ params_nt_missing = (; x= my_x, y= missing )
659
+ params_dict_missing = Dict (
660
+ @varname (x) => my_x, @varname (y) => missing
661
+ )
662
+ @test_throws ErrorException DynamicPPL. init!! (
663
+ model,
664
+ deepcopy (empty_vi),
665
+ ParamsInit (params_nt_missing, nothing ),
666
+ )
667
+ @test_throws ErrorException DynamicPPL. init!! (
668
+ model,
669
+ deepcopy (empty_vi),
670
+ ParamsInit (params_dict_missing, nothing ),
671
+ )
672
+ end
639
673
end
640
674
end
641
675
end
0 commit comments