|
| 1 | +# Learning Non-Linear Reaction Dynamics for the Gray–Scott Reaction–Diffusion Model using Universal Differential Equations |
| 2 | + |
| 3 | +## Introduction |
| 4 | +The Gray–Scott model is a prototypical reaction–diffusion system known for generating a wide variety of spatial patterns — from spots and stripes to labyrinthine structures — driven purely by simple chemical kinetics and diffusion. |
| 5 | + |
| 6 | +In this tutorial, we’ll employ a Universal Differential Equation (UDE) framework: embedding a small neural network within the PDE’s reaction term to learn unknown dynamics from data, while retaining the known diffusion physics. |
| 7 | + |
| 8 | + |
| 9 | +## Equations of the Gray-Scott Model |
| 10 | +The system is governed by the coupled PDEs: |
| 11 | +```math |
| 12 | +\frac{\partial u}{\partial t} = D_1\,\nabla^2 u + \frac{a\,u^2}{v} + \bar{u} - \alpha |
| 13 | +``` |
| 14 | + |
| 15 | +```math |
| 16 | +\frac{\partial v}{\partial t} = D_2\,\nabla^2 v + a\,u^2 + \beta\,v |
| 17 | +``` |
| 18 | + |
| 19 | +where $u$ and $v$ are the two chemical concentrations, $D_1$ and $D_2$ are diffusion coefficients, and $a$, $\bar{u}$, $\alpha$, $\beta$ are reaction parameters. |
| 20 | + |
| 21 | +In its spatially discretized form (using Neumann boundary conditions and the tridiagonal stencil $[1, -2, 1]$), the Gray–Scott PDE reduces to: |
| 22 | + |
| 23 | +```math |
| 24 | +du = D1 * (A_y u + u A_x) + \frac{a u^2}{v} + \bar{u} - \alpha u |
| 25 | +``` |
| 26 | + |
| 27 | +```math |
| 28 | +dv = D2 (A_y v + v A_x) + a u^2 + \beta v |
| 29 | +```` |
| 30 | +Here $A_x$ and $A_y$ are the 1D Laplacian matrices for x- and y-directions, respectively. |
| 31 | +
|
| 32 | +Now we will dive into the implementation of the UDE. |
| 33 | +
|
| 34 | +## Ground-truth data generation. |
| 35 | +```@example gray_scott |
| 36 | +using DifferentialEquations |
| 37 | +using Lux, Random, Optim, ComponentArrays, Statistics |
| 38 | +using LinearAlgebra |
| 39 | +using Plots, SciMLBase |
| 40 | +using Optimization, OptimizationOptimisers |
| 41 | +using SciMLSensitivity |
| 42 | +
|
| 43 | +const N = 32 # Smaller grid for faster training |
| 44 | +
|
| 45 | +# Constants for the "true" model |
| 46 | +p_true = (a=1.0, α=1.0, ubar=1.0, β=10.0, D1=0.001, D2=0.1) |
| 47 | +
|
| 48 | +# Laplacian operator with Neumann (zero-flux) boundary conditions |
| 49 | +Ax = Array(Tridiagonal([1.0 for i in 1:(N - 1)], [-2.0 for i in 1:N], [1.0 for i in 1:(N - 1)])) |
| 50 | +Ay = copy(Ax) |
| 51 | +Ax[1, 2] = 2.0 |
| 52 | +Ax[end, end - 1] = 2.0 |
| 53 | +Ay[1, 2] = 2.0 |
| 54 | +Ay[end, end - 1] = 2.0 |
| 55 | +
|
| 56 | +# Initial condition |
| 57 | +const uss = (p_true.ubar + p_true.β) / p_true.α |
| 58 | +const vss = (p_true.a / p_true.β) * uss^2 |
| 59 | +const r0 = zeros(N, N, 2) |
| 60 | +r0[:, :, 1] .= uss |
| 61 | +r0[:, :, 2] .= vss |
| 62 | +r0[div(N,2)-5:div(N,2)+5, div(N,2)-5:div(N,2)+5, 1] .+= 0.1 .* rand.() |
| 63 | +r0[div(N,2)-5:div(N,2)+5, div(N,2)-5:div(N,2)+5, 2] .+= 0.1 .* rand.() |
| 64 | +``` |
| 65 | + |
| 66 | +Having set up the grid, parameters, and initial condition, we now generate “ground truth” data by integrating the pure physics Gray–Scott model. This dataset will serve as the target that our UDE aims to learn. |
| 67 | + |
| 68 | +```@example gray_scott |
| 69 | +function true_model!(dr, r, p, t) |
| 70 | + a, α, ubar, β, D1, D2 = p |
| 71 | + u = @view r[:, :, 1] |
| 72 | + v = @view r[:, :, 2] |
| 73 | + Du = D1 .* (Ay * u + u * Ax) |
| 74 | + Dv = D2 .* (Ay * v + v * Ax) |
| 75 | + react_u = a .* u .* u ./ v .+ ubar .- α * u |
| 76 | + react_v = a .* u .* u .- β * v |
| 77 | + @. dr[:, :, 1] = Du + react_u |
| 78 | + @. dr[:, :, 2] = Dv + react_v |
| 79 | +end |
| 80 | +
|
| 81 | +tspan = (0.0, 0.1) |
| 82 | +tsteps = 0.0:0.01:0.1 |
| 83 | +prob_true = ODEProblem(true_model!, r0, tspan, p_true) |
| 84 | +solution_true = solve(prob_true, Tsit5(), saveat=tsteps) |
| 85 | +data_to_train = Array(solution_true) |
| 86 | +``` |
| 87 | + |
| 88 | +With the ground-truth solutions computed, we can now proceed to visualize the spatiotemporal evolution of $u$ and $v$ on the grid. |
| 89 | + |
| 90 | +```@example gray_scott |
| 91 | +println("Plotting ground truth concentrations at the grid center...") |
| 92 | +center = N ÷ 2 |
| 93 | +p1 = plot(tsteps, solution_true[center,center,1,:], lw=2, label="U True", color=:blue) |
| 94 | +title!(p1, "Ground Truth: Center U") |
| 95 | +xlabel!("Time") |
| 96 | +ylabel!("Concentration") |
| 97 | +
|
| 98 | +p2 = plot(tsteps, solution_true[center,center,2,:], lw=2, label="V True", color=:red) |
| 99 | +title!(p2, "Ground Truth: Center V") |
| 100 | +xlabel!("Time") |
| 101 | +
|
| 102 | +p_center_combined = plot(p1, p2, layout=(1,2), size=(900,350)) |
| 103 | +display(p_center_combined) |
| 104 | +``` |
| 105 | + |
| 106 | +## Defining the UDE |
| 107 | +Now that we have an understanding of the data and its visualization, we can define the neural network and the UDE structure. We replace the $\frac{a u^2}{v} + \bar{u} - \alpha u$ term with a neural network, giving the resultant ODEs. |
| 108 | + |
| 109 | +```math |
| 110 | +du = D1 * (A_y u + u A_x) + \mathcal{N}_\theta(u,v) |
| 111 | +``` |
| 112 | + |
| 113 | +```math |
| 114 | +dv = D2 (A_y v + v A_x) + a u^2 + \beta v |
| 115 | +```` |
| 116 | +
|
| 117 | +The first step is to do data normalization, or compute the statistical properties of the dataset. We normalize the inputs to the neural network to make the training process more stable and efficient. Neural networks learn best when their input data is scaled to a consistent range, typically centered around zero, which helps prevent the gradients from becoming too large or too small during training. This leads to faster convergence and allows the optimizer to find a good solution more reliably. |
| 118 | +
|
| 119 | +```@example gray_scott |
| 120 | +# Calculate mean and std dev for u and v across all space and time |
| 121 | +u_data = @view data_to_train[:, :, 1, :] |
| 122 | +v_data = @view data_to_train[:, :, 2, :] |
| 123 | +
|
| 124 | +u_mean = mean(u_data) |
| 125 | +u_std = std(u_data) |
| 126 | +v_mean = mean(v_data) |
| 127 | +v_std = std(v_data) |
| 128 | +
|
| 129 | +norm_stats = (u_mean=u_mean, u_std=u_std, v_mean=v_mean, v_std=v_std) |
| 130 | +``` |
| 131 | + |
| 132 | +Next, we define the neural network structure. |
| 133 | +```@example gray_scott |
| 134 | +rng = Random.default_rng() |
| 135 | +nn = Lux.Chain(Lux.Dense(2 => 16, tanh), Lux.Dense(16 => 1)) |
| 136 | +p_nn, st_nn = Lux.setup(rng, nn) |
| 137 | +
|
| 138 | +# Add the normalization stats to the non-trainable parameters |
| 139 | +p_ude = ComponentArray( |
| 140 | + p_physics=(β=p_true.β, D1=p_true.D1, D2=p_true.D2, a=p_true.a), |
| 141 | + p_nn=p_nn, |
| 142 | + p_norm=norm_stats # Add normalization stats here |
| 143 | +) |
| 144 | +``` |
| 145 | + |
| 146 | +The following code describes the UDE formulation with the neural network predictions embedded. |
| 147 | + |
| 148 | +```@example gray_scott |
| 149 | +function create_ude_model(nn_model, nn_state) |
| 150 | + function ude_model!(dr, r, p, t) |
| 151 | + β, D1, D2, a = p.p_physics |
| 152 | + u_mean, u_std, v_mean, v_std = p.p_norm |
| 153 | + |
| 154 | + u = @view r[:, :, 1] |
| 155 | + v = @view r[:, :, 2] |
| 156 | + Du = D1 .* (Ay * u + u * Ax) |
| 157 | + Dv = D2 .* (Ay * v + v * Ax) |
| 158 | + react_v = a .* u .* u .- β * v |
| 159 | + @. dr[:, :, 2] = Dv + react_v |
| 160 | + |
| 161 | + nn_reaction_u = similar(u) |
| 162 | + for i in 1:N, j in 1:N |
| 163 | + # Normalize the input to the NN |
| 164 | + u_norm = (u[i, j] - u_mean) / (u_std + 1f-8) # Add epsilon for stability |
| 165 | + v_norm = (v[i, j] - v_mean) / (v_std + 1f-8) |
| 166 | + input = [u_norm, v_norm] |
| 167 | + |
| 168 | + # The NN receives normalized data |
| 169 | + nn_reaction_u[i, j] = nn_model(input, p.p_nn, nn_state)[1][1] |
| 170 | + end |
| 171 | + @. dr[:, :, 1] = Du + nn_reaction_u |
| 172 | + end |
| 173 | + return ude_model! |
| 174 | +end |
| 175 | +
|
| 176 | +ude_model! = create_ude_model(nn, st_nn) |
| 177 | +prob_ude = ODEProblem(ude_model!, r0, tspan, p_ude) |
| 178 | +``` |
| 179 | + |
| 180 | +## Loss Function and Optimization |
| 181 | +The loss function is defined as: |
| 182 | +```@example gray_scott |
| 183 | +function loss(params_to_train) |
| 184 | + prediction = predict(params_to_train) |
| 185 | + if prediction.retcode != SciMLBase.ReturnCode.Success |
| 186 | + return Inf |
| 187 | + end |
| 188 | + return sum(abs2, Array(prediction) .- data_to_train) |
| 189 | +end |
| 190 | +``` |
| 191 | +The function used to predict is defined as below. We have explicitly defined the sensealg here to ensure the correct AD is used, which is the `QuadratureAdjoint` here. It is a method that provides a correct gradient without errors is infinitely better than a slightly faster one that fails. It represents a pragmatic engineering choice to ensure the optimization can proceed reliably. |
| 192 | + |
| 193 | +```@example gray_scott |
| 194 | +function predict(params_to_train) |
| 195 | + return solve(prob_ude, Tsit5(), p=params_to_train, saveat=solution_true.t, |
| 196 | + sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))) |
| 197 | +end |
| 198 | +``` |
| 199 | + |
| 200 | +## Training |
| 201 | +Next, we start the training loop. |
| 202 | +```@example gray_scott |
| 203 | +loss_history = [] |
| 204 | +callback = function (p, l) |
| 205 | + push!(loss_history, l) |
| 206 | + println("Current loss: ", l) |
| 207 | + return false |
| 208 | +end |
| 209 | +``` |
| 210 | +Here, two training stages were used. The first stage uses a high learning rate to quickly move the model's parameters across the loss landscape towards the general area of a good solution. The second stage then uses a much lower learning rate to carefully fine-tune the parameters, allowing the model to settle precisely into a deep minimum without the risk of overshooting it. This two-phase approach combines the benefits of rapid initial progress with a stable and accurate final convergence. |
| 211 | + |
| 212 | +```@example gray_scott |
| 213 | +adtype = Optimization.AutoZygote() |
| 214 | +optf = Optimization.OptimizationFunction((p_train, p) -> loss(p_train), adtype) |
| 215 | +optprob = Optimization.OptimizationProblem(optf, p_ude) |
| 216 | +
|
| 217 | +println("Phase 1: Training with ADAM (learning rate 0.01)...") |
| 218 | +result = Optimization.solve(optprob, ADAM(0.01), callback=callback, maxiters=100) |
| 219 | +
|
| 220 | +println("\nPhase 2: Refining with ADAM (learning rate 0.001)...") |
| 221 | +optprob2 = Optimization.OptimizationProblem(optf, result.u) |
| 222 | +result2 = Optimization.solve(optprob2, ADAM(0.001), callback=callback, maxiters=100) |
| 223 | +println("Training complete.") |
| 224 | +``` |
| 225 | + |
| 226 | +## Results and conclusion |
| 227 | +We can visualize the final results of the UDE and the ground-truth data with the following code. |
| 228 | + |
| 229 | +```@example gray_scott |
| 230 | +p_trained = result2.u |
| 231 | +final_prediction = predict(p_trained) |
| 232 | +
|
| 233 | +avg_u_true = [mean(data_to_train[:, :, 1, i]) for i in 1:length(tsteps)] |
| 234 | +avg_v_true = [mean(data_to_train[:, :, 2, i]) for i in 1:length(tsteps)] |
| 235 | +avg_u_pred = [mean(final_prediction[i][:, :, 1]) for i in 1:length(tsteps)] |
| 236 | +avg_v_pred = [mean(final_prediction[i][:, :, 2]) for i in 1:length(tsteps)] |
| 237 | +
|
| 238 | +p_comp_u = plot(tsteps, avg_u_true, label="True", color=:blue, lw=2, title="Comparison: Avg. U") |
| 239 | +plot!(tsteps, avg_u_pred, label="UDE", color=:black, linestyle=:dash, lw=2) |
| 240 | +xlabel!("Time") |
| 241 | +ylabel!("Avg. Concentration") |
| 242 | +
|
| 243 | +p_comp_v = plot(tsteps, avg_v_true, label="True", color=:red, lw=2, title="Comparison: Avg. V") |
| 244 | +plot!(tsteps, avg_v_pred, label="UDE", color=:black, linestyle=:dash, lw=2) |
| 245 | + xlabel!("Time") |
| 246 | +
|
| 247 | +p_comp_combined = plot(p_comp_u, p_comp_v, layout=(1, 2), size=(900, 350)) |
| 248 | +display(p_comp_combined) |
| 249 | +``` |
| 250 | + |
| 251 | +Now, as we can see, the UDE predictions match the ground-truth data very well, indicating the model has successfully learned the non-linear term. |
0 commit comments