From 576b0d5deeac35efb3d75679dcc2e7b38356031a Mon Sep 17 00:00:00 2001 From: Sharv <72983931+sharvmurgai@users.noreply.github.com> Date: Fri, 19 Sep 2025 23:55:10 -0700 Subject: [PATCH 1/4] Create stiff_nn_bruss.md --- docs/src/examples/stiff_nn_bruss.md | 420 ++++++++++++++++++++++++++++ 1 file changed, 420 insertions(+) create mode 100644 docs/src/examples/stiff_nn_bruss.md diff --git a/docs/src/examples/stiff_nn_bruss.md b/docs/src/examples/stiff_nn_bruss.md new file mode 100644 index 000000000..84a0f0182 --- /dev/null +++ b/docs/src/examples/stiff_nn_bruss.md @@ -0,0 +1,420 @@ +# Learning the Brusselator Equation with a Universal Differential Equation (UDE) + +This document walks through an example of using a Universal Differential Equation (UDE) to learn the dynamics of a reaction-diffusion system known as the Brusselator. + +### Context: The Brusselator and UDEs + +The **Brusselator** is a partial differential equation (PDE) that models a theoretical autocatalytic chemical reaction. It describes how the concentrations of two chemical species evolve over space and time, governed by two main processes: +1. **Reaction**: The species interact with each other, changing their concentrations locally. +2. **Diffusion**: The species spread out over the spatial domain. + +A **Universal Differential Equation (UDE)** is a hybrid modeling approach that merges known physical laws with machine learning. The core idea is to encode the parts of the system you understand (e.g., diffusion) directly into the equations and use a neural network to learn the parts you don't (e.g., the complex reaction kinetics). + +In this example, we will: +1. Generate "ground truth" data by solving the full, known Brusselator PDE. +2. Define a UDE where the diffusion term is explicitly coded, but the reaction term is replaced by a neural network. +3. Train the neural network's parameters by requiring the UDE's solution to match the ground truth data. +4. Visualize the results to confirm that our UDE has successfully learned the unknown reaction dynamics. + +This showcases the power of scientific machine learning (SciML) to discover governing equations from data. + +--- + +### 1. Problem Setup and Dependencies + +First, we import the necessary Julia libraries. We'll use `DifferentialEquations.jl` for solving ODEs, `Lux.jl` for the neural network, `Optimization.jl` for training, and `Plots.jl` for visualization. We then define the simulation constants, such as grid size and simulation time. + + +```@example stiff_bruss +using OrdinaryDiffEq, DifferentialEquations +using LinearSolve +using SciMLSensitivity +using Lux +using Optimization, OptimizationOptimisers +using Random, Zygote +using Plots, Statistics + +# We disable the default plot saving and display them directly. +default(show = true) + +# --------------------------------------------------------------------------- +# ## 1. Problem Setup: Constants, Grid, and Initial Conditions +# --------------------------------------------------------------------------- + +# -- Simulation Parameters -- +const N = 16 # Grid size will be N x N +const TEND = 3.0f0 # End time of the simulation +const MAXITERS = 50 # Number of training iterations for the optimizer +const H = 16 # Hidden layer size for the neural network + +# -- Grid and Discretization -- +const xyd = range(0.0f0, 1.0f0, length = N) # Spatial domain +const dx = step(xyd) # Spatial step size + +""" + limit(a::Int, N::Int) + +Enforces periodic boundary conditions by wrapping indices around the grid. +If an index `a` goes past the boundary (1 or N), it wraps to the other side. +""" +@inline limit(a::Int, N::Int) = a == N + 1 ? 1 : (a == 0 ? N : a) + +""" + brusselator_f(x, y, t) + +A forcing term for the Brusselator equation, which is active in a circular +region for t ≥ 1.1. +""" +@inline function brusselator_f(x, y, t) + forcing = (((x - 0.3f0)^2 + (y - 0.6f0)^2) <= 0.1f0^2) && (t >= 1.1f0) + return forcing ? 5.0f0 : 0.0f0 +end + +""" + init_u0(xyd) + +Generates the initial condition `u0` for the two species on the grid. +""" +function init_u0(xyd) + N = length(xyd) + u = zeros(Float32, N, N, 2) + @inbounds for I in CartesianIndices((N, N)) + x = xyd[I[1]] + y = xyd[I[2]] + u[I, 1] = 22.0f0 * (y * (1.0f0 - y))^(3.0f0 / 2.0f0) + u[I, 2] = 27.0f0 * (x * (1.0f0 - x))^(3.0f0 / 2.0f0) + end + return u +end + +# Initialize the state vector `u0` +u0 = init_u0(xyd) +``` + +### 2. Generating the Reference Solution (Ground Truth) +To train our UDE, we need data to learn from. We generate this by solving the full Brusselator PDE with its known equations. This solution will serve as our "ground truth" that we will try to replicate with the UDE. The rhs_ref! function defines the complete dynamics, including both diffusion and reaction terms. + +```@example stiff_bruss +# --------------------------------------------------------------------------- +# ## 2. Reference Solution (Ground Truth) +# --------------------------------------------------------------------------- +# Here, we solve the full PDE with the known reaction terms to generate +# the data we will use to train our neural network. + +# -- Brusselator PDE Parameters -- +const A = 3.4f0 +const B = 1.0f0 +const α = 10.0f0 +const αdx2 = α / dx^2 + +""" + rhs_ref!(du, u, p, t) + +The right-hand side (RHS) function for the Brusselator PDE, including both the +diffusion (Laplacian) and the known reaction terms. This defines the true dynamics. +""" +function rhs_ref!(du, u, p, t) + @inbounds for I in CartesianIndices((N, N)) + i, j = Tuple(I) + x, y = xyd[i], xyd[j] + + # Neighbor indices with periodic boundaries + ip1 = limit(i + 1, N) + im1 = limit(i - 1, N) + jp1 = limit(j + 1, N) + jm1 = limit(j - 1, N) + + u1 = u[i, j, 1] + v1 = u[i, j, 2] + + # Discretized Laplacian for diffusion + lap_u = (u[im1, j, 1] + u[ip1, j, 1] + u[i, jm1, 1] + u[i, jp1, 1] - 4.0f0 * u1) + lap_v = (u[im1, j, 2] + u[ip1, j, 2] + u[i, jm1, 2] + u[i, jp1, 2] - 4.0f0 * v1) + + # Known Brusselator reaction terms + reaction1 = B + u1 * u1 * v1 - (A + 1.0f0) * u1 + reaction2 = A * u1 - u1 * u1 * v1 + + # Combine diffusion, reaction, and forcing term + du[i, j, 1] = αdx2 * lap_u + reaction1 + brusselator_f(x, y, t) + du[i, j, 2] = αdx2 * lap_v + reaction2 + end + return nothing +end + +println("Stage 1/4: Generating reference solution...") +prob_ref = ODEProblem(rhs_ref!, u0, (0.0f0, TEND)) +sol_ref = solve(prob_ref, KenCarp47(linsolve = KrylovJL_GMRES()); + saveat = 0.0f0:0.5f0:TEND, reltol = 1e-5, abstol = 1e-5, + save_everystep = false, progress = true) + +# Store the reference solution and time points for training and comparison +Yref = Array(sol_ref) +ts = sol_ref.t +println("Reference solution generated.") +``` + +### 3. Defining the Neural Network +Next, we define the neural network architecture that will learn the unknown reaction term. The model is a simple multi-layer perceptron with a custom SigmaLayer that applies a learnable, exponentially decaying weight to its inputs. We also create helper functions (flatten_ps, to_ps) to convert the network's parameters between Lux's structured format and the flat vector format required by the optimizer. + +```@example stiff_bruss +# --------------------------------------------------------------------------- +# ## 3. Neural Network (UDE Component) +# --------------------------------------------------------------------------- +# This section defines the neural network architecture that will learn the +# unknown reaction term. + +""" + SigmaLayer + +A custom Lux layer that applies a learnable, exponentially decaying weight +to the activations. The decay rate `p` is a learnable parameter. +""" +struct SigmaLayer <: Lux.AbstractLuxLayer end + +# Initialize the layer's parameters and state +Lux.initialparameters(rng::AbstractRNG, ::SigmaLayer) = (p = 2.0f0,) +Lux.initialstates(rng::AbstractRNG, ::SigmaLayer) = NamedTuple() + +# Define the layer's forward pass +function (ℓ::SigmaLayer)(z, ps, st) + H = size(z, 1) # Height (number of neurons) + Tz = eltype(z) + # Use softplus to ensure the decay rate `σ` is positive + σ = NNlib.softplus(Tz(ps.p)) + decay = exp.(-σ .* (1:H)) + + # Apply decay, broadcasting over a batch if necessary + z = z .* reshape(decay, H, ntuple(Returns(1), ndims(z) - 1)...) + return z, st +end + +# -- Define the full neural network model -- +# The model takes the concentrations of the two species `[u, v]` as input +# and outputs the predicted reaction terms `[reaction1, reaction2]`. +model = Chain(Dense(2 => H, tanh), SigmaLayer(), Dense(H => 2)) + +# Initialize model parameters (ps) and state (st) +rng = Random.default_rng() +ps0, st0 = Lux.setup(rng, model) +const ST = st0 # State is constant during training + +""" + flatten_ps(ps) + +Converts the nested Lux parameter structure `ps` into a flat `Vector{Float32}`. +This is required by the `Optimization.jl` interface. +""" +function flatten_ps(ps)::Vector{Float32} + w1 = vec(ps.layer_1.weight) + b1 = ps.layer_1.bias + p = ps.layer_2.p + w3 = vec(ps.layer_3.weight) + b3 = ps.layer_3.bias + return vcat(w1, b1, p, w3, b3) +end + +""" + to_ps(θ::AbstractVector) + +Reconstructs the Lux parameter structure `ps` from a flat vector `θ`. +""" +function to_ps(θ::AbstractVector) + # Calculate expected length to catch errors if model architecture changes + expected_len = (2 * H) + H + 1 + (H * 2) + 2 # W1 + b1 + p2 + W3 + b3 + @assert length(θ) == expected_len "Incorrect parameter vector length." + + T = eltype(θ) + i = 1 + # Layer 1: Dense + w1_end = i + 2 * H - 1 + w1 = reshape(θ[i:w1_end], H, 2) + i = w1_end + 1 + # Layer 1: Bias + b1_end = i + H - 1 + b1 = θ[i:b1_end] + i = b1_end + 1 + # Layer 2: SigmaLayer + p2 = θ[i] + i += 1 + # Layer 3: Dense + w3_end = i + 2 * H - 1 + w3 = reshape(θ[i:w3_end], 2, H) + i = w3_end + 1 + # Layer 3: Bias + b3 = θ[i:end] + + return (layer_1 = (weight = T.(w1), bias = T.(b1)), + layer_2 = (p = T(p2),), + layer_3 = (weight = T.(w3), bias = T.(b3))) +end + +# Get the initial flat parameter vector +θ0 = flatten_ps(ps0) +``` + +### 4. Constructing the Universal Differential Equation (UDE) +Here is the core of the UDE. The rhs_ude! function defines the hybrid dynamics. It explicitly calculates the diffusion term (the known physics) and calls the neural network to approximate the reaction term (the unknown physics). This function is then used to create an ODEProblem that can be solved and differentiated. + +```@example stiff_bruss +# --------------------------------------------------------------------------- +# ## 4. Universal Differential Equation (UDE) +# --------------------------------------------------------------------------- +# The UDE combines the known physics (diffusion) with the neural network. + +const COUT = 5.0f0 # Clamp NN output to prevent explosions during training + +""" + rhs_ude!(du, u, θ_vec, t) + +The RHS function for the UDE. It computes the diffusion term analytically and +uses the neural network `model` (with parameters `θ_vec`) to approximate the +reaction term. +""" +function rhs_ude!(du, u, θ_vec, t) + psθ = to_ps(θ_vec) # Reconstruct NN parameters for Lux + Tz = eltype(u) + + @inbounds for I in CartesianIndices((N, N)) + i, j = Tuple(I) + x, y = xyd[i], xyd[j] + + # Neighbor indices with periodic boundaries + ip1 = limit(i + 1, N) + im1 = limit(i - 1, N) + jp1 = limit(j + 1, N) + jm1 = limit(j - 1, N) + + u1 = u[i, j, 1] + v1 = u[i, j, 2] + + # Part 1: Known Physics (Diffusion) + lap_u = (u[im1, j, 1] + u[ip1, j, 1] + u[i, jm1, 1] + u[i, jp1, 1] - 4.0f0 * u1) + lap_v = (u[im1, j, 2] + u[ip1, j, 2] + u[i, jm1, 2] + u[i, jp1, 2] - 4.0f0 * v1) + + # Part 2: Unknown Physics (Learned by NN) + # Input to the NN is the state [u1, v1] at a single grid point + nn_input = Tz[u1, v1] + reaction_pred, _ = model(nn_input, psθ, ST) + + # Clamp the NN output for stability + y1 = clamp(reaction_pred[1], -COUT, COUT) + y2 = clamp(reaction_pred[2], -COUT, COUT) + + # Combine known physics, learned reaction, and forcing term + du[i, j, 1] = αdx2 * lap_u + y1 + brusselator_f(x, y, t) + du[i, j, 2] = αdx2 * lap_v + y2 + end + return nothing +end + +# Define the ODE problem for the UDE, passing the NN parameters `θ0` +prob_ude = ODEProblem(rhs_ude!, u0, (0.0f0, TEND), θ0) +``` + +### 5. Training the UDE +With the UDE defined, we can now train it. The loss function solves the UDE with the current neural network parameters and computes the mean squared error against the reference data. We use Optimization.jl with the Adam optimizer to minimize this loss. SciMLSensitivity.jl provides the magic to efficiently compute gradients of the loss function with respect to the network parameters, even though the parameters are inside a differential equation solver. + +```@example stiff_bruss +# --------------------------------------------------------------------------- +# ## 5. Training the UDE +# --------------------------------------------------------------------------- + +println("\nStage 2/4: Setting up loss function and optimizer...") + +# Define the sensitivity algorithm for calculating gradients efficiently +sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP()) + +""" + loss(θ_vec) + +Computes the mean squared error between the UDE solution (using parameters `θ_vec`) +and the ground truth solution `Yref`. +""" +function loss(θ_vec) + # Solve the UDE with the current parameter vector + sol = solve(remake(prob_ude; p = θ_vec), KenCarp47(linsolve = KrylovJL_GMRES()); + saveat = ts, reltol = 1e-4, abstol = 1e-4, + save_everystep = false, sensealg = sensealg) + + # Return Inf if the solver failed to produce a solution of the correct size + if size(sol) != size(Yref) + return Inf32 + end + + # Return the mean squared error + Y = Array(sol) + return sum(abs2, Y .- Yref) / length(Yref) +end + +# -- Setup the optimization problem -- +optf = OptimizationFunction((θ, _) -> loss(θ), AutoZygote()) +optprob = OptimizationProblem(optf, θ0) + +# -- Define a callback to monitor training progress -- +println("Stage 3/4: Starting training...") +k_iter = 0 +function cb(θ, f_val) + global k_iter += 1 + if k_iter % 5 == 0 + println(" Iter: $(k_iter) \t Loss: $(round(f_val, digits=6))") + flush(stdout) + end + # Return false to continue optimization + return false +end + +# -- Run the optimization -- +solopt = solve(optprob, Optimisers.Adam(1e-2); maxiters = MAXITERS, callback = cb) +θ★ = solopt.u # The optimal parameters + +println("Training finished.") +``` + +### 6. Evaluation and Visualization +Finally, after training, we evaluate the performance of our UDE. We solve the UDE one last time using the final, optimized parameters (θ★). We then create two plots to compare the UDE's solution to the ground truth: + +A heatmap showing the spatial concentration of one species at the final time point. + +A time-series plot showing the evolution of the mean concentration over the entire simulation. + +If the training was successful, the UDE's output should closely match the true simulation. + +```@example stiff_bruss +# --------------------------------------------------------------------------- +# ## 6. Evaluation and Visualization +# --------------------------------------------------------------------------- + +println("\nStage 4/4: Evaluating final model and generating plots...") + +# Solve the UDE one last time with the optimized parameters `θ★` +sol_ude = solve(remake(prob_ude; p = θ★), KenCarp47(linsolve = KrylovJL_GMRES()); + saveat = ts, reltol = 1e-5, abstol = 1e-5, save_everystep = false) + +# Calculate the final relative mean squared error +final_loss = sum(abs2, Array(sol_ude) .- Yref) / sum(abs2, Yref) +println("Done. Final relative MSE = ", final_loss) + +# -- Create comparison plots -- + +# 1. Heatmap comparison of the final state +final_state_true = sol_ref.u[end][:, :, 1] +final_state_ude = sol_ude.u[end][:, :, 1] + +p1 = heatmap(final_state_true, title = "True Simulation (t=$(TEND))") +p2 = heatmap(final_state_ude, title = "Final UDE (it=$(k_iter))") +comparison_plot = plot(p1, p2, layout = (1, 2), size = (900, 400)) +display(comparison_plot) + +# 2. Time series comparison of the mean concentration +mean_true = [mean(u[:, :, 1]) for u in sol_ref.u] +mean_ude = [mean(u[:, :, 1]) for u in sol_ude.u] + +metric_plot = plot(ts, mean_true, label = "True Simulation", lw = 2, + xlabel = "Time (t)", ylabel = "Mean Concentration", + title = "Model Performance vs. Ground Truth") +plot!(ts, mean_ude, label = "UDE Prediction", lw = 2, linestyle = :dash) +display(metric_plot) + +println("\nPlots are displayed.") +``` From d185c7d92628f130aaca5689ae6466e5e90f9c7f Mon Sep 17 00:00:00 2001 From: Sharv <72983931+sharvmurgai@users.noreply.github.com> Date: Fri, 19 Sep 2025 23:57:18 -0700 Subject: [PATCH 2/4] Update stiff_nn_bruss.md --- docs/src/examples/stiff_nn_bruss.md | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/docs/src/examples/stiff_nn_bruss.md b/docs/src/examples/stiff_nn_bruss.md index 84a0f0182..d06d5d0e2 100644 --- a/docs/src/examples/stiff_nn_bruss.md +++ b/docs/src/examples/stiff_nn_bruss.md @@ -37,9 +37,7 @@ using Plots, Statistics # We disable the default plot saving and display them directly. default(show = true) -# --------------------------------------------------------------------------- -# ## 1. Problem Setup: Constants, Grid, and Initial Conditions -# --------------------------------------------------------------------------- +# 1. Problem Setup: Constants, Grid, and Initial Conditions # -- Simulation Parameters -- const N = 16 # Grid size will be N x N @@ -95,9 +93,8 @@ u0 = init_u0(xyd) To train our UDE, we need data to learn from. We generate this by solving the full Brusselator PDE with its known equations. This solution will serve as our "ground truth" that we will try to replicate with the UDE. The rhs_ref! function defines the complete dynamics, including both diffusion and reaction terms. ```@example stiff_bruss -# --------------------------------------------------------------------------- -# ## 2. Reference Solution (Ground Truth) -# --------------------------------------------------------------------------- +# 2. Reference Solution (Ground Truth) + # Here, we solve the full PDE with the known reaction terms to generate # the data we will use to train our neural network. @@ -158,9 +155,8 @@ println("Reference solution generated.") Next, we define the neural network architecture that will learn the unknown reaction term. The model is a simple multi-layer perceptron with a custom SigmaLayer that applies a learnable, exponentially decaying weight to its inputs. We also create helper functions (flatten_ps, to_ps) to convert the network's parameters between Lux's structured format and the flat vector format required by the optimizer. ```@example stiff_bruss -# --------------------------------------------------------------------------- -# ## 3. Neural Network (UDE Component) -# --------------------------------------------------------------------------- +# 3. Neural Network (UDE Component) + # This section defines the neural network architecture that will learn the # unknown reaction term. @@ -257,9 +253,8 @@ end Here is the core of the UDE. The rhs_ude! function defines the hybrid dynamics. It explicitly calculates the diffusion term (the known physics) and calls the neural network to approximate the reaction term (the unknown physics). This function is then used to create an ODEProblem that can be solved and differentiated. ```@example stiff_bruss -# --------------------------------------------------------------------------- -# ## 4. Universal Differential Equation (UDE) -# --------------------------------------------------------------------------- +# 4. Universal Differential Equation (UDE) + # The UDE combines the known physics (diffusion) with the neural network. const COUT = 5.0f0 # Clamp NN output to prevent explosions during training @@ -316,9 +311,7 @@ prob_ude = ODEProblem(rhs_ude!, u0, (0.0f0, TEND), θ0) With the UDE defined, we can now train it. The loss function solves the UDE with the current neural network parameters and computes the mean squared error against the reference data. We use Optimization.jl with the Adam optimizer to minimize this loss. SciMLSensitivity.jl provides the magic to efficiently compute gradients of the loss function with respect to the network parameters, even though the parameters are inside a differential equation solver. ```@example stiff_bruss -# --------------------------------------------------------------------------- -# ## 5. Training the UDE -# --------------------------------------------------------------------------- +# 5. Training the UDE println("\nStage 2/4: Setting up loss function and optimizer...") @@ -381,9 +374,7 @@ A time-series plot showing the evolution of the mean concentration over the enti If the training was successful, the UDE's output should closely match the true simulation. ```@example stiff_bruss -# --------------------------------------------------------------------------- -# ## 6. Evaluation and Visualization -# --------------------------------------------------------------------------- +# 6. Evaluation and Visualization println("\nStage 4/4: Evaluating final model and generating plots...") From 132db71417b1f9653222500d7abf5adfca8db2ac Mon Sep 17 00:00:00 2001 From: Sharv <72983931+sharvmurgai@users.noreply.github.com> Date: Sun, 21 Sep 2025 03:03:54 -0700 Subject: [PATCH 3/4] Update stiff_nn_bruss.md - Fixed SigmaLayer, and the wrapped parameters (components) --- docs/src/examples/stiff_nn_bruss.md | 315 ++++++++++++---------------- 1 file changed, 136 insertions(+), 179 deletions(-) diff --git a/docs/src/examples/stiff_nn_bruss.md b/docs/src/examples/stiff_nn_bruss.md index d06d5d0e2..472e19c7e 100644 --- a/docs/src/examples/stiff_nn_bruss.md +++ b/docs/src/examples/stiff_nn_bruss.md @@ -31,8 +31,13 @@ using LinearSolve using SciMLSensitivity using Lux using Optimization, OptimizationOptimisers +using StaticArrays +using NNlib using Random, Zygote using Plots, Statistics +using Base.Threads +using ComponentArrays +using ReverseDiff # We disable the default plot saving and display them directly. default(show = true) @@ -40,10 +45,11 @@ default(show = true) # 1. Problem Setup: Constants, Grid, and Initial Conditions # -- Simulation Parameters -- -const N = 16 # Grid size will be N x N -const TEND = 3.0f0 # End time of the simulation -const MAXITERS = 50 # Number of training iterations for the optimizer -const H = 16 # Hidden layer size for the neural network +const N = 32 +const TEND = 11.5f0 +const MAXITERS = 200 +const RTOL_REF = 1f-6 +const ATOL_REF = 1f-6 # -- Grid and Discretization -- const xyd = range(0.0f0, 1.0f0, length = N) # Spatial domain @@ -63,10 +69,7 @@ If an index `a` goes past the boundary (1 or N), it wraps to the other side. A forcing term for the Brusselator equation, which is active in a circular region for t ≥ 1.1. """ -@inline function brusselator_f(x, y, t) - forcing = (((x - 0.3f0)^2 + (y - 0.6f0)^2) <= 0.1f0^2) && (t >= 1.1f0) - return forcing ? 5.0f0 : 0.0f0 -end +@inline brusselator_f(x, y, t) = (((x - 0.3)^2 + (y - 0.6)^2) <= 0.1^2) * (t >= 1.1) * 5.0 """ init_u0(xyd) @@ -77,8 +80,7 @@ function init_u0(xyd) N = length(xyd) u = zeros(Float32, N, N, 2) @inbounds for I in CartesianIndices((N, N)) - x = xyd[I[1]] - y = xyd[I[2]] + x = Float32(xyd[I[1]]); y = Float32(xyd[I[2]]) u[I, 1] = 22.0f0 * (y * (1.0f0 - y))^(3.0f0 / 2.0f0) u[I, 2] = 27.0f0 * (x * (1.0f0 - x))^(3.0f0 / 2.0f0) end @@ -99,8 +101,6 @@ To train our UDE, we need data to learn from. We generate this by solving the fu # the data we will use to train our neural network. # -- Brusselator PDE Parameters -- -const A = 3.4f0 -const B = 1.0f0 const α = 10.0f0 const αdx2 = α / dx^2 @@ -110,45 +110,31 @@ const αdx2 = α / dx^2 The right-hand side (RHS) function for the Brusselator PDE, including both the diffusion (Laplacian) and the known reaction terms. This defines the true dynamics. """ +println("Stage 1/4: Generating reference solution...") function rhs_ref!(du, u, p, t) + A, B, alpha_val = p + alpha_val = alpha_val / dx^2 @inbounds for I in CartesianIndices((N, N)) i, j = Tuple(I) - x, y = xyd[i], xyd[j] - - # Neighbor indices with periodic boundaries - ip1 = limit(i + 1, N) - im1 = limit(i - 1, N) - jp1 = limit(j + 1, N) - jm1 = limit(j - 1, N) - - u1 = u[i, j, 1] - v1 = u[i, j, 2] - - # Discretized Laplacian for diffusion - lap_u = (u[im1, j, 1] + u[ip1, j, 1] + u[i, jm1, 1] + u[i, jp1, 1] - 4.0f0 * u1) - lap_v = (u[im1, j, 2] + u[ip1, j, 2] + u[i, jm1, 2] + u[i, jp1, 2] - 4.0f0 * v1) - - # Known Brusselator reaction terms - reaction1 = B + u1 * u1 * v1 - (A + 1.0f0) * u1 - reaction2 = A * u1 - u1 * u1 * v1 - - # Combine diffusion, reaction, and forcing term - du[i, j, 1] = αdx2 * lap_u + reaction1 + brusselator_f(x, y, t) - du[i, j, 2] = αdx2 * lap_v + reaction2 + x, y = xyd[I[1]], xyd[I[2]] + ip1, im1, jp1, jm1 = limit(i + 1, N), limit(i - 1, N), limit(j + 1, N), limit(j - 1, N) + du[i, j, 1] = alpha_val * (u[im1, j, 1] + u[ip1, j, 1] + u[i, jp1, 1] + u[i, jm1, 1] - 4u[i, j, 1]) + + B + u[i, j, 1]^2 * u[i, j, 2] - (A + 1) * u[i, j, 1] + + brusselator_f(x, y, t) + du[i, j, 2] = alpha_val * (u[im1, j, 2] + u[ip1, j, 2] + u[i, jp1, 2] + u[i, jm1, 2] - 4u[i, j, 2]) + + A * u[i, j, 1] - u[i, j, 1]^2 * u[i, j, 2] end - return nothing end -println("Stage 1/4: Generating reference solution...") -prob_ref = ODEProblem(rhs_ref!, u0, (0.0f0, TEND)) -sol_ref = solve(prob_ref, KenCarp47(linsolve = KrylovJL_GMRES()); - saveat = 0.0f0:0.5f0:TEND, reltol = 1e-5, abstol = 1e-5, - save_everystep = false, progress = true) - -# Store the reference solution and time points for training and comparison -Yref = Array(sol_ref) -ts = sol_ref.t -println("Reference solution generated.") +p_ref = (3.4, 1.0, 10.0) +prob_ref = ODEProblem(rhs_ref!, u0, (0.0, TEND), p_ref) +sol_ref = solve(prob_ref, KenCarp47(linsolve=KrylovJL_GMRES()); + saveat=0.0:0.5:TEND, reltol=RTOL_REF, abstol=ATOL_REF, progress=true) + +const Yref = Array(sol_ref) +const ts = sol_ref.t +const mean_true = [mean(Yref[:,:,1,i]) for i in 1:size(Yref, 4)] +println("Ground truth generated. Size: ", size(Yref)) ``` ### 3. Defining the Neural Network @@ -160,93 +146,89 @@ Next, we define the neural network architecture that will learn the unknown reac # This section defines the neural network architecture that will learn the # unknown reaction term. +import LuxCore: initialparameters, initialstates +using Random, Lux, ComponentArrays + +const H = 16 + +# --- 1. Define the Custom Neural Network Layer --- """ - SigmaLayer + SigmaLayerNN{M} <: Lux.AbstractLuxLayer -A custom Lux layer that applies a learnable, exponentially decaying weight -to the activations. The decay rate `p` is a learnable parameter. +A custom Lux layer that contains an internal neural network (`net`). This +internal net learns the stiffening values (`σ`) which are then applied to the layer's input. """ -struct SigmaLayer <: Lux.AbstractLuxLayer end - -# Initialize the layer's parameters and state -Lux.initialparameters(rng::AbstractRNG, ::SigmaLayer) = (p = 2.0f0,) -Lux.initialstates(rng::AbstractRNG, ::SigmaLayer) = NamedTuple() - -# Define the layer's forward pass -function (ℓ::SigmaLayer)(z, ps, st) - H = size(z, 1) # Height (number of neurons) - Tz = eltype(z) - # Use softplus to ensure the decay rate `σ` is positive - σ = NNlib.softplus(Tz(ps.p)) - decay = exp.(-σ .* (1:H)) - - # Apply decay, broadcasting over a batch if necessary - z = z .* reshape(decay, H, ntuple(Returns(1), ndims(z) - 1)...) - return z, st +struct SigmaLayerNN{M} <: Lux.AbstractLuxLayer + net::M end -# -- Define the full neural network model -- -# The model takes the concentrations of the two species `[u, v]` as input -# and outputs the predicted reaction terms `[reaction1, reaction2]`. -model = Chain(Dense(2 => H, tanh), SigmaLayer(), Dense(H => 2)) - -# Initialize model parameters (ps) and state (st) -rng = Random.default_rng() -ps0, st0 = Lux.setup(rng, model) -const ST = st0 # State is constant during training - """ - flatten_ps(ps) + SigmaLayerNN(H::Int) -Converts the nested Lux parameter structure `ps` into a flat `Vector{Float32}`. -This is required by the `Optimization.jl` interface. +Constructor for the `SigmaLayerNN`. It initializes the internal neural network. """ -function flatten_ps(ps)::Vector{Float32} - w1 = vec(ps.layer_1.weight) - b1 = ps.layer_1.bias - p = ps.layer_2.p - w3 = vec(ps.layer_3.weight) - b3 = ps.layer_3.bias - return vcat(w1, b1, p, w3, b3) +function SigmaLayerNN(H::Int) + net = Dense(H => H, tanh) + return SigmaLayerNN(net) +end + +# --- 2. Define How Lux Interacts with the Custom Layer --- + +# Explicitly tell Lux how to get the parameters for the inner network. +function initialparameters(rng::AbstractRNG, ℓ::SigmaLayerNN) + return (net = initialparameters(rng, ℓ.net),) end +# Explicitly tell Lux how to get the state for the inner network. +function initialstates(rng::AbstractRNG, ℓ::SigmaLayerNN) + return (net = initialstates(rng, ℓ.net),) +end + +# --- 3. Define the Layer's Forward Pass --- + """ - to_ps(θ::AbstractVector) + (ℓ::SigmaLayerNN)(z, ps, st) -Reconstructs the Lux parameter structure `ps` from a flat vector `θ`. +The forward pass for the `SigmaLayerNN`. It takes an input `z`, passes it +through the internal net to get the stiffening values `σ`, and then applies +those values to `z`. """ -function to_ps(θ::AbstractVector) - # Calculate expected length to catch errors if model architecture changes - expected_len = (2 * H) + H + 1 + (H * 2) + 2 # W1 + b1 + p2 + W3 + b3 - @assert length(θ) == expected_len "Incorrect parameter vector length." - - T = eltype(θ) - i = 1 - # Layer 1: Dense - w1_end = i + 2 * H - 1 - w1 = reshape(θ[i:w1_end], H, 2) - i = w1_end + 1 - # Layer 1: Bias - b1_end = i + H - 1 - b1 = θ[i:b1_end] - i = b1_end + 1 - # Layer 2: SigmaLayer - p2 = θ[i] - i += 1 - # Layer 3: Dense - w3_end = i + 2 * H - 1 - w3 = reshape(θ[i:w3_end], 2, H) - i = w3_end + 1 - # Layer 3: Bias - b3 = θ[i:end] - - return (layer_1 = (weight = T.(w1), bias = T.(b1)), - layer_2 = (p = T(p2),), - layer_3 = (weight = T.(w3), bias = T.(b3))) +function (ℓ::SigmaLayerNN)(z, ps, st) + # Get the raw output from the internal network + σ_raw, st_net = ℓ.net(z, ps.net, st.net) + + # Apply the sigmoid function to ensure stiffening values are positive + σ = 1.0f0 ./ (1.0f0 .+ exp.(-σ_raw)) + + # Apply the learned stiffening values, handling batch dimensions if present + if ndims(z) == 1 + z = z .* σ + else + z = z .* reshape(σ, :, 1) + end + + # Return the result and the updated state of the internal network + return z, (net = st_net,) end -# Get the initial flat parameter vector -θ0 = flatten_ps(ps0) +# --- 4. Build and Initialize the Full Model --- + +# Create the full model by chaining the layers together +model = Chain( + Dense(2 => H, tanh), + SigmaLayerNN(H), + Dense(H => 2) +) + +# Initialize the model's parameters (ps0) and state (st0) +rng = Random.default_rng() +ps0, st0 = Lux.setup(rng, model) + +# Define the constant state for training +const ST = st0 + +# Create the initial flat parameter vector using ComponentArrays +θ0 = ComponentArray(ps0) ``` ### 4. Constructing the Universal Differential Equation (UDE) @@ -262,49 +244,34 @@ const COUT = 5.0f0 # Clamp NN output to prevent explosions during training """ rhs_ude!(du, u, θ_vec, t) -The RHS function for the UDE. It computes the diffusion term analytically and -uses the neural network `model` (with parameters `θ_vec`) to approximate the -reaction term. +The right-hand side (RHS) function for the Universal Differential Equation (UDE). + +This function combines known physical laws (diffusion) with a neural network that learns +the unknown reaction dynamics. It operates over a 2D grid in a single loop. """ function rhs_ude!(du, u, θ_vec, t) - psθ = to_ps(θ_vec) # Reconstruct NN parameters for Lux Tz = eltype(u) - - @inbounds for I in CartesianIndices((N, N)) - i, j = Tuple(I) - x, y = xyd[i], xyd[j] - - # Neighbor indices with periodic boundaries - ip1 = limit(i + 1, N) - im1 = limit(i - 1, N) - jp1 = limit(j + 1, N) - jm1 = limit(j - 1, N) - - u1 = u[i, j, 1] - v1 = u[i, j, 2] - - # Part 1: Known Physics (Diffusion) - lap_u = (u[im1, j, 1] + u[ip1, j, 1] + u[i, jm1, 1] + u[i, jp1, 1] - 4.0f0 * u1) - lap_v = (u[im1, j, 2] + u[ip1, j, 2] + u[i, jm1, 2] + u[i, jp1, 2] - 4.0f0 * v1) - - # Part 2: Unknown Physics (Learned by NN) - # Input to the NN is the state [u1, v1] at a single grid point - nn_input = Tz[u1, v1] - reaction_pred, _ = model(nn_input, psθ, ST) - - # Clamp the NN output for stability - y1 = clamp(reaction_pred[1], -COUT, COUT) - y2 = clamp(reaction_pred[2], -COUT, COUT) - - # Combine known physics, learned reaction, and forcing term - du[i, j, 1] = αdx2 * lap_u + y1 + brusselator_f(x, y, t) - du[i, j, 2] = αdx2 * lap_v + y2 + loop_body = I -> begin + i,j = Tuple(I) + x = Float32(xyd[i]); y = Float32(xyd[j]) + u1 = u[i,j,1]; v1 = u[i,j,2] + lap_u = u[limit(i-1,N),j,1]+u[limit(i+1,N),j,1]+u[i,limit(j+1,N),1]+u[i,limit(j-1,N),1]-4f0*u1 + lap_v = u[limit(i-1,N),j,2]+u[limit(i+1,N),j,2]+u[i,limit(j+1,N),2]+u[i,limit(j-1,N),2]-4f0*v1 + x_in = Tz[u1, v1] + ŷ, _ = model(x_in, θ_vec, ST) + y1 = clamp(ŷ[1], -COUT, COUT) + y2 = clamp(ŷ[2], -COUT, COUT) + du[i,j,1] = αdx2*lap_u + y1 + brusselator_f(x,y,t) + du[i,j,2] = αdx2*lap_v + y2 + end + @inbounds @threads for I in CartesianIndices((N,N)) + loop_body(I) end - return nothing + nothing end # Define the ODE problem for the UDE, passing the NN parameters `θ0` -prob_ude = ODEProblem(rhs_ude!, u0, (0.0f0, TEND), θ0) +prob_ude = ODEProblem(rhs_ude!, u0, (0.0, TEND), θ0) ``` ### 5. Training the UDE @@ -325,23 +292,17 @@ Computes the mean squared error between the UDE solution (using parameters `θ_v and the ground truth solution `Yref`. """ function loss(θ_vec) - # Solve the UDE with the current parameter vector - sol = solve(remake(prob_ude; p = θ_vec), KenCarp47(linsolve = KrylovJL_GMRES()); - saveat = ts, reltol = 1e-4, abstol = 1e-4, - save_everystep = false, sensealg = sensealg) - - # Return Inf if the solver failed to produce a solution of the correct size - if size(sol) != size(Yref) + sol = solve(remake(prob_ude; p=θ_vec), KenCarp47(linsolve=LinearSolve.KrylovJL_GMRES()); + saveat=ts, reltol=1f-4, abstol=1f-4, save_everystep=false, sensealg=sensealg) + Y = Array(sol) + if size(Y) != size(Yref) return Inf32 end - - # Return the mean squared error - Y = Array(sol) - return sum(abs2, Y .- Yref) / length(Yref) + sum(abs2, Y .- Yref) / length(Yref) end # -- Setup the optimization problem -- -optf = OptimizationFunction((θ, _) -> loss(θ), AutoZygote()) +optf = OptimizationFunction((θ, _)->loss(θ), AutoReverseDiff()) optprob = OptimizationProblem(optf, θ0) # -- Define a callback to monitor training progress -- @@ -358,7 +319,7 @@ function cb(θ, f_val) end # -- Run the optimization -- -solopt = solve(optprob, Optimisers.Adam(1e-2); maxiters = MAXITERS, callback = cb) +solopt = solve(optprob, Optimisers.Adam(1e-2); maxiters=MAXITERS, callback=cb) θ★ = solopt.u # The optimal parameters println("Training finished.") @@ -380,7 +341,7 @@ println("\nStage 4/4: Evaluating final model and generating plots...") # Solve the UDE one last time with the optimized parameters `θ★` sol_ude = solve(remake(prob_ude; p = θ★), KenCarp47(linsolve = KrylovJL_GMRES()); - saveat = ts, reltol = 1e-5, abstol = 1e-5, save_everystep = false) + saveat = ts, reltol = 1e-6, abstol = 1e-6, save_everystep = false) # Calculate the final relative mean squared error final_loss = sum(abs2, Array(sol_ude) .- Yref) / sum(abs2, Yref) @@ -389,22 +350,18 @@ println("Done. Final relative MSE = ", final_loss) # -- Create comparison plots -- # 1. Heatmap comparison of the final state -final_state_true = sol_ref.u[end][:, :, 1] -final_state_ude = sol_ude.u[end][:, :, 1] +final_state_true = Yref[:,:,1,end] +final_state_ude = sol_ude_final.u[end][:, :, 1] -p1 = heatmap(final_state_true, title = "True Simulation (t=$(TEND))") -p2 = heatmap(final_state_ude, title = "Final UDE (it=$(k_iter))") -comparison_plot = plot(p1, p2, layout = (1, 2), size = (900, 400)) +p1 = heatmap(final_state_true, title="True Simulation (t=$(TEND))") +p2 = heatmap(final_state_ude, title="Final SNN-UDE (it=$(k_iter))") +comparison_plot = plot(p1, p2, layout=(1, 2), size=(900, 400)) display(comparison_plot) # 2. Time series comparison of the mean concentration -mean_true = [mean(u[:, :, 1]) for u in sol_ref.u] -mean_ude = [mean(u[:, :, 1]) for u in sol_ude.u] - -metric_plot = plot(ts, mean_true, label = "True Simulation", lw = 2, - xlabel = "Time (t)", ylabel = "Mean Concentration", - title = "Model Performance vs. Ground Truth") -plot!(ts, mean_ude, label = "UDE Prediction", lw = 2, linestyle = :dash) +mean_ude = [mean(u[:,:,1]) for u in sol_ude_final.u] +metric_plot = plot(ts, mean_true, label="True Simulation", lw=2, xlabel="Time (t)", ylabel="Mean Concentration", title="Model Performance (Final)") +plot!(ts, mean_ude, label="SNN-UDE Prediction", lw=2, linestyle=:dash) display(metric_plot) println("\nPlots are displayed.") From 84af0f7d673ecd04f99b270834a9cd7fd0d09505 Mon Sep 17 00:00:00 2001 From: Sharv <72983931+sharvmurgai@users.noreply.github.com> Date: Thu, 9 Oct 2025 00:01:27 -0700 Subject: [PATCH 4/4] Fix stiff_nn_bruss.md --- docs/src/examples/stiff_nn_bruss.md | 382 +++++++++++++--------------- 1 file changed, 171 insertions(+), 211 deletions(-) diff --git a/docs/src/examples/stiff_nn_bruss.md b/docs/src/examples/stiff_nn_bruss.md index 472e19c7e..8defa101b 100644 --- a/docs/src/examples/stiff_nn_bruss.md +++ b/docs/src/examples/stiff_nn_bruss.md @@ -31,110 +31,81 @@ using LinearSolve using SciMLSensitivity using Lux using Optimization, OptimizationOptimisers -using StaticArrays +using Optimisers using NNlib using Random, Zygote using Plots, Statistics using Base.Threads using ComponentArrays -using ReverseDiff - +using JLD2 +using Dates # We disable the default plot saving and display them directly. -default(show = true) # 1. Problem Setup: Constants, Grid, and Initial Conditions # -- Simulation Parameters -- -const N = 32 -const TEND = 11.5f0 +const N = 16 +const TEND = 12.0f0 const MAXITERS = 200 const RTOL_REF = 1f-6 const ATOL_REF = 1f-6 # -- Grid and Discretization -- -const xyd = range(0.0f0, 1.0f0, length = N) # Spatial domain -const dx = step(xyd) # Spatial step size - -""" - limit(a::Int, N::Int) +const xyd = range(0f0, stop=1f0, length=N) +dx = step(xyd) -Enforces periodic boundary conditions by wrapping indices around the grid. -If an index `a` goes past the boundary (1 or N), it wraps to the other side. -""" @inline limit(a::Int, N::Int) = a == N + 1 ? 1 : (a == 0 ? N : a) +@inline brusselator_f(x, y, t) = + (((x - 0.3f0)^2 + (y - 0.6f0)^2) <= 0.1f0^2) * (t >= 1.1f0) * 5.0f0 -""" - brusselator_f(x, y, t) - -A forcing term for the Brusselator equation, which is active in a circular -region for t ≥ 1.1. -""" -@inline brusselator_f(x, y, t) = (((x - 0.3)^2 + (y - 0.6)^2) <= 0.1^2) * (t >= 1.1) * 5.0 - -""" - init_u0(xyd) - -Generates the initial condition `u0` for the two species on the grid. -""" function init_u0(xyd) N = length(xyd) u = zeros(Float32, N, N, 2) - @inbounds for I in CartesianIndices((N, N)) + @inbounds for I in CartesianIndices((N,N)) x = Float32(xyd[I[1]]); y = Float32(xyd[I[2]]) - u[I, 1] = 22.0f0 * (y * (1.0f0 - y))^(3.0f0 / 2.0f0) - u[I, 2] = 27.0f0 * (x * (1.0f0 - x))^(3.0f0 / 2.0f0) + u[I,1] = 22f0 * (y*(1f0-y))^(3f0/2f0) + u[I,2] = 27f0 * (x*(1f0-x))^(3f0/2f0) end return u end - -# Initialize the state vector `u0` -u0 = init_u0(xyd) +const u0 = init_u0(xyd) +const α = 10.0f0 +const αdx2 = α / (dx*dx) ``` ### 2. Generating the Reference Solution (Ground Truth) To train our UDE, we need data to learn from. We generate this by solving the full Brusselator PDE with its known equations. This solution will serve as our "ground truth" that we will try to replicate with the UDE. The rhs_ref! function defines the complete dynamics, including both diffusion and reaction terms. ```@example stiff_bruss -# 2. Reference Solution (Ground Truth) - -# Here, we solve the full PDE with the known reaction terms to generate -# the data we will use to train our neural network. - -# -- Brusselator PDE Parameters -- -const α = 10.0f0 -const αdx2 = α / dx^2 - -""" - rhs_ref!(du, u, p, t) +println("Stage 1/4: Ground truth not found. Generating..."); flush(stdout) -The right-hand side (RHS) function for the Brusselator PDE, including both the -diffusion (Laplacian) and the known reaction terms. This defines the true dynamics. -""" -println("Stage 1/4: Generating reference solution...") function rhs_ref!(du, u, p, t) - A, B, alpha_val = p - alpha_val = alpha_val / dx^2 + A, B, αval = p + αval_dx2 = αval / (dx*dx) @inbounds for I in CartesianIndices((N, N)) i, j = Tuple(I) - x, y = xyd[I[1]], xyd[I[2]] - ip1, im1, jp1, jm1 = limit(i + 1, N), limit(i - 1, N), limit(j + 1, N), limit(j - 1, N) - du[i, j, 1] = alpha_val * (u[im1, j, 1] + u[ip1, j, 1] + u[i, jp1, 1] + u[i, jm1, 1] - 4u[i, j, 1]) + - B + u[i, j, 1]^2 * u[i, j, 2] - (A + 1) * u[i, j, 1] + - brusselator_f(x, y, t) - du[i, j, 2] = alpha_val * (u[im1, j, 2] + u[ip1, j, 2] + u[i, jp1, 2] + u[i, jm1, 2] - 4u[i, j, 2]) + - A * u[i, j, 1] - u[i, j, 1]^2 * u[i, j, 2] + x, y = xyd[i], xyd[j] + ip1, im1 = limit(i+1,N), limit(i-1,N) + jp1, jm1 = limit(j+1,N), limit(j-1,N) + u1 = u[i,j,1]; v1 = u[i,j,2] + lap_u = u[im1,j,1] + u[ip1,j,1] + u[i,jp1,1] + u[i,jm1,1] - 4f0*u1 + lap_v = u[im1,j,2] + u[ip1,j,2] + u[i,jp1,2] + u[i,jm1,2] - 4f0*v1 + du[i,j,1] = αval_dx2*lap_u + B + u1^2*v1 - (A+1f0)*u1 + brusselator_f(x,y,t) + du[i,j,2] = αval_dx2*lap_v + A*u1 - u1^2*v1 end + return nothing end -p_ref = (3.4, 1.0, 10.0) -prob_ref = ODEProblem(rhs_ref!, u0, (0.0, TEND), p_ref) -sol_ref = solve(prob_ref, KenCarp47(linsolve=KrylovJL_GMRES()); - saveat=0.0:0.5:TEND, reltol=RTOL_REF, abstol=ATOL_REF, progress=true) +p_ref = (3.4f0, 1.0f0, 10.0f0) +prob_ref = ODEProblem(rhs_ref!, u0, (0.0f0, TEND), p_ref) +sol_ref = solve(prob_ref, KenCarp47(linsolve=LinearSolve.KrylovJL_GMRES()); + saveat=0.0f0:0.5f0:TEND, reltol=RTOL_REF, abstol=ATOL_REF, progress=true) + +global Yref = Array(sol_ref) +global ts = sol_ref.t -const Yref = Array(sol_ref) -const ts = sol_ref.t -const mean_true = [mean(Yref[:,:,1,i]) for i in 1:size(Yref, 4)] -println("Ground truth generated. Size: ", size(Yref)) +const mean_true = [mean(Yref[:,:,1,i]) for i in 1:size(Yref,4)] +println("Ground truth loaded. Size: ", size(Yref)) ``` ### 3. Defining the Neural Network @@ -147,87 +118,67 @@ Next, we define the neural network architecture that will learn the unknown reac # unknown reaction term. import LuxCore: initialparameters, initialstates -using Random, Lux, ComponentArrays -const H = 16 - -# --- 1. Define the Custom Neural Network Layer --- """ - SigmaLayerNN{M} <: Lux.AbstractLuxLayer - -A custom Lux layer that contains an internal neural network (`net`). This -internal net learns the stiffening values (`σ`) which are then applied to the layer's input. +SigmaDiag: diagonal Σ in state space (m = length(x)), PSD with mode-indexed decay. """ -struct SigmaLayerNN{M} <: Lux.AbstractLuxLayer - net::M +struct SigmaDiag <: Lux.AbstractLuxLayer + kind::Symbol # :logistic or :exp +end +function initialparameters(::AbstractRNG, Σ::SigmaDiag) + if Σ.kind === :exp + return (β = 1.0f0,) + else + return (a = -3.0f0, b = 4.0f0) # logistic: σ_k = σ(a*k + b) + end +end +initialstates(::AbstractRNG, ::SigmaDiag) = NamedTuple() +function (Σ::SigmaDiag)(x::AbstractVector{T}, ps, st) where {T} + k = T.(1:length(x)) + σ = Σ.kind === :exp ? exp.(-ps.β .* k) : + Σ.kind === :logistic ? one(T) ./(one(T) .+ exp.(-(ps.a .* k .+ ps.b))) : + error("Unknown Σ.kind=$(Σ.kind)") + return σ .* x, st end """ - SigmaLayerNN(H::Int) - -Constructor for the `SigmaLayerNN`. It initializes the internal neural network. +StiffUVSigma: F(x) = U( Σ( V(x) ) ) with U,V: ℝ^m→ℝ^m and diagonal Σ. """ -function SigmaLayerNN(H::Int) - net = Dense(H => H, tanh) - return SigmaLayerNN(net) +struct StiffUVSigma{V,U,S} <: Lux.AbstractLuxLayer + V::V + Σ::S + U::U end - -# --- 2. Define How Lux Interacts with the Custom Layer --- - -# Explicitly tell Lux how to get the parameters for the inner network. -function initialparameters(rng::AbstractRNG, ℓ::SigmaLayerNN) - return (net = initialparameters(rng, ℓ.net),) +function initialparameters(rng::AbstractRNG, M::StiffUVSigma) + return (V = initialparameters(rng, M.V), + Σ = initialparameters(rng, M.Σ), + U = initialparameters(rng, M.U)) end - -# Explicitly tell Lux how to get the state for the inner network. -function initialstates(rng::AbstractRNG, ℓ::SigmaLayerNN) - return (net = initialstates(rng, ℓ.net),) +function initialstates(rng::AbstractRNG, M::StiffUVSigma) + return (V = initialstates(rng, M.V), + Σ = initialstates(rng, M.Σ), + U = initialstates(rng, M.U)) end - -# --- 3. Define the Layer's Forward Pass --- - -""" - (ℓ::SigmaLayerNN)(z, ps, st) - -The forward pass for the `SigmaLayerNN`. It takes an input `z`, passes it -through the internal net to get the stiffening values `σ`, and then applies -those values to `z`. -""" -function (ℓ::SigmaLayerNN)(z, ps, st) - # Get the raw output from the internal network - σ_raw, st_net = ℓ.net(z, ps.net, st.net) - - # Apply the sigmoid function to ensure stiffening values are positive - σ = 1.0f0 ./ (1.0f0 .+ exp.(-σ_raw)) - - # Apply the learned stiffening values, handling batch dimensions if present - if ndims(z) == 1 - z = z .* σ - else - z = z .* reshape(σ, :, 1) - end - - # Return the result and the updated state of the internal network - return z, (net = st_net,) +function (M::StiffUVSigma)(x, ps, st) + yV, stV = M.V(x, ps.V, st.V) + yΣ, stΣ = M.Σ(yV, ps.Σ, st.Σ) + yU, stU = M.U(yΣ, ps.U, st.U) + return yU, (V = stV, Σ = stΣ, U = stU) end -# --- 4. Build and Initialize the Full Model --- +# Model: U( Σ V(x) ) with m=2 state; H is width in U/V +const H = 16 +const STIFF_KIND = Symbol(get(ENV, "STIFF_KIND", "logistic")) # :logistic | :exp -# Create the full model by chaining the layers together -model = Chain( - Dense(2 => H, tanh), - SigmaLayerNN(H), - Dense(H => 2) -) +Random.seed!(1234) +Vnet = Chain(Dense(2 => H, tanh), Dense(H => 2)) +Σlay = SigmaDiag(STIFF_KIND) +Unet = Chain(Dense(2 => H, tanh), Dense(H => 2)) +model = StiffUVSigma(Vnet, Σlay, Unet) -# Initialize the model's parameters (ps0) and state (st0) -rng = Random.default_rng() +rng = Random.default_rng() ps0, st0 = Lux.setup(rng, model) - -# Define the constant state for training const ST = st0 - -# Create the initial flat parameter vector using ComponentArrays θ0 = ComponentArray(ps0) ``` @@ -235,94 +186,111 @@ const ST = st0 Here is the core of the UDE. The rhs_ude! function defines the hybrid dynamics. It explicitly calculates the diffusion term (the known physics) and calls the neural network to approximate the reaction term (the unknown physics). This function is then used to create an ODEProblem that can be solved and differentiated. ```@example stiff_bruss -# 4. Universal Differential Equation (UDE) - -# The UDE combines the known physics (diffusion) with the neural network. - -const COUT = 5.0f0 # Clamp NN output to prevent explosions during training - -""" - rhs_ude!(du, u, θ_vec, t) +const RESIDUAL = true +const COUT = 8.0 +const RESID_COUT = 0.7 +const A0 = 3.4 +const B0 = 1.0 -The right-hand side (RHS) function for the Universal Differential Equation (UDE). - -This function combines known physical laws (diffusion) with a neural network that learns -the unknown reaction dynamics. It operates over a 2D grid in a single loop. -""" function rhs_ude!(du, u, θ_vec, t) Tz = eltype(u) - loop_body = I -> begin - i,j = Tuple(I) - x = Float32(xyd[i]); y = Float32(xyd[j]) + @inbounds for I in CartesianIndices((N,N)) + i, j = Tuple(I); x = xyd[i]; y = xyd[j] u1 = u[i,j,1]; v1 = u[i,j,2] - lap_u = u[limit(i-1,N),j,1]+u[limit(i+1,N),j,1]+u[i,limit(j+1,N),1]+u[i,limit(j-1,N),1]-4f0*u1 - lap_v = u[limit(i-1,N),j,2]+u[limit(i+1,N),j,2]+u[i,limit(j+1,N),2]+u[i,limit(j-1,N),2]-4f0*v1 - x_in = Tz[u1, v1] - ŷ, _ = model(x_in, θ_vec, ST) - y1 = clamp(ŷ[1], -COUT, COUT) - y2 = clamp(ŷ[2], -COUT, COUT) - du[i,j,1] = αdx2*lap_u + y1 + brusselator_f(x,y,t) - du[i,j,2] = αdx2*lap_v + y2 + lap_u = u[limit(i-1,N),j,1] + u[limit(i+1,N),j,1] + + u[i,limit(j+1,N),1] + u[i,limit(j-1,N),1] - 4f0*u1 + lap_v = u[limit(i-1,N),j,2] + u[limit(i+1,N),j,2] + + u[i,limit(j+1,N),2] + u[i,limit(j-1,N),2] - 4f0*v1 + ŷ, _ = model(Tz[u1, v1], θ_vec, ST) + if RESIDUAL + r1 = B0 + u1^2*v1 - (A0 + 1f0)*u1 + r2 = A0*u1 - u1^2*v1 + δ1 = clamp(ŷ[1], -RESID_COUT, RESID_COUT) + δ2 = clamp(ŷ[2], -RESID_COUT, RESID_COUT) + du[i,j,1] = αdx2*lap_u + (r1 + δ1) + brusselator_f(x,y,t) + du[i,j,2] = αdx2*lap_v + (r2 + δ2) + else + y1 = clamp(ŷ[1], -COUT, COUT); y2 = clamp(ŷ[2], -COUT, COUT) + du[i,j,1] = αdx2*lap_u + y1 + brusselator_f(x,y,t) + du[i,j,2] = αdx2*lap_v + y2 + end end - @inbounds @threads for I in CartesianIndices((N,N)) - loop_body(I) - end - nothing + return nothing end -# Define the ODE problem for the UDE, passing the NN parameters `θ0` -prob_ude = ODEProblem(rhs_ude!, u0, (0.0, TEND), θ0) +prob_ude = ODEProblem(rhs_ude!, u0, (0.0f0, TEND), θ0) ``` ### 5. Training the UDE -With the UDE defined, we can now train it. The loss function solves the UDE with the current neural network parameters and computes the mean squared error against the reference data. We use Optimization.jl with the Adam optimizer to minimize this loss. SciMLSensitivity.jl provides the magic to efficiently compute gradients of the loss function with respect to the network parameters, even though the parameters are inside a differential equation solver. +With the UDE defined, we can now train it. The loss function solves the UDE with the current neural network parameters and computes the mean squared error against the reference data. We use Optimization.jl with the Adam optimizer to minimize this loss. SciMLSensitivity.jl provides the functionality to efficiently compute gradients of the loss function with respect to the network parameters, even though the parameters are inside a differential equation solver. ```@example stiff_bruss -# 5. Training the UDE +println("Stage 2/4: Set loss + optimizer …"); flush(stdout) +sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP()) -println("\nStage 2/4: Setting up loss function and optimizer...") +const ALG_FINAL = TRBDF2(autodiff=false) +const FINAL_RTOL = 1f-4 +const FINAL_ATOL = 1f-4 +const FINAL_DTMAX = 0.1f0 -# Define the sensitivity algorithm for calculating gradients efficiently -sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP()) +# ---- Zygote-safe constants used in loss ---- +const LAMBDA = 3e-3 -""" - loss(θ_vec) +# Curriculum in time for training only +const TRAIN_FRAC = 1.0 +const TRAIN_IDX = 1:clamp(ceil(Int, TRAIN_FRAC * length(ts)), 1, length(ts)) +const ts_train = ts[TRAIN_IDX] +const Yref_train = Yref[:,:,:,TRAIN_IDX] +const W_WEIGHTS = reshape(1 .+ collect(ts_train)./maximum(ts_train), 1,1,1,length(TRAIN_IDX)) -Computes the mean squared error between the UDE solution (using parameters `θ_vec`) -and the ground truth solution `Yref`. -""" function loss(θ_vec) - sol = solve(remake(prob_ude; p=θ_vec), KenCarp47(linsolve=LinearSolve.KrylovJL_GMRES()); - saveat=ts, reltol=1f-4, abstol=1f-4, save_everystep=false, sensealg=sensealg) + sol = solve(remake(prob_ude; p=θ_vec), TRBDF2(autodiff=false); + saveat=ts_train, reltol=1f-4, abstol=1f-4, + save_everystep=false, sensealg=sensealg) Y = Array(sol) - if size(Y) != size(Yref) + if size(Y) != size(Yref_train) return Inf32 end - sum(abs2, Y .- Yref) / length(Yref) + data_mse = sum(abs2, (Y .- Yref_train) .* W_WEIGHTS) / length(Yref_train) + reg = LAMBDA * sum(abs2, θ_vec) + return data_mse + reg end -# -- Setup the optimization problem -- -optf = OptimizationFunction((θ, _)->loss(θ), AutoReverseDiff()) +optf = OptimizationFunction((θ, _)->loss(θ), AutoZygote()) optprob = OptimizationProblem(optf, θ0) -# -- Define a callback to monitor training progress -- -println("Stage 3/4: Starting training...") +println("Stage 3/4: Training …"); flush(stdout) k_iter = 0 -function cb(θ, f_val) +function cb(state, L) global k_iter += 1 if k_iter % 5 == 0 - println(" Iter: $(k_iter) \t Loss: $(round(f_val, digits=6))") - flush(stdout) + println(" it=$k_iter loss=$(round(L; digits=6))"); flush(stdout) end - # Return false to continue optimization return false end -# -- Run the optimization -- -solopt = solve(optprob, Optimisers.Adam(1e-2); maxiters=MAXITERS, callback=cb) -θ★ = solopt.u # The optimal parameters +initial_lr = 0.0010 +decay_lr = 0.00035 +switch_it = 40 + +function train_with_schedule() + it1 = min(MAXITERS, switch_it) + opt1 = Optimisers.OptimiserChain(Optimisers.ClipNorm(1.0f0), Optimisers.Adam(initial_lr)) + stg1 = solve(optprob, opt1; maxiters=it1, callback=cb) + rem = MAXITERS - it1 + if rem > 0 + optprob2 = OptimizationProblem(optf, ComponentArray(stg1.u)) + opt2 = Optimisers.OptimiserChain(Optimisers.ClipNorm(1.0f0), Optimisers.Adam(decay_lr)) + stg2 = solve(optprob2, opt2; maxiters=rem, callback=cb) + return stg2 + else + return stg1 + end +end -println("Training finished.") +solopt = train_with_schedule() + +θ★ = ComponentArray(solopt.u) ``` ### 6. Evaluation and Visualization @@ -335,34 +303,26 @@ A time-series plot showing the evolution of the mean concentration over the enti If the training was successful, the UDE's output should closely match the true simulation. ```@example stiff_bruss -# 6. Evaluation and Visualization - -println("\nStage 4/4: Evaluating final model and generating plots...") +println("\nStage 4/4: Final evaluation …"); flush(stdout) +println(" Config: N=$(N), TEND=$(TEND), TRAIN_FRAC=$(TRAIN_FRAC)") +println(" Threads during final solve: DISABLED | Algorithm: ", typeof(ALG_FINAL)) +t_start = time() -# Solve the UDE one last time with the optimized parameters `θ★` -sol_ude = solve(remake(prob_ude; p = θ★), KenCarp47(linsolve = KrylovJL_GMRES()); - saveat = ts, reltol = 1e-6, abstol = 1e-6, save_everystep = false) +sol_ude_final = solve(remake(prob_ude; p=θ★), ALG_FINAL; + saveat=ts, reltol=FINAL_RTOL, abstol=FINAL_ATOL, + save_everystep=false, dtmax=FINAL_DTMAX) -# Calculate the final relative mean squared error -final_loss = sum(abs2, Array(sol_ude) .- Yref) / sum(abs2, Yref) -println("Done. Final relative MSE = ", final_loss) +println(" Final solve retcode = ", sol_ude_final.retcode) +println(" Saved points = ", length(sol_ude_final.t)) +println(" Final solve wall time = $(round(time()-t_start; digits=3)) s") -# -- Create comparison plots -- +rel_mse = sum(abs2, Array(sol_ude_final) .- Yref) / sum(abs2, Yref) +println("done. Final relative MSE = ", rel_mse) -# 1. Heatmap comparison of the final state -final_state_true = Yref[:,:,1,end] -final_state_ude = sol_ude_final.u[end][:, :, 1] - -p1 = heatmap(final_state_true, title="True Simulation (t=$(TEND))") -p2 = heatmap(final_state_ude, title="Final SNN-UDE (it=$(k_iter))") -comparison_plot = plot(p1, p2, layout=(1, 2), size=(900, 400)) -display(comparison_plot) - -# 2. Time series comparison of the mean concentration mean_ude = [mean(u[:,:,1]) for u in sol_ude_final.u] -metric_plot = plot(ts, mean_true, label="True Simulation", lw=2, xlabel="Time (t)", ylabel="Mean Concentration", title="Model Performance (Final)") -plot!(ts, mean_ude, label="SNN-UDE Prediction", lw=2, linestyle=:dash) -display(metric_plot) - -println("\nPlots are displayed.") +plt_final = plot(ts, mean_true, label="True Simulation", lw=2, + xlabel="Time (t)", ylabel="Mean Concentration", + title="Model Performance (Final)") +plot!(plt_final, ts, mean_ude, label=RESIDUAL ? "SNN Residual" : "SNN-UDE Prediction", + lw=2, linestyle=:dash) ```