Skip to content

implement custom gradient with multi-argument functions #197

@KnutAM

Description

@KnutAM

From Slack-comment by @koehlerson; how to implement custom gradient calculation for a multi-argument function.
It is common to have such a case for autodiff, so would be good to have a clear way of doing this.
The solution I can come up with now is

using Tensors
import ForwardDiff: Dual

# General setup for any function f(x, args...)
struct Foo{F,T<:Tuple} <: Function # <:Function optional
    f::F
    args::T
end
struct FooGrad{FT<:Foo} <: Function # <: Function required
    foo::FT
end

function (foo::Foo)(x)
    println("Foo with Any: ", typeof(x))  # To show that it works
    return foo.f(x, foo.args...)
end
function (foo::Foo)(x::AbstractTensor{<:Any,<:Any,<:Dual})
    println("Foo with Dual: ", typeof(x))  # To show that it works
    return Tensors._propagate_gradient(FooGrad(foo), x)
end
function (fg::FooGrad)(x)
    println("FooGrad: ", typeof(x)) # To show that it works
    return f_dfdx(fg.foo.f, x, fg.foo.args...)
end

# Specific example to setup for bar(x, a, b), must then also define f_dfdx(::typeof(bar), x, a, b):
bar(x, a, b) = norm(a*x)^b 
dbar_dx(x, a, b) = b*(a^b)*norm(x)^(b-2)*x
f_dfdx(::typeof(bar), args...) = (bar(args...), dbar_dx(args...))

# At the location in the code where the derivative will be calculated
t = rand(SymmetricTensor{2,3}); a = π; b = 2 # Typically inputs
foo = Foo(bar, (a, b))
gradient(foo, t) == dbar_dx(t, a, b)

But it is quite cumbersome, especially if only needed for one function, so a better method would be good.
(Tensors._propagate_gradient is renamed to propagate_gradient, exported, and documented in #181)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions