@@ -51,12 +51,12 @@ chain = multi_layer_feed_forward(2, 2)
5151
5252eqs = [connect (model. nn_in, nn. output)
5353 connect (model. nn_out, nn. input)]
54-
54+ eqs = [model . nn_in . u ~ nn . output . u, model . nn_out . u ~ nn . input . u]
5555ude_sys = complete (ODESystem (
5656 eqs, ModelingToolkit. t_nounits, systems = [model, nn],
5757 name = :ude_sys ))
5858
59- sys = structural_simplify (ude_sys)
59+ sys = structural_simplify (ude_sys, allow_symbolic = true )
6060
6161prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys, [], (0 , 1.0 ), [])
6262
@@ -103,7 +103,7 @@ ps = (prob, sol_ref, get_vars, get_refs, set_x);
103103@test all (.! isnan .(∇l1))
104104@test ! iszero (∇l1)
105105
106- @test ∇l1≈ ∇l2 rtol= 1e-2
106+ @test ∇l1≈ ∇l2 rtol= 1e-3
107107@test ∇l1≈ ∇l3 rtol= 1e-5
108108
109109op = OptimizationProblem (of, x0, ps)
@@ -135,3 +135,32 @@ res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t)
135135# plot!(res_sol, idxs = [sys.lotka.x, sys.lotka.y])
136136
137137@test SciMLBase. successful_retcode (res_sol)
138+
139+ function lotka_ude2 ()
140+ @variables t x (t)= 3.1 y (t)= 1.5 pred (t)[1 : 2 ]
141+ @parameters α= 1.3 [tunable = false ] δ= 1.8 [tunable = false ]
142+ chain = multi_layer_feed_forward (2 , 2 )
143+ NN, p = SymbolicNeuralNetwork (; chain, n_input = 2 , n_output = 2 , rng = StableRNG (42 ))
144+ Dt = ModelingToolkit. D_nounits
145+
146+ eqs = [pred ~ NN ([x, y], p)
147+ Dt (x) ~ α * x + pred[1 ]
148+ Dt (y) ~ - δ * y + pred[2 ]]
149+ return ODESystem (eqs, ModelingToolkit. t_nounits, name = :lotka )
150+ end
151+
152+ sys2 = structural_simplify (lotka_ude2 ())
153+
154+ prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys2, [], (0 , 1.0 ), [])
155+
156+ sol = solve (prob, Rodas5P (), abstol = 1e-10 , reltol = 1e-8 )
157+
158+ @test SciMLBase. successful_retcode (sol)
159+
160+ set_x2 = setp_oop (sys2, sys2. p)
161+ ps2 = (prob, sol_ref, get_vars, get_refs, set_x2);
162+ op2 = OptimizationProblem (of, x0, ps2)
163+
164+ res2 = solve (op2, Adam (), maxiters = 10000 )
165+
166+ @test res. u ≈ res2. u
0 commit comments