-
-
Notifications
You must be signed in to change notification settings - Fork 7
Description
What kind of problems is it mostly used for? Please describe.
Sometimes, the explicit form of a gradient of neural networks is required.
A (specific) but good example is the optimal experimental design of UDEs, where explicit sensitivity equations are needed to augment the system ( if possible in a functional form ). Right now, this is hardly possible ( or just for very small nets ), mostly due to compile issues of the resulting gradients. ( At least on my machine, M3, 18 GB Ram ).
I think here would be a good place to store this :). Otherwise I would move it into a separate Repository.
Describe the algorithm you’d like
Instead of deriving the gradient of the full chain in a single sweep, the symbolic augmentation of a simple Chain consisting of Dense layers can be done very easily using the chain rule. In a sense, this would provide the forward sensitivity of a full chain in a structured manor.
A MWE I have been cooking up:
using Lux
using Random
using ComponentArrays
using Symbolics
using Symbolics.SymbolicUtils.Code
using Symbolics.SymbolicUtils
# Just register all the activations from NNlib and their gradients here.
∇swish(x::T) where T = (1+exp(-x)+x*exp(-x))/(1+exp(-x))^2
myswish(x::T) where T = x/(1+exp(-x))
@register_symbolic myswish(x::Real)::Real
@register_symbolic ∇swish(x::Real)::Real
Symbolics.derivative(::typeof(myswish), (x,)::Tuple, ::Base.Val{1}) = begin
∇swish(x)
end
model = Lux.Chain(
Dense(7, 10, myswish),
Dense(10, 10, myswish),
Dense(10, 3, myswish)
)
p0, st = Lux.setup(Random.default_rng(), model)
p0 = ComponentArray(p0)
using MacroTools
# Maybe this is not needed, I've added this to speed up the compilation
function simplify_expression(val, us, ps)
varmatcher = let psyms = toexpr.(ps), usyms = toexpr.(us)
(x) -> begin
if x ∈ usyms
id = findfirst(==(x), usyms)
return :(getindex(x, $(id)))
end
if x ∈ psyms
id = findfirst(==(x), psyms)
return :(getindex(p, $(id)))
end
return x
end
end
returns = [gensym() for i in eachindex(val)]
body = Expr[]
# This simplifies the inputs etc
for i in eachindex(returns)
subex = toexpr(val[i])
subex = MacroTools.postwalk(varmatcher, subex)
push!(body,
:($(returns[i]) = $(subex))
)
end
# Return the right shape
push!(
body,
:(return reshape([$(returns...)], $(size(val)))::AbstractMatrix{promote_type(T, P)})
);
# Make the right signature
:(function (x::AbstractVector{T},p::AbstractVector{P}) where {T, P}
$(body...)
end)
end
struct GradientLayer{L, DU, DP} <: Lux.AbstractExplicitLayer
layer::L
du::DU
dp::DP
end
function GradientLayer(layer::Lux.Dense)
(; in_dims, out_dims, activation) = layer
p = Lux.LuxCore.initialparameters(Random.default_rng(), layer)
# Make symbolic parameters
nparams = sum(prod ∘ size, p)
ps = Symbolics.variables(gensym(),Base.OneTo(nparams))
us = Symbolics.variables(gensym(), Base.OneTo(in_dims))
W = reshape(ps[1:in_dims*out_dims], out_dims, in_dims)
b = ps[(in_dims*out_dims+1):end]
ex = activation.(W*us+b)
dfdu = Symbolics.jacobian(ex, us)
dfdp = Symbolics.jacobian(ex, ps)
# Build the gradient w.r.t. to input and parameters
dfduex = simplify_expression(dfdu, us, ps)
dfdpex = simplify_expression(dfdp, us, ps)
# Build the function
dfdu = eval(dfduex)
dfdp = eval(dfdpex)
return GradientLayer{typeof(layer), typeof(dfdu), typeof(dfdp)}(
layer, dfdu, dfdp
)
end
Lux.LuxCore.initialparameters(rng::Random.AbstractRNG, layer::GradientLayer) = LuxCore.initialparameters(rng, layer.layer)
Lux.LuxCore.initialstates(rng::Random.AbstractRNG, layer::GradientLayer) = LuxCore.initialstates(rng, layer.layer)
function (glayer::GradientLayer)(u::AbstractArray, ps, st::NamedTuple)
pvec = reduce(vcat, ps)
∇u_i = glayer.du(u, pvec)
∇p_i = glayer.dp(u, pvec)
next, st = glayer.layer(u, ps, st)
return (next, ∇u_i, ∇p_i), st
end
function (glayer::GradientLayer)((u, du, dp)::Tuple, ps, st::NamedTuple)
pvec = reduce(vcat, ps)
∇u_i = glayer.du(u, pvec)
∇p_i = glayer.dp(u, pvec)
next, st = glayer.layer(u, ps, st)
# Note: This assumes right now a sequential chain. More advanced layers would probably need a dispatch
return (next, ∇u_i * du, hcat(∇u_i*dp, ∇p_i)), st
end
function symbolify(chain::Lux.Chain, p, st, name)
new_layers = map(GradientLayer, chain.layers)
new_chain = Lux.Chain(new_layers...; name)
(new_chain, Lux.setup(Random.default_rng(), new_chain...))
end
function symbolify(layer::Lux.Dense, p, st, name)
new_layer = GradientLayer(layer)
(new_layer, Lux.setup(Random.default_rng(), new_layer)...)
end
# In general, we can just `symbolify` a Chain or add a `GradientChain` constructor here.
newmodel, newps, newst = Lux.Experimental.layer_map(symbolify, model, p0, st);
u0 = rand(7)
ret, _ = newmodel(u0, newps, newst)
@code_warntype newmodel(u0, newps, newst)
using Zygote
du_zyg = Zygote.jacobian(u->first(model(u, newps, newst)), u0)
dp_zyg = Zygote.jacobian(u->first(model(u0, u, newst)), newps)
# Returns a triplet (y, dy/du, dy/dp)
rets, _ = newmodel(u0, newps, newst);Other implementations to know about
I don't know of any.
References
Just for completeness the preprint for the optimal experimental design.