|
4 | 4 | @inline fold_sum(::Nothing, a::A) where A = a |
5 | 5 | @inline fold_sum(a::A, b::A) where A = a + b |
6 | 6 |
|
| 7 | +@inline _sum(::Nothing) = nothing |
| 8 | +@inline _sum(x::Vector) = sum(x) |
| 9 | + |
7 | 10 | function choice_gradients(trace::VectorTrace{UnfoldType,T,U}, selection::Selection, retval_grad) where {T,U} |
8 | 11 | kernel_has_grads = has_argument_grads(trace.gen_fn.kernel) |
9 | 12 | if kernel_has_grads[1] |
@@ -44,7 +47,7 @@ function choice_gradients(trace::VectorTrace{UnfoldType,T,U}, selection::Selecti |
44 | 47 | end |
45 | 48 | end |
46 | 49 | end |
47 | | - ((nothing, kernel_arg_grads[2], params_grad...), value_choices, gradient_choices) |
| 50 | + ((nothing, kernel_arg_grads[2], map(_sum, params_grad)...), value_choices, gradient_choices) |
48 | 51 | end |
49 | 52 |
|
50 | 53 | function accumulate_param_gradients!(trace::VectorTrace{UnfoldType,T,U}, retval_grad) where {T,U} |
@@ -82,5 +85,5 @@ function accumulate_param_gradients!(trace::VectorTrace{UnfoldType,T,U}, retval_ |
82 | 85 | end |
83 | 86 | end |
84 | 87 | end |
85 | | - (nothing, kernel_arg_grads[2], params_grad...) |
| 88 | + (nothing, kernel_arg_grads[2], map(_sum, params_grad)...) |
86 | 89 | end |
0 commit comments