4848
4949rbf (x) = exp .(- (x .^ 2 ))
5050
51- chain = Lux. Chain (
52- Lux. Dense (2 , 5 , rbf), Lux. Dense (5 , 5 , rbf), Lux. Dense (5 , 5 , rbf),
53- Lux. Dense (5 , 2 ))
51+ chain = multi_layer_feed_forward (2 , 2 , width = 5 , initial_scaling_factor = 1 )
5452ude_sys = lotka_ude (chain)
5553
56- sys = mtkcompile (ude_sys, allow_symbolic = true )
54+ sys = mtkcompile (ude_sys)
55+
56+ @test length (equations (sys)) == 2
5757
5858prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys, [], (0 , 5.0 ))
5959
6060model_true = mtkcompile (lotka_true ())
6161prob_true = ODEProblem {true, SciMLBase.FullSpecialize} (model_true, [], (0 , 5.0 ))
62- sol_ref = solve (prob_true, Vern9 (), abstol = 1e-12 , reltol = 1e-12 )
62+ sol_ref = solve (prob_true, Vern9 (), abstol = 1e-8 , reltol = 1e-8 )
6363
6464ts = range (0 , 5.0 , length = 21 )
6565data = reduce (hcat, sol_ref (ts, idxs = [model_true. x, model_true. y]). u)
@@ -69,11 +69,10 @@ x0 = default_values(sys)[sys.nn.p]
6969get_vars = getu (sys, [sys. x, sys. y])
7070set_x = setsym_oop (sys, sys. nn. p)
7171
72- function loss (x, (prob, sol_ref, get_vars, data, ts, set_x))
73- # new_u0, new_p = set_x(prob, 1, x)
72+ function loss (x, (prob, get_vars, data, ts, set_x))
7473 new_u0, new_p = set_x (prob, x)
7574 new_prob = remake (prob, p = new_p, u0 = new_u0)
76- new_sol = solve (new_prob, Vern9 (), abstol = 1e-10 , reltol = 1e-8 , saveat = ts)
75+ new_sol = solve (new_prob, Vern9 (), abstol = 1e-8 , reltol = 1e-8 , saveat = ts)
7776
7877 if SciMLBase. successful_retcode (new_sol)
7978 mean (abs2 .(reduce (hcat, get_vars (new_sol)) .- data))
8483
8584of = OptimizationFunction {true} (loss, AutoZygote ())
8685
87- ps = (prob, sol_ref, get_vars, data, ts, set_x);
86+ ps = (prob, get_vars, data, ts, set_x);
8887
8988@test_call target_modules= (ModelingToolkitNeuralNets,) loss (x0, ps)
9089@test_opt target_modules= (ModelingToolkitNeuralNets,) loss (x0, ps)
@@ -106,34 +105,36 @@ op = OptimizationProblem(of, x0, ps)
106105# oh = []
107106
108107# plot_cb = (opt_state, loss) -> begin
108+ # opt_state.iter % 500 ≠ 0 && return false
109109# @info "step $(opt_state.iter), loss: $loss"
110110# push!(oh, opt_state)
111111# new_p = SciMLStructures.replace(Tunable(), prob.p, opt_state.u)
112112# new_prob = remake(prob, p = new_p)
113- # sol = solve(new_prob, Rodas4() )
113+ # sol = solve(new_prob, Vern9(), abstol = 1e-8, reltol = 1e-8 )
114114# display(plot(sol))
115115# false
116116# end
117117
118- res = solve (op, Adam (), maxiters = 10000 )# , callback = plot_cb)
118+ res = solve (op, Adam (1e-3 ), maxiters = 25_000 )# , callback = plot_cb)
119119
120120display (res. stats)
121- @test res. objective < 1
121+ @test res. objective < 1.5e-4
122+
123+ u0, p = set_x (prob, res. u)
124+ res_prob = remake (prob; u0, p)
125+ res_sol = solve (res_prob, Vern9 (), abstol = 1e-8 , reltol = 1e-8 , saveat = ts)
122126
123- res_p = set_x (prob, res. u)
124- res_prob = remake (prob, p = res_p)
125- res_sol = solve (res_prob, Vern9 ())
127+ @test SciMLBase. successful_retcode (res_sol)
128+ @test mean (abs2 .(reduce (hcat, get_vars (res_sol)) .- data)) ≈ res. objective
126129
127130# using Plots
128131# plot(sol_ref, idxs = [model_true.x, model_true.y])
129132# plot!(res_sol, idxs = [sys.x, sys.y])
130133
131- @test SciMLBase. successful_retcode (res_sol)
132-
133134function lotka_ude2 ()
134135 @variables t x (t)= 3.1 y (t)= 1.5 pred (t)[1 : 2 ]
135136 @parameters α= 1.3 [tunable = false ] δ= 1.8 [tunable = false ]
136- chain = multi_layer_feed_forward (2 , 2 )
137+ chain = multi_layer_feed_forward (2 , 2 ; width = 5 , initial_scaling_factor = 1 )
137138 NN, p = SymbolicNeuralNetwork (; chain, n_input = 2 , n_output = 2 , rng = StableRNG (42 ))
138139 Dt = ModelingToolkit. D_nounits
139140
@@ -145,16 +146,16 @@ end
145146
146147sys2 = mtkcompile (lotka_ude2 ())
147148
148- prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys2, [], (0 , 1 .0 ))
149+ prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys2, [], (0 , 5 .0 ))
149150
150151sol = solve (prob, Vern9 (), abstol = 1e-10 , reltol = 1e-8 )
151152
152153@test SciMLBase. successful_retcode (sol)
153154
154- set_x2 = setp_oop (sys2, sys2. p)
155- ps2 = (prob, sol_ref, get_vars, get_refs , set_x2);
155+ set_x2 = setsym_oop (sys2, sys2. p)
156+ ps2 = (prob, get_vars, data, ts , set_x2);
156157op2 = OptimizationProblem (of, x0, ps2)
157158
158- res2 = solve (op2, Adam (), maxiters = 10000 )
159+ res2 = solve (op2, Adam (1e-3 ), maxiters = 25_000 )
159160
160161@test res. u ≈ res2. u
0 commit comments