@@ -51,12 +51,12 @@ chain = multi_layer_feed_forward(2, 2)
5151
5252eqs = [connect (model. nn_in, nn. output)
5353 connect (model. nn_out, nn. input)]
54-
54+ eqs = [model . nn_in . u ~ nn . output . u, model . nn_out . u ~ nn . input . u]
5555ude_sys = complete (ODESystem (
5656 eqs, ModelingToolkit. t_nounits, systems = [model, nn],
5757 name = :ude_sys ))
5858
59- sys = structural_simplify (ude_sys)
59+ sys = structural_simplify (ude_sys, allow_symbolic = true )
6060
6161prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys, [], (0 , 1.0 ), [])
6262
@@ -103,7 +103,7 @@ ps = (prob, sol_ref, get_vars, get_refs, set_x);
103103@test all (.! isnan .(∇l1))
104104@test ! iszero (∇l1)
105105
106- @test ∇l1≈ ∇l2 rtol= 1e-2
106+ @test ∇l1≈ ∇l2 rtol= 1e-3
107107@test ∇l1≈ ∇l3 rtol= 1e-5
108108
109109op = OptimizationProblem (of, x0, ps)
@@ -135,3 +135,34 @@ res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t)
135135# plot!(res_sol, idxs = [sys.lotka.x, sys.lotka.y])
136136
137137@test SciMLBase. successful_retcode (res_sol)
138+
139+ function lotka_ude2 ()
140+ @variables t x (t)= 3.1 y (t)= 1.5 pred (t)[1 : 2 ]
141+ @parameters α= 1.3 [tunable = false ] δ= 1.8 [tunable = false ]
142+ chain = multi_layer_feed_forward (2 , 2 )
143+ NN, p = SymbolicNeuralNetwork (; chain, n_input= 2 , n_output= 2 , rng = StableRNG (42 ))
144+ Dt = ModelingToolkit. D_nounits
145+
146+ eqs = [
147+ pred ~ NN ([x, y], p)
148+ Dt (x) ~ α * x + pred[1 ]
149+ Dt (y) ~ - δ * y + pred[2 ]
150+ ]
151+ return ODESystem (eqs, ModelingToolkit. t_nounits, name = :lotka )
152+ end
153+
154+ sys2 = structural_simplify (lotka_ude2 ())
155+
156+ prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys2, [], (0 , 1.0 ), [])
157+
158+ sol = solve (prob, Rodas5P (), abstol = 1e-10 , reltol = 1e-8 )
159+
160+ @test SciMLBase. successful_retcode (sol)
161+
162+ set_x2 = setp_oop (sys, sys. p)
163+ ps2 = (prob, sol_ref, get_vars, get_refs, set_x2);
164+ op2 = OptimizationProblem (of, x0, ps2)
165+
166+ res2 = solve (op2, Adam (), maxiters = 10000 )
167+
168+ @test res. u ≈ res2. u
0 commit comments