@@ -15,13 +15,13 @@ using DifferentiationInterface
1515using SciMLSensitivity
1616using Zygote: Zygote
1717using Statistics
18+ using Lux
1819
19- function lotka_ude ()
20+ function lotka_ude (chain )
2021 @variables t x (t)= 3.1 y (t)= 1.5
2122 @parameters α= 1.3 [tunable= false ] δ= 1.8 [tunable= false ]
2223 Dt = ModelingToolkit. D_nounits
2324
24- chain = multi_layer_feed_forward (2 , 2 )
2525 @named nn = NeuralNetworkBlock (2 , 2 ; chain, rng = StableRNG (42 ))
2626
2727 eqs = [
3636
3737function lotka_true ()
3838 @variables t x (t)= 3.1 y (t)= 1.5
39- @parameters α= 1.3 β= 0.9 γ= 0.8 δ= 1.8
39+ @parameters α= 1.3 [tunable = false ] β= 0.9 γ= 0.8 δ= 1.8 [tunable = false ]
4040 Dt = ModelingToolkit. D_nounits
4141
4242 eqs = [
4343 Dt (x) ~ α * x - β * x * y,
44- Dt (y) ~ - δ * y + δ * x * y
44+ Dt (y) ~ - δ * y + γ * x * y
4545 ]
4646 return System (eqs, ModelingToolkit. t_nounits, name = :lotka_true )
4747end
4848
49- ude_sys = lotka_ude ( )
49+ rbf (x) = exp .( - (x .^ 2 ) )
5050
51- sys = mtkcompile (ude_sys, allow_symbolic = true )
51+ chain = multi_layer_feed_forward (2 , 2 , width = 5 , initial_scaling_factor = 1 )
52+ ude_sys = lotka_ude (chain)
5253
53- prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys, [], (0 , 1.0 ))
54+ sys = mtkcompile (ude_sys)
55+
56+ @test length (equations (sys)) == 2
57+
58+ prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys, [], (0 , 5.0 ))
5459
5560model_true = mtkcompile (lotka_true ())
56- prob_true = ODEProblem {true, SciMLBase.FullSpecialize} (model_true, [], (0 , 1.0 ))
57- sol_ref = solve (prob_true, Vern9 (), abstol = 1e-10 , reltol = 1e-8 )
61+ prob_true = ODEProblem {true, SciMLBase.FullSpecialize} (model_true, [], (0 , 5.0 ))
62+ sol_ref = solve (prob_true, Vern9 (), abstol = 1e-8 , reltol = 1e-8 )
63+
64+ ts = range (0 , 5.0 , length = 21 )
65+ data = reduce (hcat, sol_ref (ts, idxs = [model_true. x, model_true. y]). u)
5866
5967x0 = default_values (sys)[sys. nn. p]
6068
6169get_vars = getu (sys, [sys. x, sys. y])
62- get_refs = getu (model_true, [model_true. x, model_true. y])
63- set_x = setp_oop (sys, sys. nn. p)
70+ set_x = setsym_oop (sys, sys. nn. p)
6471
65- function loss (x, (prob, sol_ref, get_vars, get_refs, set_x))
66- new_p = set_x (prob, x)
67- new_prob = remake (prob, p = new_p, u0 = eltype (x).(prob. u0))
68- ts = sol_ref. t
69- new_sol = solve (new_prob, Vern9 (), abstol = 1e-10 , reltol = 1e-8 , saveat = ts)
72+ function loss (x, (prob, get_vars, data, ts, set_x))
73+ new_u0, new_p = set_x (prob, x)
74+ new_prob = remake (prob, p = new_p, u0 = new_u0)
75+ new_sol = solve (new_prob, Vern9 (), abstol = 1e-8 , reltol = 1e-8 , saveat = ts)
7076
7177 if SciMLBase. successful_retcode (new_sol)
72- mean (abs2 .(reduce (hcat, get_vars (new_sol)) .- reduce (hcat, get_refs (sol_ref)) ))
78+ mean (abs2 .(reduce (hcat, get_vars (new_sol)) .- data ))
7379 else
7480 Inf
7581 end
7682end
7783
7884of = OptimizationFunction {true} (loss, AutoZygote ())
7985
80- ps = (prob, sol_ref, get_vars, get_refs , set_x);
86+ ps = (prob, get_vars, data, ts , set_x);
8187
8288@test_call target_modules= (ModelingToolkitNeuralNets,) loss (x0, ps)
8389@test_opt target_modules= (ModelingToolkitNeuralNets,) loss (x0, ps)
@@ -89,7 +95,7 @@ ps = (prob, sol_ref, get_vars, get_refs, set_x);
8995@test all (.! isnan .(∇l1))
9096@test ! iszero (∇l1)
9197
92- @test ∇l1≈ ∇l2 rtol= 1e-5
98+ @test ∇l1≈ ∇l2 rtol= 1e-4
9399@test ∇l1 ≈ ∇l3
94100
95101op = OptimizationProblem (of, x0, ps)
@@ -99,34 +105,36 @@ op = OptimizationProblem(of, x0, ps)
99105# oh = []
100106
101107# plot_cb = (opt_state, loss) -> begin
108+ # opt_state.iter % 500 ≠ 0 && return false
102109# @info "step $(opt_state.iter), loss: $loss"
103110# push!(oh, opt_state)
104111# new_p = SciMLStructures.replace(Tunable(), prob.p, opt_state.u)
105112# new_prob = remake(prob, p = new_p)
106- # sol = solve(new_prob, Rodas4() )
113+ # sol = solve(new_prob, Vern9(), abstol = 1e-8, reltol = 1e-8 )
107114# display(plot(sol))
108115# false
109116# end
110117
111- res = solve (op, Adam (), maxiters = 10000 )# , callback = plot_cb)
118+ res = solve (op, Adam (1e-3 ), maxiters = 25_000 )# , callback = plot_cb)
112119
113120display (res. stats)
114- @test res. objective < 1
121+ @test res. objective < 1.5e-4
115122
116- res_p = set_x (prob, res. u)
117- res_prob = remake (prob, p = res_p)
118- res_sol = solve (res_prob, Vern9 ())
123+ u0, p = set_x (prob, res. u)
124+ res_prob = remake (prob; u0, p)
125+ res_sol = solve (res_prob, Vern9 (), abstol = 1e-8 , reltol = 1e-8 , saveat = ts)
126+
127+ @test SciMLBase. successful_retcode (res_sol)
128+ @test mean (abs2 .(reduce (hcat, get_vars (res_sol)) .- data)) ≈ res. objective
119129
120130# using Plots
121131# plot(sol_ref, idxs = [model_true.x, model_true.y])
122132# plot!(res_sol, idxs = [sys.x, sys.y])
123133
124- @test SciMLBase. successful_retcode (res_sol)
125-
126134function lotka_ude2 ()
127135 @variables t x (t)= 3.1 y (t)= 1.5 pred (t)[1 : 2 ]
128136 @parameters α= 1.3 [tunable= false ] δ= 1.8 [tunable= false ]
129- chain = multi_layer_feed_forward (2 , 2 )
137+ chain = multi_layer_feed_forward (2 , 2 ; width = 5 , initial_scaling_factor = 1 )
130138 NN, p = SymbolicNeuralNetwork (; chain, n_input = 2 , n_output = 2 , rng = StableRNG (42 ))
131139 Dt = ModelingToolkit. D_nounits
132140
@@ -138,16 +146,16 @@ end
138146
139147sys2 = mtkcompile (lotka_ude2 ())
140148
141- prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys2, [], (0 , 1 .0 ))
149+ prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys2, [], (0 , 5 .0 ))
142150
143151sol = solve (prob, Vern9 (), abstol = 1e-10 , reltol = 1e-8 )
144152
145153@test SciMLBase. successful_retcode (sol)
146154
147- set_x2 = setp_oop (sys2, sys2. p)
148- ps2 = (prob, sol_ref, get_vars, get_refs , set_x2);
155+ set_x2 = setsym_oop (sys2, sys2. p)
156+ ps2 = (prob, get_vars, data, ts , set_x2);
149157op2 = OptimizationProblem (of, x0, ps2)
150158
151- res2 = solve (op2, Adam (), maxiters = 10000 )
159+ res2 = solve (op2, Adam (1e-3 ), maxiters = 25_000 )
152160
153161@test res. u ≈ res2. u
0 commit comments