-
Notifications
You must be signed in to change notification settings - Fork 162
Open
Description
As stated in the title, accum_param_gradients! does not support scale_factor for static functions. Calling accum_param_gradients! with a third argument returns ERROR: Not implemented, because it defaults to the abstract GFI definition.
This is due to (1) the lack of a generated method definition with the appropriate signature:
Gen.jl/src/static_ir/backprop.jl
Lines 508 to 512 in e5ed96f
| push!(generated_functions, quote | |
| @generated function $(GlobalRef(Gen, :accumulate_param_gradients!))(trace::T, retval_grad) where {T<:$(QuoteNode(StaticIRTrace))} | |
| $(QuoteNode(codegen_accumulate_param_gradients!))(trace, retval_grad) | |
| end | |
| end) |
And (2) the lack of logic to handle a scale factor in the backward pass for trainable parameter nodes:
Gen.jl/src/static_ir/backprop.jl
Lines 169 to 185 in e5ed96f
| function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::TrainableParameterNode, mode) | |
| # handle case when it is the return node | |
| if node === ir.return_node && node in fwd_marked | |
| @assert node in back_marked | |
| push!(stmts, :(isnothing(retval_grad) && error("Required return value gradient but got nothing"))) | |
| push!(stmts, :($(gradient_var(node)) += retval_grad)) | |
| end | |
| if node in fwd_marked && node in back_marked | |
| cur_param_grad = :($(QuoteNode(get_param_grad))(trace.$static_ir_gen_fn_ref, | |
| $(QuoteNode(node.name)))) | |
| push!(stmts, :($(QuoteNode(set_param_grad!))(trace.$static_ir_gen_fn_ref, | |
| $(QuoteNode(node.name)), | |
| $cur_param_grad + $(gradient_var(node))))) | |
| end | |
| end |
Metadata
Metadata
Assignees
Labels
No labels