Skip to content

Commit d7fec52

Browse files
test: fix initialization, account for Initial parameters in test
1 parent 420299c commit d7fec52

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

test/lotka_volterra.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using SymbolicIndexingInterface
88
using Optimization
99
using OptimizationOptimisers: Adam
1010
using SciMLStructures
11-
using SciMLStructures: Tunable
11+
using SciMLStructures: Tunable, canonicalize
1212
using ForwardDiff
1313
using StableRNGs
1414

@@ -51,7 +51,7 @@ eqs = [connect(model.nn_in, nn.output)
5151

5252
ude_sys = complete(ODESystem(
5353
eqs, ModelingToolkit.t_nounits, systems = [model, nn],
54-
name = :ude_sys, defaults = [nn.input.u => [0.0, 0.0]]))
54+
name = :ude_sys))
5555

5656
sys = structural_simplify(ude_sys)
5757

@@ -61,13 +61,14 @@ model_true = structural_simplify(lotka_true())
6161
prob_true = ODEProblem{true, SciMLBase.FullSpecialize}(model_true, [], (0, 1.0), [])
6262
sol_ref = solve(prob_true, Rodas4())
6363

64-
x0 = reduce(vcat, getindex.((default_values(sys),), tunable_parameters(sys)))
64+
x0 = default_values(sys)[nn.p]
6565

6666
get_vars = getu(sys, [sys.lotka.x, sys.lotka.y])
6767
get_refs = getu(model_true, [model_true.x, model_true.y])
68+
set_x = setp_oop(sys, nn.p)
6869

69-
function loss(x, (prob, sol_ref, get_vars, get_refs))
70-
new_p = SciMLStructures.replace(Tunable(), prob.p, x)
70+
function loss(x, (prob, sol_ref, get_vars, get_refs, set_x))
71+
new_p = set_x(prob, x)
7172
new_prob = remake(prob, p = new_p, u0 = eltype(x).(prob.u0))
7273
ts = sol_ref.t
7374
new_sol = solve(new_prob, Rodas4(), saveat = ts)
@@ -87,14 +88,14 @@ end
8788

8889
of = OptimizationFunction{true}(loss, AutoForwardDiff())
8990

90-
ps = (prob, sol_ref, get_vars, get_refs);
91+
ps = (prob, sol_ref, get_vars, get_refs, set_x);
9192

9293
@test_call target_modules=(ModelingToolkitNeuralNets,) loss(x0, ps)
9394
@test_opt target_modules=(ModelingToolkitNeuralNets,) loss(x0, ps)
9495

9596
@test all(.!isnan.(ForwardDiff.gradient(Base.Fix2(of, ps), x0)))
9697

97-
op = OptimizationProblem(of, x0, (prob, sol_ref, get_vars, get_refs))
98+
op = OptimizationProblem(of, x0, ps)
9899

99100
# using Plots
100101

@@ -114,7 +115,7 @@ res = solve(op, Adam(), maxiters = 5000)#, callback = plot_cb)
114115

115116
@test res.objective < 1
116117

117-
res_p = SciMLStructures.replace(Tunable(), prob.p, res.u)
118+
res_p = set_x(prob, res.u)
118119
res_prob = remake(prob, p = res_p)
119120
res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t)
120121

0 commit comments

Comments
 (0)