Skip to content

Commit e82dbb8

Browse files
committed
test: add tests for SymbolicNeuralNetwork
1 parent b182114 commit e82dbb8

File tree

1 file changed

+32
-3
lines changed

1 file changed

+32
-3
lines changed

test/lotka_volterra.jl

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ chain = multi_layer_feed_forward(2, 2)
5151

5252
eqs = [connect(model.nn_in, nn.output)
5353
connect(model.nn_out, nn.input)]
54-
54+
eqs = [model.nn_in.u ~ nn.output.u, model.nn_out.u ~ nn.input.u]
5555
ude_sys = complete(ODESystem(
5656
eqs, ModelingToolkit.t_nounits, systems = [model, nn],
5757
name = :ude_sys))
5858

59-
sys = structural_simplify(ude_sys)
59+
sys = structural_simplify(ude_sys, allow_symbolic = true)
6060

6161
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 1.0), [])
6262

@@ -103,7 +103,7 @@ ps = (prob, sol_ref, get_vars, get_refs, set_x);
103103
@test all(.!isnan.(∇l1))
104104
@test !iszero(∇l1)
105105

106-
@test ∇l1∇l2 rtol=1e-2
106+
@test ∇l1∇l2 rtol=1e-3
107107
@test ∇l1∇l3 rtol=1e-5
108108

109109
op = OptimizationProblem(of, x0, ps)
@@ -135,3 +135,32 @@ res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t)
135135
# plot!(res_sol, idxs = [sys.lotka.x, sys.lotka.y])
136136

137137
@test SciMLBase.successful_retcode(res_sol)
138+
139+
function lotka_ude2()
140+
@variables t x(t)=3.1 y(t)=1.5 pred(t)[1:2]
141+
@parameters α=1.3 [tunable = false] δ=1.8 [tunable = false]
142+
chain = multi_layer_feed_forward(2, 2)
143+
NN, p = SymbolicNeuralNetwork(; chain, n_input = 2, n_output = 2, rng = StableRNG(42))
144+
Dt = ModelingToolkit.D_nounits
145+
146+
eqs = [pred ~ NN([x, y], p)
147+
Dt(x) ~ α * x + pred[1]
148+
Dt(y) ~ -δ * y + pred[2]]
149+
return ODESystem(eqs, ModelingToolkit.t_nounits, name = :lotka)
150+
end
151+
152+
sys2 = structural_simplify(lotka_ude2())
153+
154+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys2, [], (0, 1.0), [])
155+
156+
sol = solve(prob, Rodas5P(), abstol = 1e-10, reltol = 1e-8)
157+
158+
@test SciMLBase.successful_retcode(sol)
159+
160+
set_x2 = setp_oop(sys2, sys2.p)
161+
ps2 = (prob, sol_ref, get_vars, get_refs, set_x2);
162+
op2 = OptimizationProblem(of, x0, ps2)
163+
164+
res2 = solve(op2, Adam(), maxiters = 10000)
165+
166+
@test res.u res2.u

0 commit comments

Comments
 (0)