@@ -8,7 +8,7 @@ using SymbolicIndexingInterface
88using Optimization
99using OptimizationOptimisers: Adam
1010using SciMLStructures
11- using SciMLStructures: Tunable
11+ using SciMLStructures: Tunable, canonicalize
1212using ForwardDiff
1313using StableRNGs
1414
@@ -51,7 +51,7 @@ eqs = [connect(model.nn_in, nn.output)
5151
5252ude_sys = complete (ODESystem (
5353 eqs, ModelingToolkit. t_nounits, systems = [model, nn],
54- name = :ude_sys , defaults = [nn . input . u => [ 0.0 , 0.0 ]] ))
54+ name = :ude_sys ))
5555
5656sys = structural_simplify (ude_sys)
5757
@@ -61,13 +61,14 @@ model_true = structural_simplify(lotka_true())
6161prob_true = ODEProblem {true, SciMLBase.FullSpecialize} (model_true, [], (0 , 1.0 ), [])
6262sol_ref = solve (prob_true, Rodas4 ())
6363
64- x0 = reduce (vcat, getindex .(( default_values (sys),), tunable_parameters (sys)))
64+ x0 = default_values (sys)[nn . p]
6565
6666get_vars = getu (sys, [sys. lotka. x, sys. lotka. y])
6767get_refs = getu (model_true, [model_true. x, model_true. y])
68+ set_x = setp_oop (sys, nn. p)
6869
69- function loss (x, (prob, sol_ref, get_vars, get_refs))
70- new_p = SciMLStructures . replace ( Tunable (), prob. p , x)
70+ function loss (x, (prob, sol_ref, get_vars, get_refs, set_x ))
71+ new_p = set_x ( prob, x)
7172 new_prob = remake (prob, p = new_p, u0 = eltype (x).(prob. u0))
7273 ts = sol_ref. t
7374 new_sol = solve (new_prob, Rodas4 (), saveat = ts)
8788
8889of = OptimizationFunction {true} (loss, AutoForwardDiff ())
8990
90- ps = (prob, sol_ref, get_vars, get_refs);
91+ ps = (prob, sol_ref, get_vars, get_refs, set_x );
9192
9293@test_call target_modules= (ModelingToolkitNeuralNets,) loss (x0, ps)
9394@test_opt target_modules= (ModelingToolkitNeuralNets,) loss (x0, ps)
9495
9596@test all (.! isnan .(ForwardDiff. gradient (Base. Fix2 (of, ps), x0)))
9697
97- op = OptimizationProblem (of, x0, (prob, sol_ref, get_vars, get_refs) )
98+ op = OptimizationProblem (of, x0, ps )
9899
99100# using Plots
100101
@@ -114,7 +115,7 @@ res = solve(op, Adam(), maxiters = 5000)#, callback = plot_cb)
114115
115116@test res. objective < 1
116117
117- res_p = SciMLStructures . replace ( Tunable (), prob. p , res. u)
118+ res_p = set_x ( prob, res. u)
118119res_prob = remake (prob, p = res_p)
119120res_sol = solve (res_prob, Rodas4 (), saveat = sol_ref. t)
120121
0 commit comments