Skip to content

Advice on UDE modelling workflow (for future Catalyst implementations)Β #54

@TorkelE

Description

@TorkelE

Ok, I think I have managed to get a full working UDE workflow working based on what Chris told me. The advantage of this one (as opposed to the one on the docs) is that it will be much easier to integrate with Catalyst.

I was wondering if someone who know about this stuff better could just confirm that what I am doing makes sense. If it does, I will look at making this working in conjecture with Catalyst (don't work directly as we make some assumption of types that means the neural networks do not work, but this should be changeable).

# Fetch packages.
using Lux, ComponentArrays, ModelingToolkit, ModelingToolkitNeuralNets, OrdinaryDiffEq, Plots, StableRNGs
import ModelingToolkitNeuralNets: lazyconvert
import ModelingToolkit.t_nounits as t
import ModelingToolkit.D_nounits as D
import ModelingToolkit: getu, setp_oop
using OptimizationOptimisers: Adam

# Prepare the true model (X produced as a function of itself, and degraded lineary).
@variables X(t)
@parameters d
prod(x) = (x^2)/(0.5 + x^2)
eq = D(X) ~ prod(X) - d*X
@mtkbuild osys_true = ODESystem([eq], t)

# Generate the synthetic meassurments (samples a simulation and adds some noise) (plot comes later on).
u0_true = [X => 0.5]
tend = 10.0
ps_true = [d => 0.5]
oprob_true = ODEProblem(osys_true, u0_true, tend, ps_true)
sol_true = solve(oprob_true)
sample_t = 0.0:0.1:tend
sample_x = sol_true(sample_t; idxs = X).u .* (0.9 .+ 0.2*rand(length(sample_t)))

# Create the Neural Network.
NN = Lux.Chain(
    Lux.Dense(1 => 10, Lux.mish, use_bias = false),
    Lux.Dense(10 => 10, Lux.mish, use_bias = false),
    Lux.Dense(10 => 1, use_bias = false)
)
ps_nn_init, st_nn = Lux.setup(StableRNG(1234), NN)
ca_nn = ComponentArray{Float64}(ps_nn_init)
@parameters p_nn[1:length(ca_nn)] = Vector(ca_nn)
@parameters T_nn::typeof(typeof(ca_nn))=typeof(ca_nn) [tunable = false]

# Create the UDE model.
nn_func(x) = LuxCore.stateless_apply(NN, [x], lazyconvert(T_nn, p_nn))[1]
eq_nn = D(X) ~ nn_func(X) - d*X
@mtkbuild osys_nn = ODESystem([eq_nn], t, [X], [d, p_nn, T_nn])

# Create the optimization function (l2 distance between new simulation and data).
function loss(ps, (prob_base, sample_t, sample_x, set_ps, get_us))
    p = set_ps(prob_base, ps)
    new_prob = remake(prob_base; p)
    new_sol = solve(new_prob; saveat = sample_t)
    SciMLBase.successful_retcode(new_sol) || return Inf
    return sum(abs2.(get_us(new_sol) .- sample_x))
end
of = OptimizationFunction{true}(loss, AutoForwardDiff())

# Prepares the optimisation problem.
ps_init = [1.0; ModelingToolkit.default_values(osys_nn)[osys_nn.p_nn]]
prob_base = ODEProblem(osys_nn, u0_true, tend, [d => ps_init[1]])
set_ps = setp_oop(osys_nn, [d; osys_nn.p_nn])
get_us = getu(osys_nn, X)
opt_prob = OptimizationProblem(of, ps_init, (prob_base, sample_t, sample_x, set_ps, get_us))

# Fits the UDE to the meassured data.
function callback(opt_state, loss)
    (opt_state.iter % 100 == 0) && (@info "step $(opt_state.iter), loss: $loss")
    return false
end
@time opt_sol = solve(opt_prob, Adam(5e-3); maxiters = 5000, callback)

# Simulates the fitted solution.
fitted_ps = set_ps(prob_base, opt_sol.u)
fitted_prob = remake(prob_base, p = fitted_ps)
fitted_sol = solve(fitted_prob)
init_sol = solve(prob_base)

# Evaluate the fit (plots true solution, data, and the pre/post fitted NN model simulation).
plot(sol_true; label = "True solution", lw = 5)
plot!(sample_t, sample_x; label = "Meassured values", seriestype = :scatter, color = 1, ms = 6, markeralpha = 0.6)
plot!(fitted_sol, lw = 4, la = 0.9, linestyle = :dash, label = "Trained solution", color = :blue)
plot!(init_sol, lw = 4, la = 0.9, linestyle = :dot, label = "Pre-training solution", color = :royalblue4)

Image

# Plots the fitted/true functions.
X_grid = minimum(sample_x):0.1:maximum(sample_x)
prod_true = prod.(X_grid)
get_nn_val(x) = LuxCore.stateless_apply(NN, [x], convert(fitted_sol.ps[osys_nn.T_nn], fitted_sol.ps[osys_nn.p_nn]))[1]
prod_nn = get_nn_val.(X_grid)
plot(X_grid, prod_true; label = "True production rate", lw = 8, la = 0.8)
p2 = plot!(X_grid, prod_nn; label = "Fitted production rate", lw = 8, la = 0.8, xlimit = (X_grid[1], X_grid[end]))

Image

(in this case we can fit well, but the addition of d mean we don't necessarily recover the true function, but that is besides the point here)

Metadata

Metadata

Assignees

Labels

questionFurther information is requested

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions