Skip to content

Commit dce003c

Browse files
authored
Merge pull request #448 from probcomp/mrb/unfold_gradient_default_args
Fix Unfold propagating gradients to default parameters.
2 parents 3cc93fd + 56115e3 commit dce003c

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
Manifest.toml
12
*.pdf
23
*.png
34
*.jld
@@ -7,4 +8,4 @@
78
*.log
89
docs/build/
910
docs/site/
10-
.DS_Store
11+
.DS_Store

src/modeling_library/unfold/backprop.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
@inline fold_sum(::Nothing, a::A) where A = a
55
@inline fold_sum(a::A, b::A) where A = a + b
66

7+
@inline _sum(::Nothing) = nothing
8+
@inline _sum(x::Vector) = sum(x)
9+
710
function choice_gradients(trace::VectorTrace{UnfoldType,T,U}, selection::Selection, retval_grad) where {T,U}
811
kernel_has_grads = has_argument_grads(trace.gen_fn.kernel)
912
if kernel_has_grads[1]
@@ -44,7 +47,7 @@ function choice_gradients(trace::VectorTrace{UnfoldType,T,U}, selection::Selecti
4447
end
4548
end
4649
end
47-
((nothing, kernel_arg_grads[2], params_grad...), value_choices, gradient_choices)
50+
((nothing, kernel_arg_grads[2], map(_sum, params_grad)...), value_choices, gradient_choices)
4851
end
4952

5053
function accumulate_param_gradients!(trace::VectorTrace{UnfoldType,T,U}, retval_grad) where {T,U}
@@ -82,5 +85,5 @@ function accumulate_param_gradients!(trace::VectorTrace{UnfoldType,T,U}, retval_
8285
end
8386
end
8487
end
85-
(nothing, kernel_arg_grads[2], params_grad...)
88+
(nothing, kernel_arg_grads[2], map(_sum, params_grad)...)
8689
end

test/modeling_library/unfold.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,10 +506,13 @@ foo = Unfold(kernel)
506506

507507
zero_param_grad!(kernel, :std)
508508
input_grads = accumulate_param_gradients!(trace, nothing)
509+
expected_xs_grad = logpdf_grad(normal, x1, x_init * alpha + beta, std)[2] * x_init + logpdf_grad(normal, x2, x1 * alpha + beta, std)[2] * x1
510+
expected_ys_grad = logpdf_grad(normal, x1, x_init * alpha + beta, std)[2] + logpdf_grad(normal, x2, x1 * alpha + beta, std)[2]
511+
509512
@test input_grads[1] == nothing # length
510513
@test input_grads[2] == nothing # inital state
511-
#@test isapprox(input_grads[3], expected_xs_grad) # alpha
512-
#@test isapprox(input_grads[4], expected_ys_grad) # beta
514+
@test isapprox(input_grads[3], expected_xs_grad) # alpha
515+
@test isapprox(input_grads[4], expected_ys_grad) # beta
513516
expected_std_grad = (logpdf_grad(normal, x1, x_init * alpha + beta, std)[3]
514517
+ logpdf_grad(normal, x2, x1 * alpha + beta, std)[3])
515518
@test isapprox(get_param_grad(kernel, :std), expected_std_grad)

0 commit comments

Comments
 (0)