Skip to content

Commit 07ae12e

Browse files
committed
Fix accumulate_param_gradients! for Map and Unfold.
1 parent 614a10a commit 07ae12e

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

src/gen_fn_interface.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,9 @@ If an argument is not annotated with `(grad)`, the corresponding value in
356356
Also increment the gradient accumulators for the trainable parameters \$Θ\$ of
357357
the function by:
358358
```math
359-
∇_Θ \\left( \\log P(t; x) + J \\right)
359+
s * ∇_Θ \\left( \\log P(t; x) + J \\right)
360360
```
361+
where \$s\$ is `scale_factor`.
361362
"""
362363
function accumulate_param_gradients!(trace, retgrad, scale_factor)
363364
error("Not implemented")

src/modeling_library/map/backprop.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function choice_gradients(trace::VectorTrace{MapType,T,U}, selection::Selection,
3535
((arg_grad...,), value_choices, gradient_choices)
3636
end
3737

38-
function accumulate_param_gradients!(trace::VectorTrace{MapType,T,U}, retval_grad) where {T,U}
38+
function accumulate_param_gradients!(trace::VectorTrace{MapType,T,U}, retval_grad, scale_factor) where {T,U}
3939

4040
args = get_args(trace)
4141
n_args = length(args)
@@ -54,7 +54,7 @@ function accumulate_param_gradients!(trace::VectorTrace{MapType,T,U}, retval_gra
5454
for key=1:len
5555
subtrace = trace.subtraces[key]
5656
kernel_retval_grad = (retval_grad == nothing) ? nothing : retval_grad[key]
57-
kernel_arg_grad::Tuple = accumulate_param_gradients!(subtrace, kernel_retval_grad)
57+
kernel_arg_grad::Tuple = accumulate_param_gradients!(subtrace, kernel_retval_grad, scale_factor)
5858
for (i, grad, has_grad) in zip(1:n_args, kernel_arg_grad, has_grads)
5959
if has_grad
6060
arg_grad[i][key] = grad

src/modeling_library/unfold/backprop.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ function choice_gradients(trace::VectorTrace{UnfoldType,T,U}, selection::Selecti
5050
((nothing, kernel_arg_grads[2], map(_sum, params_grad)...), value_choices, gradient_choices)
5151
end
5252

53-
function accumulate_param_gradients!(trace::VectorTrace{UnfoldType,T,U}, retval_grad) where {T,U}
53+
function accumulate_param_gradients!(trace::VectorTrace{UnfoldType,T,U}, retval_grad, scale_factor) where {T,U}
5454
kernel_has_grads = has_argument_grads(trace.gen_fn.kernel)
5555
if kernel_has_grads[1]
5656
error("Cannot differentiate with respect to index in unfold")
@@ -76,7 +76,7 @@ function accumulate_param_gradients!(trace::VectorTrace{UnfoldType,T,U}, retval_
7676
if state_has_grad
7777
kernel_retval_grad = fold_sum(kernel_retval_grad, kernel_arg_grads[2])
7878
end
79-
kernel_arg_grads = accumulate_param_gradients!(subtrace, kernel_retval_grad)
79+
kernel_arg_grads = accumulate_param_gradients!(subtrace, kernel_retval_grad, scale_factor)
8080
@assert kernel_arg_grads[1] == nothing
8181
state_has_grad || @assert kernel_arg_grads[2] == nothing
8282
for (i, (grad, has_grad)) in enumerate(zip(kernel_arg_grads[3:end], params_has_grad))

0 commit comments

Comments
 (0)