|
113 | 113 | end
|
114 | 114 | end
|
115 | 115 |
|
116 |
| - @testset "SampleFromPrior and SampleUniform" begin |
117 |
| - @model function gdemo(x, y) |
118 |
| - s ~ InverseGamma(2, 3) |
119 |
| - m ~ Normal(2.0, sqrt(s)) |
120 |
| - x ~ Normal(m, sqrt(s)) |
121 |
| - return y ~ Normal(m, sqrt(s)) |
122 |
| - end |
123 |
| - |
124 |
| - model = gdemo(1.0, 2.0) |
125 |
| - N = 1_000 |
126 |
| - |
127 |
| - chains = sample(model, SampleFromPrior(), N; progress=false) |
128 |
| - @test chains isa Vector{<:VarInfo} |
129 |
| - @test length(chains) == N |
130 |
| - |
131 |
| - # Expected value of ``X`` where ``X ~ N(2, ...)`` is 2. |
132 |
| - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.15 |
133 |
| - |
134 |
| - # Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3. |
135 |
| - @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2 |
136 |
| - |
137 |
| - chains = sample(model, SampleFromUniform(), N; progress=false) |
138 |
| - @test chains isa Vector{<:VarInfo} |
139 |
| - @test length(chains) == N |
140 |
| - |
141 |
| - # `m` is Gaussian, i.e. no transformation is used, so it |
142 |
| - # will be drawn from U[-2, 2] and its mean should be 0. |
143 |
| - @test mean(vi[@varname(m)] for vi in chains) ≈ 0.0 atol = 0.1 |
144 |
| - |
145 |
| - # Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8. |
146 |
| - @test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1 |
147 |
| - end |
148 |
| - |
149 |
| - @testset "init" begin |
150 |
| - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS |
151 |
| - N = 1000 |
152 |
| - chain_init = sample(model, SampleFromUniform(), N; progress=false) |
153 |
| - |
154 |
| - for vn in keys(first(chain_init)) |
155 |
| - if AbstractPPL.subsumes(@varname(s), vn) |
156 |
| - # `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2]. |
157 |
| - dist = InverseGamma(2, 3) |
158 |
| - b = DynamicPPL.link_transform(dist) |
159 |
| - @test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11 |
160 |
| - elseif AbstractPPL.subsumes(@varname(m), vn) |
161 |
| - # `m ~ Normal(0, sqrt(s))` and its constrained value is the same. |
162 |
| - @test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11 |
163 |
| - else |
164 |
| - error("Unknown variable name: $vn") |
165 |
| - end |
166 |
| - end |
167 |
| - end |
168 |
| - end |
169 |
| - |
170 | 116 | @testset "Initial parameters" begin
|
171 | 117 | # dummy algorithm that just returns initial value and does not perform any sampling
|
172 | 118 | abstract type OnlyInitAlg end
|
|
0 commit comments