Skip to content

Commit 97473d0

Browse files
committed
Added accumulate_param_gradients! tests.
1 parent cb62fb5 commit 97473d0

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

test/modeling_library/switch.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,19 @@
192192
end
193193

194194
@testset "accumulate parameter gradients" begin
195-
tr = simulate(bam, (1, ))
196-
zero_param_grad!(bang1, :std)
197-
input_grads = accumulate_param_gradients!(tr, 1.0)
198-
tr = simulate(bam, (2, ))
199-
zero_param_grad!(fuzz1, :std)
200-
input_grads = accumulate_param_gradients!(tr, 1.0)
195+
for z in [1.0, 3.0, 5.0, 10.0]
196+
chm = choicemap((:z, z))
197+
tr, _ = generate(bam, (1, ), chm)
198+
zero_param_grad!(bang1, :std)
199+
input_grads = accumulate_param_gradients!(tr, 1.0)
200+
expected_std_grad = logpdf_grad(normal, tr[:x => :z], 5.0 + 3.0, 3.0)[3]
201+
@test isapprox(get_param_grad(bang1, :std), expected_std_grad)
202+
tr, _ = generate(bam, (2, ), chm)
203+
zero_param_grad!(fuzz1, :std)
204+
input_grads = accumulate_param_gradients!(tr, 1.0)
205+
expected_std_grad = logpdf_grad(normal, tr[:x => :z], 5.0 + 2 * 3.0, 3.0)[3]
206+
@test isapprox(get_param_grad(fuzz1, :std), expected_std_grad)
207+
end
201208
end
202209

203210
# ------------ (More complex) hierarchy ------------ #

0 commit comments

Comments
 (0)