Skip to content

Commit 56115e3

Browse files
committed
Added default params grad tests for Unfold.
1 parent 5e01c62 commit 56115e3

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

test/modeling_library/unfold.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,10 +506,13 @@ foo = Unfold(kernel)
506506

507507
zero_param_grad!(kernel, :std)
508508
input_grads = accumulate_param_gradients!(trace, nothing)
509+
expected_xs_grad = logpdf_grad(normal, x1, x_init * alpha + beta, std)[2] * x_init + logpdf_grad(normal, x2, x1 * alpha + beta, std)[2] * x1
510+
expected_ys_grad = logpdf_grad(normal, x1, x_init * alpha + beta, std)[2] + logpdf_grad(normal, x2, x1 * alpha + beta, std)[2]
511+
509512
@test input_grads[1] == nothing # length
510513
@test input_grads[2] == nothing # inital state
511-
#@test isapprox(input_grads[3], expected_xs_grad) # alpha
512-
#@test isapprox(input_grads[4], expected_ys_grad) # beta
514+
@test isapprox(input_grads[3], expected_xs_grad) # alpha
515+
@test isapprox(input_grads[4], expected_ys_grad) # beta
513516
expected_std_grad = (logpdf_grad(normal, x1, x_init * alpha + beta, std)[3]
514517
+ logpdf_grad(normal, x2, x1 * alpha + beta, std)[3])
515518
@test isapprox(get_param_grad(kernel, :std), expected_std_grad)

0 commit comments

Comments
 (0)