@@ -40,3 +40,68 @@ Random.seed!(100)
40
40
@test mean (vi[@varname (s)] for vi in chains) ≈ 1.8 atol = 0.1
41
41
end
42
42
43
+ @testset " Initial parameters" begin
44
+ # dummy algorithm that just returns initial value and does not perform any sampling
45
+ struct OnlyInitAlg end
46
+ function DynamicPPL. initialstep (
47
+ rng:: Random.AbstractRNG ,
48
+ model:: Model ,
49
+ :: Sampler{OnlyInitAlg} ,
50
+ vi:: AbstractVarInfo ;
51
+ kwargs... ,
52
+ )
53
+ return vi, nothing
54
+ end
55
+ DynamicPPL. getspace (:: OnlyInitAlg ) = ()
56
+
57
+ # model with one variable: initialization p = 0.2
58
+ @model function coinflip ()
59
+ p ~ Beta (1 , 1 )
60
+ 10 ~ Binomial (25 , p)
61
+ end
62
+ model = coinflip ()
63
+ sampler = Sampler (OnlyInitAlg ())
64
+ lptrue = logpdf (Binomial (25 , 0.2 ), 10 )
65
+ chain = sample (model, sampler, 1 ; init_params = 0.2 )
66
+ @test chain[1 ]. metadata. p. vals == [0.2 ]
67
+ @test getlogp (chain[1 ]) == lptrue
68
+
69
+ # parallel sampling
70
+ chains = sample (model, sampler, MCMCThreads (), 1 , 10 ; init_params = 0.2 )
71
+ for c in chains
72
+ @test c[1 ]. metadata. p. vals == [0.2 ]
73
+ @test getlogp (c[1 ]) == lptrue
74
+ end
75
+
76
+ # model with two variables: initialization s = 4, m = -1
77
+ @model function twovars ()
78
+ s ~ InverseGamma (2 , 3 )
79
+ m ~ Normal (0 , sqrt (s))
80
+ end
81
+ model = twovars ()
82
+ lptrue = logpdf (InverseGamma (2 , 3 ), 4 ) + logpdf (Normal (0 , 2 ), - 1 )
83
+ chain = sample (model, sampler, 1 ; init_params = [4 , - 1 ])
84
+ @test chain[1 ]. metadata. s. vals == [4 ]
85
+ @test chain[1 ]. metadata. m. vals == [- 1 ]
86
+ @test getlogp (chain[1 ]) == lptrue
87
+
88
+ # parallel sampling
89
+ chains = sample (model, sampler, MCMCThreads (), 1 , 10 ; init_params = [4 , - 1 ])
90
+ for c in chains
91
+ @test c[1 ]. metadata. s. vals == [4 ]
92
+ @test c[1 ]. metadata. m. vals == [- 1 ]
93
+ @test getlogp (c[1 ]) == lptrue
94
+ end
95
+
96
+ # set only m = -1
97
+ chain = sample (model, sampler, 1 ; init_params = [missing , - 1 ])
98
+ @test ! ismissing (chain[1 ]. metadata. s. vals[1 ])
99
+ @test chain[1 ]. metadata. m. vals == [- 1 ]
100
+
101
+ # parallel sampling
102
+ chains = sample (model, sampler, MCMCThreads (), 1 , 10 ; init_params = [missing , - 1 ])
103
+ for c in chains
104
+ @test ! ismissing (c[1 ]. metadata. s. vals[1 ])
105
+ @test c[1 ]. metadata. m. vals == [- 1 ]
106
+ end
107
+ end
0 commit comments