-
-
Notifications
You must be signed in to change notification settings - Fork 7
Description
After fitting a UDE defined using ModelingToolkitNeuralNets, it could be useful to look at what kind fo function one actually has recovered (i.e. not just how the dynamic fit compares to the data in simulation). Right now this is possible, but the interface is very much non-intuitive, and could probably be made better.
Starting from the example here
Code from the example in a single block
using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D
@variables X(t) Y(t)
@parameters v=1.0 K=1.0 n=1.0 d=1.0
eqs = [D(X) ~ v * (Y^n) / (K^n + Y^n) - d*X
D(Y) ~ X - d*Y]
@mtkcompile xy_model = System(eqs, t)
using OrdinaryDiffEqDefault, Plots
u0 = [X => 2.0, Y => 0.1]
ps_true = [v => 1.1, K => 2.0, n => 3.0, d => 0.5]
sim_cond = [u0; ps_true]
tend = 45.0
oprob_true = ODEProblem(xy_model, sim_cond, (0.0, tend))
sol_true = solve(oprob_true)
plot(sol_true; lw = 6, idxs = [X, Y])
sample_t = range(0.0, tend; length = 20)
sample_X = [(0.8 + 0.4rand()) * X_sample for X_sample in sol_true(sample_t; idxs = X)]
sample_Y = [(0.8 + 0.4rand()) * Y_sample for Y_sample in sol_true(sample_t; idxs = Y)]
plot!(sample_t, sample_X, seriestype = :scatter,
label = "X (data)", color = 1, ms = 6, alpha = 0.7)
plot!(sample_t, sample_Y, seriestype = :scatter,
label = "Y (data)", color = 2, ms = 6, alpha = 0.7)
using Lux
nn_arch = Lux.Chain(
Lux.Dense(1 => 3, Lux.softplus, use_bias = false),
Lux.Dense(3 => 3, Lux.softplus, use_bias = false),
Lux.Dense(3 => 1, Lux.softplus, use_bias = false)
)
using ModelingToolkitNeuralNets
sym_nn, θ = SymbolicNeuralNetwork(; nn_p_name = :θ, chain = nn_arch, n_input = 1, n_output = 1)
sym_nn_func(x) = sym_nn([x], θ)[1]
eqs_ude = [D(X) ~ sym_nn_func(Y) - d*X
D(Y) ~ X - d*Y]
@mtkcompile xy_model_ude = System(eqs_ude, t)
function loss(ps, (oprob_base, set_ps, sample_t, sample_X, sample_Y))
p = set_ps(oprob_base, ps)
new_oprob = remake(oprob_base; p)
new_osol = solve(new_oprob; saveat = sample_t, verbose = false, maxiters = 10000)
SciMLBase.successful_retcode(new_osol) || return Inf # Simulation failed -> Inf loss.
x_error = sum((x_sim - x_data)^2 for (x_sim, x_data) in zip(new_osol[X], sample_X))
y_error = sum((y_sim - y_data)^2 for (y_sim, y_data) in zip(new_osol[Y], sample_Y))
return x_error + y_error
end
using Optimization
oprob_base = ODEProblem(xy_model_ude, u0, (0.0, tend))
set_ps = ModelingToolkit.setp_oop(oprob_base, [d, θ...])
loss_params = (oprob_base, set_ps, sample_t, sample_X, sample_Y)
ps_init = oprob_base.ps[[d, θ...]]
of = OptimizationFunction{true}(loss, AutoForwardDiff())
opt_prob = OptimizationProblem(of, ps_init, loss_params)
import OptimizationOptimisers: Adam
@time opt_sol = solve(opt_prob, Adam(0.01); maxiters = 10000)
oprob_fitted = remake(oprob_base; p = set_ps(oprob_base, opt_sol.u))
sol_fitted = solve(oprob_fitted)
plot!(sol_true; lw = 4, la = 0.7, linestyle = :dash, idxs = [X, Y], color = [:blue :red],
label = ["X (UDE)" "Y (UDE)"])We can do
# Defines Julia functions for the true and fitted functions.
true_func(y) = 1.1 * (y^3) / (2^3 + y^3)
fitted_func(y) = ModelingToolkit.getdefault(sym_nn)([y], oprob_fitted.ps[θ])[1]
# Plots the true and fitted functions (we mostly got the correct one, but less accurate in some regions).
plot(true_func, 0.0, 5.0; lw = 8, label = "True function", color = :lightblue)
plot!(fitted_func, 0.0, 5.0; lw = 6, label = "Fitted function", color = :blue, la = 0.7, linestyle = :dash)
To compare the function we recovered from data to the true on. In e.g. systems bilogy (whether one knows the true function or not) I would want to add this to see how (fitted) gene expression depends on transcription factors.
Currently, the fitted_func(y) = ModelingToolkit.getdefault(sym_nn)([y], oprob_fitted.ps[θ])[1] notation feels overly messy, and some helper function here could definitely be useful. The really simple one would just be something like
function make_function(U, p_vals)
return function(x...)
ModelingToolkit.getdefault(U)([x...], p_vals)
end
endalthough better versions to achieve this might exist. It would have been really cool if one could actually evaluate the symbolic neural network directly, but I think this by anture won't be possible (?). If you know what would make sense, I am also happy to help wit ha PR.