Skip to content

Symbolic Gradients of Neural Networks #44

@AlCap23

Description

@AlCap23

What kind of problems is it mostly used for? Please describe.

Sometimes, the explicit form of a gradient of neural networks is required.

A (specific) but good example is the optimal experimental design of UDEs, where explicit sensitivity equations are needed to augment the system ( if possible in a functional form ). Right now, this is hardly possible ( or just for very small nets ), mostly due to compile issues of the resulting gradients. ( At least on my machine, M3, 18 GB Ram ).

I think here would be a good place to store this :). Otherwise I would move it into a separate Repository.

Describe the algorithm you’d like

Instead of deriving the gradient of the full chain in a single sweep, the symbolic augmentation of a simple Chain consisting of Dense layers can be done very easily using the chain rule. In a sense, this would provide the forward sensitivity of a full chain in a structured manor.

A MWE I have been cooking up:

using Lux 
using Random
using ComponentArrays
using Symbolics
using Symbolics.SymbolicUtils.Code
using Symbolics.SymbolicUtils

# Just register all the activations from NNlib and their gradients here.
∇swish(x::T) where T = (1+exp(-x)+x*exp(-x))/(1+exp(-x))^2
myswish(x::T) where T =  x/(1+exp(-x))

@register_symbolic myswish(x::Real)::Real
@register_symbolic ∇swish(x::Real)::Real

Symbolics.derivative(::typeof(myswish), (x,)::Tuple, ::Base.Val{1}) = begin
    ∇swish(x)
end

model = Lux.Chain(
    Dense(7, 10, myswish), 
    Dense(10, 10, myswish), 
    Dense(10, 3, myswish)
)

p0, st = Lux.setup(Random.default_rng(), model)
p0 = ComponentArray(p0)

using MacroTools

# Maybe this is not needed, I've added this to speed up the compilation
function simplify_expression(val, us, ps)
    varmatcher = let psyms = toexpr.(ps), usyms = toexpr.(us)
        (x) -> begin
            if x  usyms
                id = findfirst(==(x), usyms)
                return :(getindex(x, $(id)))
            end
            if x  psyms
                id = findfirst(==(x), psyms)
                return :(getindex(p, $(id)))
            end
            return x
        end
    end    

    returns = [gensym() for i in eachindex(val)]
    body = Expr[]
    # This simplifies the inputs etc
    for i in eachindex(returns)
        subex = toexpr(val[i])
        subex = MacroTools.postwalk(varmatcher, subex)
        push!(body, 
            :($(returns[i]) = $(subex))
        )
    end
    # Return the right shape
    push!(
        body, 
        :(return reshape([$(returns...)], $(size(val)))::AbstractMatrix{promote_type(T, P)})
    );
    # Make the right signature
    :(function (x::AbstractVector{T},p::AbstractVector{P}) where {T, P}
        $(body...)
    end)
end 

struct GradientLayer{L, DU, DP} <: Lux.AbstractExplicitLayer
    layer::L
    du::DU 
    dp::DP
end

function GradientLayer(layer::Lux.Dense)
    (; in_dims, out_dims, activation) = layer
    p = Lux.LuxCore.initialparameters(Random.default_rng(), layer)
    # Make symbolic parameters 
    nparams = sum(prod  size, p)
    ps = Symbolics.variables(gensym(),Base.OneTo(nparams))
    us = Symbolics.variables(gensym(), Base.OneTo(in_dims))
    W = reshape(ps[1:in_dims*out_dims], out_dims, in_dims)
    b = ps[(in_dims*out_dims+1):end]
    ex = activation.(W*us+b)
    dfdu = Symbolics.jacobian(ex, us)
    dfdp = Symbolics.jacobian(ex, ps)
    # Build the gradient w.r.t. to input and parameters
    dfduex = simplify_expression(dfdu, us, ps)
    dfdpex = simplify_expression(dfdp, us, ps)
    # Build the function 
    dfdu = eval(dfduex)
    dfdp = eval(dfdpex)
    return GradientLayer{typeof(layer), typeof(dfdu), typeof(dfdp)}(
        layer, dfdu, dfdp
    )
end

Lux.LuxCore.initialparameters(rng::Random.AbstractRNG, layer::GradientLayer) = LuxCore.initialparameters(rng, layer.layer)
Lux.LuxCore.initialstates(rng::Random.AbstractRNG, layer::GradientLayer) = LuxCore.initialstates(rng, layer.layer)

function (glayer::GradientLayer)(u::AbstractArray, ps, st::NamedTuple)
    pvec = reduce(vcat, ps)
    ∇u_i = glayer.du(u, pvec)
    ∇p_i = glayer.dp(u, pvec)
    next, st = glayer.layer(u, ps, st)
    return (next, ∇u_i, ∇p_i), st
end

function (glayer::GradientLayer)((u, du, dp)::Tuple, ps, st::NamedTuple)
    pvec = reduce(vcat, ps)
    ∇u_i = glayer.du(u, pvec) 
    ∇p_i = glayer.dp(u, pvec)
    next, st = glayer.layer(u, ps, st)
    # Note: This assumes right now a sequential chain. More advanced layers would probably need a dispatch
    return (next, ∇u_i * du, hcat(∇u_i*dp, ∇p_i)), st
end

function symbolify(chain::Lux.Chain, p, st, name)
    new_layers = map(GradientLayer, chain.layers)
    new_chain = Lux.Chain(new_layers...; name)
    (new_chain, Lux.setup(Random.default_rng(), new_chain...))
end

function symbolify(layer::Lux.Dense, p, st, name)
    new_layer = GradientLayer(layer)
    (new_layer, Lux.setup(Random.default_rng(), new_layer)...)
end

# In general, we can just `symbolify` a Chain or add a `GradientChain` constructor here.
newmodel, newps, newst = Lux.Experimental.layer_map(symbolify, model, p0, st);
u0 = rand(7)
ret, _ = newmodel(u0, newps, newst)
@code_warntype newmodel(u0, newps, newst)
using Zygote

du_zyg = Zygote.jacobian(u->first(model(u, newps, newst)), u0)
dp_zyg = Zygote.jacobian(u->first(model(u0, u, newst)), newps)

# Returns a triplet (y, dy/du, dy/dp)
rets, _ = newmodel(u0, newps, newst);

Other implementations to know about

I don't know of any.

References

Just for completeness the preprint for the optimal experimental design.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions