Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"

[compat]
Adapt = "3"
Expand All @@ -48,6 +49,7 @@ TerminalLoggers = "0.1"
Zygote = "0.5, 0.6"
ZygoteRules = "0.2"
julia = "1.5"
GenericLinearAlgebra = "0.2.5"

[extras]
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
Expand Down
28 changes: 28 additions & 0 deletions docs/src/examples/lagrangian_nn.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# One point test
using Flux, ReverseDiff, LagrangianNN
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be moved to test folder and included in runtests.jl


m, k, b = 1, 1, 1

X = rand(2,1)
Y = -k.*X[1]/m

g = Chain(Dense(2, 10, σ), Dense(10,1))
model = LagrangianNN(g)
params = model.params
re = model.re

# some toy loss function
function loss(x, y, p)
nn = x -> model(x,p)
out = sum((y .- (nn(x))).^2)
out
end
opt = ADAM(0.01)
epochs = 100

for epoch in 1:epochs
x, y = X, Y
gs = ReverseDiff.gradient(p -> loss(x, y, p), params)
Flux.Optimise.update!(opt, params, gs)
@show loss(x,y,params)
end
4 changes: 3 additions & 1 deletion src/DiffEqFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module DiffEqFlux

using GalacticOptim, DataInterpolations, DiffEqBase, DiffResults, DiffEqSensitivity,
Distributions, ForwardDiff, Flux, Requires, Adapt, LinearAlgebra, RecursiveArrayTools,
StaticArrays, Base.Iterators, Printf, Zygote
StaticArrays, Base.Iterators, Printf, Zygote, GenericLinearAlgebra

using DistributionsAD
import ProgressLogging, ZygoteRules
Expand Down Expand Up @@ -82,11 +82,13 @@ include("tensor_product_basis.jl")
include("tensor_product_layer.jl")
include("collocation.jl")
include("hnn.jl")
include("lnn.jl")
include("multiple_shooting.jl")

export diffeq_fd, diffeq_rd, diffeq_adjoint
export DeterministicCNF, FFJORD, NeuralODE, NeuralDSDE, NeuralSDE, NeuralCDDE, NeuralDAE, NeuralODEMM, TensorLayer, AugmentedNDELayer, SplineLayer, NeuralHamiltonianDE
export HamiltonianNN
export LagrangianNN
export ChebyshevBasis, SinBasis, CosBasis, FourierBasis, LegendreBasis, PolynomialBasis
export neural_ode, neural_ode_rd
export neural_dmsde
Expand Down
42 changes: 42 additions & 0 deletions src/lnn.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Constructs a Lagrangian Neural Network [1].

References:
[1] Miles Cranmer, Sam Greydanus, Stephan Hoyer, Peter Battaglia, David Spergel, and Shirley Ho.Lagrangian Neural Networks.
InICLR 2020 Workshop on Integration of Deep Neural Modelsand Differential Equations, 2020.
"""

struct LagrangianNN
model
re
params

# Define inner constructor method
function LagrangianNN(model; p = nothing)
_p, re = Flux.destructure(model)
if p === nothing
p = _p
end
return new(model, re, p)
end
end

function (nn::LagrangianNN)(x, p = nn.params)
@assert size(x,1) % 2 === 0 # velocity df should be equal to coords degree of freedom
M = div(size(x,1), 2) # number of velocities degrees of freedom
re = nn.re
hess = x -> Zygote.hessian_reverse(x->sum(re(p)(x)), x) # we have to compute the whole hessian
hess = hess(x)[M+1:end, M+1:end] # takes only velocities
inv_hess = GenericLinearAlgebra.pinv(hess)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why pinv?


_grad_q = x -> Zygote.gradient(x->sum(re(p)(x)), x)[end]
_grad_q = _grad_q(x)[1:M,:] # take only coord derivatives
out1 =_grad_q

# Second term
_grad_qv = x -> Zygote.gradient(x->sum(re(p)(x)), x)[end]
_jac_qv = x -> Zygote.jacobian(x->_grad_qv(x), x)[end]
out2 = _jac_qv(x)[1:M,M+1:end] * x[M+1:end] # take only dqdq_dot derivatives

return inv_hess * (out1 .+ out2)
end