-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Description
From Slack-comment by @koehlerson; how to implement custom gradient calculation for a multi-argument function.
It is common to have such a case for autodiff, so would be good to have a clear way of doing this.
The solution I can come up with now is
using Tensors
import ForwardDiff: Dual
# General setup for any function f(x, args...)
struct Foo{F,T<:Tuple} <: Function # <:Function optional
f::F
args::T
end
struct FooGrad{FT<:Foo} <: Function # <: Function required
foo::FT
end
function (foo::Foo)(x)
println("Foo with Any: ", typeof(x)) # To show that it works
return foo.f(x, foo.args...)
end
function (foo::Foo)(x::AbstractTensor{<:Any,<:Any,<:Dual})
println("Foo with Dual: ", typeof(x)) # To show that it works
return Tensors._propagate_gradient(FooGrad(foo), x)
end
function (fg::FooGrad)(x)
println("FooGrad: ", typeof(x)) # To show that it works
return f_dfdx(fg.foo.f, x, fg.foo.args...)
end
# Specific example to setup for bar(x, a, b), must then also define f_dfdx(::typeof(bar), x, a, b):
bar(x, a, b) = norm(a*x)^b
dbar_dx(x, a, b) = b*(a^b)*norm(x)^(b-2)*x
f_dfdx(::typeof(bar), args...) = (bar(args...), dbar_dx(args...))
# At the location in the code where the derivative will be calculated
t = rand(SymmetricTensor{2,3}); a = π; b = 2 # Typically inputs
foo = Foo(bar, (a, b))
gradient(foo, t) == dbar_dx(t, a, b)But it is quite cumbersome, especially if only needed for one function, so a better method would be good.
(Tensors._propagate_gradient is renamed to propagate_gradient, exported, and documented in #181)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels