|
192 | 192 | end |
193 | 193 |
|
194 | 194 | @testset "accumulate parameter gradients" begin |
195 | | - tr = simulate(bam, (1, )) |
196 | | - zero_param_grad!(bang1, :std) |
197 | | - input_grads = accumulate_param_gradients!(tr, 1.0) |
198 | | - tr = simulate(bam, (2, )) |
199 | | - zero_param_grad!(fuzz1, :std) |
200 | | - input_grads = accumulate_param_gradients!(tr, 1.0) |
| 195 | + for z in [1.0, 3.0, 5.0, 10.0] |
| 196 | + chm = choicemap((:z, z)) |
| 197 | + tr, _ = generate(bam, (1, ), chm) |
| 198 | + zero_param_grad!(bang1, :std) |
| 199 | + input_grads = accumulate_param_gradients!(tr, 1.0) |
| 200 | + expected_std_grad = logpdf_grad(normal, tr[:x => :z], 5.0 + 3.0, 3.0)[3] |
| 201 | + @test isapprox(get_param_grad(bang1, :std), expected_std_grad) |
| 202 | + tr, _ = generate(bam, (2, ), chm) |
| 203 | + zero_param_grad!(fuzz1, :std) |
| 204 | + input_grads = accumulate_param_gradients!(tr, 1.0) |
| 205 | + expected_std_grad = logpdf_grad(normal, tr[:x => :z], 5.0 + 2 * 3.0, 3.0)[3] |
| 206 | + @test isapprox(get_param_grad(fuzz1, :std), expected_std_grad) |
| 207 | + end |
201 | 208 | end |
202 | 209 |
|
203 | 210 | # ------------ (More complex) hierarchy ------------ # |
|
0 commit comments