diff --git a/Project.toml b/Project.toml index 02b567b..cfcea2b 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "1.6.1" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" @@ -17,6 +18,7 @@ Aqua = "0.8" ComponentArrays = "0.15.11" DifferentiationInterface = "0.6" ForwardDiff = "0.10.36" +IntervalSets = "0.7.10" JET = "0.8, 0.9" Lux = "1" LuxCore = "1" @@ -31,7 +33,7 @@ SciMLSensitivity = "7.72" SciMLStructures = "1.1.0" StableRNGs = "1" SymbolicIndexingInterface = "0.3.15" -Symbolics = "6.22" +Symbolics = "6.36" Test = "1.10" Zygote = "0.6.73" julia = "1.10" diff --git a/src/ModelingToolkitNeuralNets.jl b/src/ModelingToolkitNeuralNets.jl index 36395cd..3ce8533 100644 --- a/src/ModelingToolkitNeuralNets.jl +++ b/src/ModelingToolkitNeuralNets.jl @@ -1,6 +1,7 @@ module ModelingToolkitNeuralNets using ModelingToolkit: @parameters, @named, ODESystem, t_nounits +using IntervalSets: var".." using ModelingToolkitStandardLibrary.Blocks: RealInputArray, RealOutputArray using Symbolics: Symbolics, @register_array_symbolic, @wrapped using LuxCore: stateless_apply @@ -8,7 +9,7 @@ using Lux: Lux using Random: Xoshiro using ComponentArrays: ComponentArray -export NeuralNetworkBlock, multi_layer_feed_forward +export NeuralNetworkBlock, SymbolicNeuralNetwork, multi_layer_feed_forward, get_network include("utils.jl") @@ -32,16 +33,17 @@ function NeuralNetworkBlock(; n_input = 1, n_output = 1, @parameters p[1:length(ca)] = Vector(ca) @parameters T::typeof(typeof(ca))=typeof(ca) [tunable = false] + @parameters lux_model::typeof(chain) = chain @named input = RealInputArray(nin = n_input) @named output = RealOutputArray(nout = n_output) - out = stateless_apply(chain, input.u, lazyconvert(T, p)) + out = stateless_apply(lux_model, input.u, lazyconvert(T, p)) eqs = [output.u ~ out] ude_comp = ODESystem( - eqs, t_nounits, [], [p, T]; systems = [input, output], name) + eqs, t_nounits, [], [lux_model, p, T]; systems = [input, output], name) return ude_comp end @@ -55,4 +57,74 @@ function lazyconvert(T, x::Symbolics.Arr) Symbolics.array_term(convert, T, x, size = size(x)) end +""" + SymbolicNeuralNetwork(; n_input = 1, n_output = 1, + chain = multi_layer_feed_forward(n_input, n_output), + rng = Xoshiro(0), + init_params = Lux.initialparameters(rng, chain), + nn_name = :NN, + nn_p_name = :p, + eltype = Float64) + +Create symbolic parameter for a neural network and one for its parameters. +Example: + +``` +chain = multi_layer_feed_forward(2, 2) +NN, p = SymbolicNeuralNetwork(; chain, n_input=2, n_output=2, rng = StableRNG(42)) +``` + +The NN and p are symbolic parameters that can be used later as part of a system. +To change the name of the symbolic variables, use `nn_name` and `nn_p_name`. +To get the predictions of the neural network, use + +``` +pred ~ NN(input, p) +``` + +where `pred` and `input` are a symbolic vector variable with the lengths `n_output` and `n_input`. + +To use this outside of an equation, you can get the default values for the symbols and make a similar call + +``` +defaults(sys)[sys.NN](input, nn_p) +``` + +where `sys` is a system (e.g. `ODESystem`) that contains `NN`, `input` is a vector of `n_input` length and +`nn_p` is a vector representing parameter values for the neural network. + +To get the underlying Lux model you can use `get_network(defaults(sys)[sys.NN])` or +""" +function SymbolicNeuralNetwork(; n_input = 1, n_output = 1, + chain = multi_layer_feed_forward(n_input, n_output), + rng = Xoshiro(0), + init_params = Lux.initialparameters(rng, chain), + nn_name = :NN, + nn_p_name = :p, + eltype = Float64) + ca = ComponentArray{eltype}(init_params) + wrapper = StatelessApplyWrapper(chain, typeof(ca)) + + p = @parameters $(nn_p_name)[1:length(ca)] = Vector(ca) + NN = @parameters ($(nn_name)::typeof(wrapper))(..)[1:n_output] = wrapper + + return only(NN), only(p) +end + +struct StatelessApplyWrapper{NN} + lux_model::NN + T::DataType +end + +function (wrapper::StatelessApplyWrapper)(input::AbstractArray, nn_p::AbstractVector) + stateless_apply(get_network(wrapper), input, convert(wrapper.T, nn_p)) +end + +function Base.show(io::IO, m::MIME"text/plain", wrapper::StatelessApplyWrapper) + printstyled(io, "LuxCore.stateless_apply wrapper for:\n", color = :gray) + show(io, m, get_network(wrapper)) +end + +get_network(wrapper::StatelessApplyWrapper) = wrapper.lux_model + end diff --git a/test/lotka_volterra.jl b/test/lotka_volterra.jl index 39e79bd..f85992b 100644 --- a/test/lotka_volterra.jl +++ b/test/lotka_volterra.jl @@ -51,12 +51,12 @@ chain = multi_layer_feed_forward(2, 2) eqs = [connect(model.nn_in, nn.output) connect(model.nn_out, nn.input)] - +eqs = [model.nn_in.u ~ nn.output.u, model.nn_out.u ~ nn.input.u] ude_sys = complete(ODESystem( eqs, ModelingToolkit.t_nounits, systems = [model, nn], name = :ude_sys)) -sys = structural_simplify(ude_sys) +sys = structural_simplify(ude_sys, allow_symbolic = true) prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 1.0), []) @@ -103,7 +103,7 @@ ps = (prob, sol_ref, get_vars, get_refs, set_x); @test all(.!isnan.(∇l1)) @test !iszero(∇l1) -@test ∇l1≈∇l2 rtol=1e-2 +@test ∇l1≈∇l2 rtol=1e-3 @test ∇l1≈∇l3 rtol=1e-5 op = OptimizationProblem(of, x0, ps) @@ -135,3 +135,32 @@ res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t) # plot!(res_sol, idxs = [sys.lotka.x, sys.lotka.y]) @test SciMLBase.successful_retcode(res_sol) + +function lotka_ude2() + @variables t x(t)=3.1 y(t)=1.5 pred(t)[1:2] + @parameters α=1.3 [tunable = false] δ=1.8 [tunable = false] + chain = multi_layer_feed_forward(2, 2) + NN, p = SymbolicNeuralNetwork(; chain, n_input = 2, n_output = 2, rng = StableRNG(42)) + Dt = ModelingToolkit.D_nounits + + eqs = [pred ~ NN([x, y], p) + Dt(x) ~ α * x + pred[1] + Dt(y) ~ -δ * y + pred[2]] + return ODESystem(eqs, ModelingToolkit.t_nounits, name = :lotka) +end + +sys2 = structural_simplify(lotka_ude2()) + +prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys2, [], (0, 1.0), []) + +sol = solve(prob, Rodas5P(), abstol = 1e-10, reltol = 1e-8) + +@test SciMLBase.successful_retcode(sol) + +set_x2 = setp_oop(sys2, sys2.p) +ps2 = (prob, sol_ref, get_vars, get_refs, set_x2); +op2 = OptimizationProblem(of, x0, ps2) + +res2 = solve(op2, Adam(), maxiters = 10000) + +@test res.u ≈ res2.u