-
-
Notifications
You must be signed in to change notification settings - Fork 7
Closed
Labels
questionFurther information is requestedFurther information is requested
Description
Ok, I think I have managed to get a full working UDE workflow working based on what Chris told me. The advantage of this one (as opposed to the one on the docs) is that it will be much easier to integrate with Catalyst.
I was wondering if someone who know about this stuff better could just confirm that what I am doing makes sense. If it does, I will look at making this working in conjecture with Catalyst (don't work directly as we make some assumption of types that means the neural networks do not work, but this should be changeable).
# Fetch packages.
using Lux, ComponentArrays, ModelingToolkit, ModelingToolkitNeuralNets, OrdinaryDiffEq, Plots, StableRNGs
import ModelingToolkitNeuralNets: lazyconvert
import ModelingToolkit.t_nounits as t
import ModelingToolkit.D_nounits as D
import ModelingToolkit: getu, setp_oop
using OptimizationOptimisers: Adam
# Prepare the true model (X produced as a function of itself, and degraded lineary).
@variables X(t)
@parameters d
prod(x) = (x^2)/(0.5 + x^2)
eq = D(X) ~ prod(X) - d*X
@mtkbuild osys_true = ODESystem([eq], t)
# Generate the synthetic meassurments (samples a simulation and adds some noise) (plot comes later on).
u0_true = [X => 0.5]
tend = 10.0
ps_true = [d => 0.5]
oprob_true = ODEProblem(osys_true, u0_true, tend, ps_true)
sol_true = solve(oprob_true)
sample_t = 0.0:0.1:tend
sample_x = sol_true(sample_t; idxs = X).u .* (0.9 .+ 0.2*rand(length(sample_t)))
# Create the Neural Network.
NN = Lux.Chain(
Lux.Dense(1 => 10, Lux.mish, use_bias = false),
Lux.Dense(10 => 10, Lux.mish, use_bias = false),
Lux.Dense(10 => 1, use_bias = false)
)
ps_nn_init, st_nn = Lux.setup(StableRNG(1234), NN)
ca_nn = ComponentArray{Float64}(ps_nn_init)
@parameters p_nn[1:length(ca_nn)] = Vector(ca_nn)
@parameters T_nn::typeof(typeof(ca_nn))=typeof(ca_nn) [tunable = false]
# Create the UDE model.
nn_func(x) = LuxCore.stateless_apply(NN, [x], lazyconvert(T_nn, p_nn))[1]
eq_nn = D(X) ~ nn_func(X) - d*X
@mtkbuild osys_nn = ODESystem([eq_nn], t, [X], [d, p_nn, T_nn])
# Create the optimization function (l2 distance between new simulation and data).
function loss(ps, (prob_base, sample_t, sample_x, set_ps, get_us))
p = set_ps(prob_base, ps)
new_prob = remake(prob_base; p)
new_sol = solve(new_prob; saveat = sample_t)
SciMLBase.successful_retcode(new_sol) || return Inf
return sum(abs2.(get_us(new_sol) .- sample_x))
end
of = OptimizationFunction{true}(loss, AutoForwardDiff())
# Prepares the optimisation problem.
ps_init = [1.0; ModelingToolkit.default_values(osys_nn)[osys_nn.p_nn]]
prob_base = ODEProblem(osys_nn, u0_true, tend, [d => ps_init[1]])
set_ps = setp_oop(osys_nn, [d; osys_nn.p_nn])
get_us = getu(osys_nn, X)
opt_prob = OptimizationProblem(of, ps_init, (prob_base, sample_t, sample_x, set_ps, get_us))
# Fits the UDE to the meassured data.
function callback(opt_state, loss)
(opt_state.iter % 100 == 0) && (@info "step $(opt_state.iter), loss: $loss")
return false
end
@time opt_sol = solve(opt_prob, Adam(5e-3); maxiters = 5000, callback)
# Simulates the fitted solution.
fitted_ps = set_ps(prob_base, opt_sol.u)
fitted_prob = remake(prob_base, p = fitted_ps)
fitted_sol = solve(fitted_prob)
init_sol = solve(prob_base)
# Evaluate the fit (plots true solution, data, and the pre/post fitted NN model simulation).
plot(sol_true; label = "True solution", lw = 5)
plot!(sample_t, sample_x; label = "Meassured values", seriestype = :scatter, color = 1, ms = 6, markeralpha = 0.6)
plot!(fitted_sol, lw = 4, la = 0.9, linestyle = :dash, label = "Trained solution", color = :blue)
plot!(init_sol, lw = 4, la = 0.9, linestyle = :dot, label = "Pre-training solution", color = :royalblue4)# Plots the fitted/true functions.
X_grid = minimum(sample_x):0.1:maximum(sample_x)
prod_true = prod.(X_grid)
get_nn_val(x) = LuxCore.stateless_apply(NN, [x], convert(fitted_sol.ps[osys_nn.T_nn], fitted_sol.ps[osys_nn.p_nn]))[1]
prod_nn = get_nn_val.(X_grid)
plot(X_grid, prod_true; label = "True production rate", lw = 8, la = 0.8)
p2 = plot!(X_grid, prod_nn; label = "Fitted production rate", lw = 8, la = 0.8, xlimit = (X_grid[1], X_grid[end]))(in this case we can fit well, but the addition of d mean we don't necessarily recover the true function, but that is besides the point here)
SebastianM-C
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested

