|
4 | 4 | @inline fold_sum(::Nothing, a::A) where A = a
|
5 | 5 | @inline fold_sum(a::A, b::A) where A = a + b
|
6 | 6 |
|
| 7 | +@inline _sum(::Nothing) = nothing |
| 8 | +@inline _sum(x::Vector) = sum(x) |
| 9 | + |
7 | 10 | function choice_gradients(trace::VectorTrace{UnfoldType,T,U}, selection::Selection, retval_grad) where {T,U}
|
8 | 11 | kernel_has_grads = has_argument_grads(trace.gen_fn.kernel)
|
9 | 12 | if kernel_has_grads[1]
|
@@ -44,7 +47,7 @@ function choice_gradients(trace::VectorTrace{UnfoldType,T,U}, selection::Selecti
|
44 | 47 | end
|
45 | 48 | end
|
46 | 49 | end
|
47 |
| - ((nothing, kernel_arg_grads[2], params_grad...), value_choices, gradient_choices) |
| 50 | + ((nothing, kernel_arg_grads[2], map(_sum, params_grad)...), value_choices, gradient_choices) |
48 | 51 | end
|
49 | 52 |
|
50 | 53 | function accumulate_param_gradients!(trace::VectorTrace{UnfoldType,T,U}, retval_grad) where {T,U}
|
@@ -82,5 +85,5 @@ function accumulate_param_gradients!(trace::VectorTrace{UnfoldType,T,U}, retval_
|
82 | 85 | end
|
83 | 86 | end
|
84 | 87 | end
|
85 |
| - (nothing, kernel_arg_grads[2], params_grad...) |
| 88 | + (nothing, kernel_arg_grads[2], map(_sum, params_grad)...) |
86 | 89 | end
|
0 commit comments