diff --git a/test/lotka_volterra.jl b/test/lotka_volterra.jl index c61fbdf..f4e9096 100644 --- a/test/lotka_volterra.jl +++ b/test/lotka_volterra.jl @@ -15,13 +15,13 @@ using DifferentiationInterface using SciMLSensitivity using Zygote: Zygote using Statistics +using Lux -function lotka_ude() +function lotka_ude(chain) @variables t x(t)=3.1 y(t)=1.5 @parameters α=1.3 [tunable = false] δ=1.8 [tunable = false] Dt = ModelingToolkit.D_nounits - chain = multi_layer_feed_forward(2, 2) @named nn = NeuralNetworkBlock(2, 2; chain, rng = StableRNG(42)) eqs = [ @@ -36,40 +36,47 @@ end function lotka_true() @variables t x(t)=3.1 y(t)=1.5 - @parameters α=1.3 β=0.9 γ=0.8 δ=1.8 + @parameters α=1.3 [tunable = false] β=0.9 γ=0.8 δ=1.8 [tunable = false] Dt = ModelingToolkit.D_nounits eqs = [ Dt(x) ~ α * x - β * x * y, - Dt(y) ~ -δ * y + δ * x * y + Dt(y) ~ -δ * y + γ * x * y ] return System(eqs, ModelingToolkit.t_nounits, name = :lotka_true) end -ude_sys = lotka_ude() +rbf(x) = exp.(-(x .^ 2)) + +chain = Lux.Chain( + Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf), + Lux.Dense(5, 2)) +ude_sys = lotka_ude(chain) sys = mtkcompile(ude_sys, allow_symbolic = true) -prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 1.0)) +prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 5.0)) model_true = mtkcompile(lotka_true()) -prob_true = ODEProblem{true, SciMLBase.FullSpecialize}(model_true, [], (0, 1.0)) -sol_ref = solve(prob_true, Vern9(), abstol = 1e-10, reltol = 1e-8) +prob_true = ODEProblem{true, SciMLBase.FullSpecialize}(model_true, [], (0, 5.0)) +sol_ref = solve(prob_true, Vern9(), abstol = 1e-12, reltol = 1e-12) + +ts = range(0, 5.0, length = 21) +data = reduce(hcat, sol_ref(ts, idxs = [model_true.x, model_true.y]).u) x0 = default_values(sys)[sys.nn.p] get_vars = getu(sys, [sys.x, sys.y]) -get_refs = getu(model_true, [model_true.x, model_true.y]) -set_x = setp_oop(sys, sys.nn.p) +set_x = setsym_oop(sys, sys.nn.p) -function loss(x, (prob, sol_ref, get_vars, get_refs, set_x)) - new_p = set_x(prob, x) - new_prob = remake(prob, p = new_p, u0 = eltype(x).(prob.u0)) - ts = sol_ref.t +function loss(x, (prob, sol_ref, get_vars, data, ts, set_x)) + # new_u0, new_p = set_x(prob, 1, x) + new_u0, new_p = set_x(prob, x) + new_prob = remake(prob, p = new_p, u0 = new_u0) new_sol = solve(new_prob, Vern9(), abstol = 1e-10, reltol = 1e-8, saveat = ts) if SciMLBase.successful_retcode(new_sol) - mean(abs2.(reduce(hcat, get_vars(new_sol)) .- reduce(hcat, get_refs(sol_ref)))) + mean(abs2.(reduce(hcat, get_vars(new_sol)) .- data)) else Inf end @@ -77,7 +84,7 @@ end of = OptimizationFunction{true}(loss, AutoZygote()) -ps = (prob, sol_ref, get_vars, get_refs, set_x); +ps = (prob, sol_ref, get_vars, data, ts, set_x); @test_call target_modules=(ModelingToolkitNeuralNets,) loss(x0, ps) @test_opt target_modules=(ModelingToolkitNeuralNets,) loss(x0, ps) @@ -89,7 +96,7 @@ ps = (prob, sol_ref, get_vars, get_refs, set_x); @test all(.!isnan.(∇l1)) @test !iszero(∇l1) -@test ∇l1≈∇l2 rtol=1e-5 +@test ∇l1≈∇l2 rtol=1e-4 @test ∇l1 ≈ ∇l3 op = OptimizationProblem(of, x0, ps)