@@ -51,12 +51,12 @@ chain = multi_layer_feed_forward(2, 2)
51
51
52
52
eqs = [connect (model. nn_in, nn. output)
53
53
connect (model. nn_out, nn. input)]
54
-
54
+ eqs = [model . nn_in . u ~ nn . output . u, model . nn_out . u ~ nn . input . u]
55
55
ude_sys = complete (ODESystem (
56
56
eqs, ModelingToolkit. t_nounits, systems = [model, nn],
57
57
name = :ude_sys ))
58
58
59
- sys = structural_simplify (ude_sys)
59
+ sys = structural_simplify (ude_sys, allow_symbolic = true )
60
60
61
61
prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys, [], (0 , 1.0 ), [])
62
62
@@ -103,7 +103,7 @@ ps = (prob, sol_ref, get_vars, get_refs, set_x);
103
103
@test all (.! isnan .(∇l1))
104
104
@test ! iszero (∇l1)
105
105
106
- @test ∇l1≈ ∇l2 rtol= 1e-2
106
+ @test ∇l1≈ ∇l2 rtol= 1e-3
107
107
@test ∇l1≈ ∇l3 rtol= 1e-5
108
108
109
109
op = OptimizationProblem (of, x0, ps)
@@ -135,3 +135,32 @@ res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t)
135
135
# plot!(res_sol, idxs = [sys.lotka.x, sys.lotka.y])
136
136
137
137
@test SciMLBase. successful_retcode (res_sol)
138
+
139
+ function lotka_ude2 ()
140
+ @variables t x (t)= 3.1 y (t)= 1.5 pred (t)[1 : 2 ]
141
+ @parameters α= 1.3 [tunable = false ] δ= 1.8 [tunable = false ]
142
+ chain = multi_layer_feed_forward (2 , 2 )
143
+ NN, p = SymbolicNeuralNetwork (; chain, n_input = 2 , n_output = 2 , rng = StableRNG (42 ))
144
+ Dt = ModelingToolkit. D_nounits
145
+
146
+ eqs = [pred ~ NN ([x, y], p)
147
+ Dt (x) ~ α * x + pred[1 ]
148
+ Dt (y) ~ - δ * y + pred[2 ]]
149
+ return ODESystem (eqs, ModelingToolkit. t_nounits, name = :lotka )
150
+ end
151
+
152
+ sys2 = structural_simplify (lotka_ude2 ())
153
+
154
+ prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys2, [], (0 , 1.0 ), [])
155
+
156
+ sol = solve (prob, Rodas5P (), abstol = 1e-10 , reltol = 1e-8 )
157
+
158
+ @test SciMLBase. successful_retcode (sol)
159
+
160
+ set_x2 = setp_oop (sys2, sys2. p)
161
+ ps2 = (prob, sol_ref, get_vars, get_refs, set_x2);
162
+ op2 = OptimizationProblem (of, x0, ps2)
163
+
164
+ res2 = solve (op2, Adam (), maxiters = 10000 )
165
+
166
+ @test res. u ≈ res2. u
0 commit comments