Skip to content

Commit 419d485

Browse files
committed
Fix Unfold propagating gradients to default parameters.
1 parent 8e00d7f commit 419d485

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
Manifest.toml
12
*.pdf
23
*.png
34
*.jld
@@ -7,4 +8,4 @@
78
*.log
89
docs/build/
910
docs/site/
10-
.DS_Store
11+
.DS_Store

src/modeling_library/unfold/backprop.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
@inline fold_sum(::Nothing, a::A) where A = a
55
@inline fold_sum(a::A, b::A) where A = a + b
66

7+
@inline _sum(::Nothing) = nothing
8+
@inline _sum(x::Vector) = sum(x)
9+
710
function choice_gradients(trace::VectorTrace{UnfoldType,T,U}, selection::Selection, retval_grad) where {T,U}
811
kernel_has_grads = has_argument_grads(trace.gen_fn.kernel)
912
if kernel_has_grads[1]
@@ -44,7 +47,7 @@ function choice_gradients(trace::VectorTrace{UnfoldType,T,U}, selection::Selecti
4447
end
4548
end
4649
end
47-
((nothing, kernel_arg_grads[2], params_grad...), value_choices, gradient_choices)
50+
((nothing, kernel_arg_grads[2], map(_sum, params_grad)...), value_choices, gradient_choices)
4851
end
4952

5053
function accumulate_param_gradients!(trace::VectorTrace{UnfoldType,T,U}, retval_grad) where {T,U}
@@ -82,5 +85,5 @@ function accumulate_param_gradients!(trace::VectorTrace{UnfoldType,T,U}, retval_
8285
end
8386
end
8487
end
85-
(nothing, kernel_arg_grads[2], params_grad...)
88+
(nothing, kernel_arg_grads[2], map(_sum, params_grad)...)
8689
end

0 commit comments

Comments
 (0)