@@ -8,7 +8,8 @@ using SymbolicIndexingInterface
88using Optimization
99using OptimizationOptimisers: Adam
1010using SciMLStructures
11- using SciMLStructures: Tunable
11+ using SciMLStructures: Tunable, canonicalize
12+ using PreallocationTools
1213using ForwardDiff
1314using StableRNGs
1415
@@ -51,7 +52,7 @@ eqs = [connect(model.nn_in, nn.output)
5152
5253ude_sys = complete (ODESystem (
5354 eqs, ModelingToolkit. t_nounits, systems = [model, nn],
54- name = :ude_sys , defaults = [nn . input . u => [ 0.0 , 0.0 ]] ))
55+ name = :ude_sys ))
5556
5657sys = structural_simplify (ude_sys)
5758
@@ -61,13 +62,18 @@ model_true = structural_simplify(lotka_true())
6162prob_true = ODEProblem {true, SciMLBase.FullSpecialize} (model_true, [], (0 , 1.0 ), [])
6263sol_ref = solve (prob_true, Rodas4 ())
6364
64- x0 = reduce (vcat, getindex .(( default_values (sys),), tunable_parameters (sys)))
65+ x0 = default_values (sys)[nn . p]
6566
6667get_vars = getu (sys, [sys. lotka. x, sys. lotka. y])
6768get_refs = getu (model_true, [model_true. x, model_true. y])
68-
69- function loss (x, (prob, sol_ref, get_vars, get_refs))
70- new_p = SciMLStructures. replace (Tunable (), prob. p, x)
69+ set_x = setu (sys, nn. p)
70+ diffcache = DiffCache (canonicalize (Tunable (), parameter_values (prob))[1 ])
71+
72+ function loss (x, (prob, sol_ref, get_vars, get_refs, set_x, diffcache))
73+ tunables = get_tmp (diffcache, x)
74+ copyto! (tunables, canonicalize (Tunable (), prob. p)[1 ])
75+ new_p = SciMLStructures. replace (Tunable (), prob. p, tunables)
76+ set_x (new_p, x)
7177 new_prob = remake (prob, p = new_p, u0 = eltype (x).(prob. u0))
7278 ts = sol_ref. t
7379 new_sol = solve (new_prob, Rodas4 (), saveat = ts)
8793
8894of = OptimizationFunction {true} (loss, AutoForwardDiff ())
8995
90- ps = (prob, sol_ref, get_vars, get_refs);
96+ ps = (prob, sol_ref, get_vars, get_refs, set_x, diffcache );
9197
9298@test_call target_modules= (ModelingToolkitNeuralNets,) loss (x0, ps)
9399@test_opt target_modules= (ModelingToolkitNeuralNets,) loss (x0, ps)
94100
95101@test all (.! isnan .(ForwardDiff. gradient (Base. Fix2 (of, ps), x0)))
96102
97- op = OptimizationProblem (of, x0, (prob, sol_ref, get_vars, get_refs) )
103+ op = OptimizationProblem (of, x0, ps )
98104
99105# using Plots
100106
@@ -114,7 +120,8 @@ res = solve(op, Adam(), maxiters = 5000)#, callback = plot_cb)
114120
115121@test res. objective < 1
116122
117- res_p = SciMLStructures. replace (Tunable (), prob. p, res. u)
123+ res_p = copy (prob. p)
124+ set_x (res_p, res. u)
118125res_prob = remake (prob, p = res_p)
119126res_sol = solve (res_prob, Rodas4 (), saveat = sol_ref. t)
120127
0 commit comments