diff --git a/docs/src/examples/pde/brusselator.md b/docs/src/examples/pde/brusselator.md index b57e2b001..48503b8eb 100644 --- a/docs/src/examples/pde/brusselator.md +++ b/docs/src/examples/pde/brusselator.md @@ -1,4 +1,4 @@ -# Learning Nonlinear Reaction Dynamics in the 2D Brusselator PDE Using Universal Differential Equations +# Learning Nonlinear Reaction Dynamics in the 2D Brusselator PDE Using Universal Differential Equations and Multiple Shooting ## Introduction @@ -8,71 +8,70 @@ The Brusselator is a mathematical model used to describe oscillating chemical re The Brusselator PDE is defined on a unit square periodic domain as follows: -$$ -\frac{\partial U}{\partial t} = B + U^2V - (A+1)U + \alpha \nabla^2 U + f(x, y, t) -$$ +```math +\frac{\partial U}{\partial t} = B + U^2V - (A+1)U + \alpha \nabla^2 U + f(x, y, t) +``` -$$ -\frac{\partial V}{\partial t} = AU - U^2V + \alpha \nabla^2 V -$$ +```math +\frac{\partial V}{\partial t} = AU - U^2V + \alpha \nabla^2 +``` where $A=3.4, B=1$ and the forcing term is: -$$ +```math f(x, y, t) = \begin{cases} 5 & \text{if } (x - 0.3)^2 + (y - 0.6)^2 \leq 0.1^2 \text{ and } t \geq 1.1 \\ 0 & \text{otherwise} \end{cases} -$$ +``` and the Laplacian operator is: -$$ +```math \nabla^2 = \frac{\partial^2}{\partial x^2} + \frac{\partial^2}{\partial y^2} -$$ +``` These equations are solved over the time interval: -$$ +```math t \in [0, 11.5] -$$ +``` with the initial conditions: -$$ -U(x, y, 0) = 22 \cdot \left( y(1 - y) \right)^{3/2} -$$ +```math +U(x, y, 0) = 22 \cdot \left( y(1 - y) \right)^{3/2} +``` -$$ +```math V(x, y, 0) = 27 \cdot \left( x(1 - x) \right)^{3/2} -$$ +``` and the periodic boundary conditions: -$$ -U(x + 1, y, t) = U(x, y, t) -$$ - -$$ +```math +U(x + 1, y, t) = U(x, y, t) +``` +```math V(x, y + 1, t) = V(x, y, t) -$$ +``` ## Numerical Discretization - +f To numerically solve this PDE, we discretize the unit square domain using $N$ grid points along each spatial dimension. The variables $U[i,j]$ and $V[i,j]$ then denote the concentrations at the grid point $(i, j)$ at a given time $t$. We represent the spatially discretized fields as: -$$ +```math U[i,j] = U(i \cdot \Delta x, j \cdot \Delta y), \quad V[i,j] = V(i \cdot \Delta x, j \cdot \Delta y), -$$ +``` where $\Delta x = \Delta y = \frac{1}{N}$ for a grid of size $N \times N$. To organize the simulation state efficiently, we store both $ U $ and $ V $ in a single 3D array: -$$ +```math u[i,j,1] = U[i,j], \quad u[i,j,2] = V[i,j], -$$ +``` giving us a field tensor of shape $(N, N, 2)$. This structure is flexible and extends naturally to systems with additional field variables. @@ -81,44 +80,46 @@ giving us a field tensor of shape $(N, N, 2)$. This structure is flexible and ex For spatial derivatives, we apply a second-order central difference scheme using a three-point stencil. The Laplacian is discretized as: -$$ +```math [\ 1,\ -2,\ 1\ ] -$$ +``` -in both the $ x $ and $ y $ directions, forming a tridiagonal structure in both the x and y directions; applying this 1D stencil (scaled appropriately by $\frac{1}{Δx^2}$ or $\frac{1}{Δy^2}$) along each axis and summing the contributions yields the standard 5-point stencil computation for the 2D Laplacian. Periodic boundary conditions are incorporated by wrapping the stencil at the domain edges, effectively connecting the boundaries. The nonlinear interaction terms are computed directly at each grid point, making the implementation straightforward and local in nature. +in both the $x$ and $y$ directions, forming a tridiagonal structure in both the x and y directions; applying this 1D stencil (scaled appropriately by $\frac{1}{Δx^2}$ or $\frac{1}{Δy^2}$) along each axis and summing the contributions yields the standard 5-point stencil computation for the 2D Laplacian. Periodic boundary conditions are incorporated by wrapping the stencil at the domain edges, effectively connecting the boundaries. The nonlinear interaction terms are computed directly at each grid point, making the implementation straightforward and local in nature. ## Generating Training Data This provides us with an `ODEProblem` that can be solved to obtain training data. ```@example bruss -using ComponentArrays, Random, Plots, OrdinaryDiffEq +using ComponentArrays, Random, Plots, OrdinaryDiffEq, Statistics +using Lux, Optimization, OptimizationOptimJL, SciMLSensitivity, Zygote, OptimizationOptimisers +# Grid and Time Setup N_GRID = 16 XYD = range(0f0, stop = 1f0, length = N_GRID) dx = step(XYD) T_FINAL = 11.5f0 SAVE_AT = 0.5f0 tspan = (0.0f0, T_FINAL) -t_points = range(tspan[1], stop=tspan[2], step=SAVE_AT) +t_points = collect(range(tspan[1], stop=tspan[2], step=SAVE_AT)) A, B, alpha = 3.4f0, 1.0f0, 10.0f0 -brusselator_f(x, y, t) = (((x - 0.3f0)^2 + (y - 0.6f0)^2) <= 0.01f0) * (t >= 1.1f0) * 5.0f0 +# Helper Functions limit(a, N) = a == 0 ? N : a == N+1 ? 1 : a +brusselator_f(x, y, t) = (((x - 0.3f0)^2 + (y - 0.6f0)^2) <= 0.01f0) * (t >= 1.1f0) * 5.0f0 + function init_brusselator(xyd) - println("[Init] Creating initial condition array...") u0 = zeros(Float32, N_GRID, N_GRID, 2) for I in CartesianIndices((N_GRID, N_GRID)) x, y = xyd[I[1]], xyd[I[2]] u0[I,1] = 22f0 * (y * (1f0 - y))^(3f0/2f0) u0[I,2] = 27f0 * (x * (1f0 - x))^(3f0/2f0) end - println("[Init] Done.") return u0 end -u0 = init_brusselator(XYD) +# Ground Truth PDE function pde_truth!(du, u, p, t) A, B, alpha, dx = p αdx = alpha / dx^2 @@ -135,8 +136,9 @@ function pde_truth!(du, u, p, t) end end +u0 = init_brusselator(XYD) p_tuple = (A, B, alpha, dx) -@time sol_truth = solve(ODEProblem(pde_truth!, u0, tspan, p_tuple), FBDF(), saveat=t_points) +sol_truth = solve(ODEProblem(pde_truth!, u0, tspan, p_tuple), FBDF(), saveat=t_points) u_true = Array(sol_truth) ``` @@ -144,8 +146,6 @@ u_true = Array(sol_truth) We can now use this code for training our UDE, and generating time-series plots of the concentrations of species of U and V using the code: ```@example bruss -using Plots, Statistics - # Compute average concentration at each timestep avg_U = [mean(snapshot[:, :, 1]) for snapshot in sol_truth.u] avg_V = [mean(snapshot[:, :, 2]) for snapshot in sol_truth.u] @@ -158,27 +158,20 @@ plot!(sol_truth.t, avg_V, label="Mean V", lw=2, linestyle=:dash) With the ground truth data generated and visualized, we are now ready to construct a Universal Differential Equation (UDE) by replacing the nonlinear term $U^2V$ with a neural network. The next section outlines how we define this hybrid model and train it to recover the reaction dynamics from data. -## Universal Differential Equation (UDE) Formulation - -In the original Brusselator model, the nonlinear reaction term \( U^2V \) governs key dynamic behavior. In our UDE approach, we replace this known term with a trainable neural network \( \mathcal{N}_\theta(U, V) \), where \( \theta \) are the learnable parameters. +## Universal Differential Equation (UDE) Formulation with Multiple Shooting -The resulting system becomes: - -$$ +```math \frac{\partial U}{\partial t} = 1 + \mathcal{N}_\theta(U, V) - 4.4U + \alpha \nabla^2 U + f(x, y, t) -$$ +``` -$$ +```math \frac{\partial V}{\partial t} = 3.4U - \mathcal{N}_\theta(U, V) + \alpha \nabla^2 V -$$ +``` Here, $\mathcal{N}_\theta(U, V)$ is trained to approximate the true interaction term $U^2V$ using simulation data. This hybrid formulation allows us to recover unknown or partially known physical processes while preserving the known structural components of the PDE. First, we have to define and configure the neural network that has to be used for the training. The implementation for that is as follows: - ```@example bruss -using Lux, Random, Optimization, OptimizationOptimJL, SciMLSensitivity, Zygote - model = Lux.Chain(Dense(2 => 16, tanh), Dense(16 => 1)) rng = Random.default_rng() ps_init, st = Lux.setup(rng, model) @@ -216,44 +209,91 @@ function pde_ude!(du, u, ps_nn, t) end prob_ude_template = ODEProblem(pde_ude!, u0, tspan, ps_init) ``` -## Loss Function and Optimization -To train the neural network -$\mathcal{N}_\theta(U, V)$ embedded in the UDE, we define a loss function that measures how closely the solution of the UDE matches the ground truth data generated earlier. -The loss is computed as the sum of squared errors between the predicted solution from the UDE and the true solution at each saved time point. If the solver fails (e.g., due to numerical instability or incorrect parameters), we return an infinite loss to discard that configuration during optimization. We use ```FBDF()``` as the solver due to the stiff nature of the brusselators euqation. Other solvers like ```KenCarp47()``` could also be used. +### Multiple Shooting +Traditional single-shooting training for stiff PDEs like the Brusselator often leads to instability or suboptimal learning due to long simulation horizons. Multiple shooting mitigates this by dividing the overall time span into shorter, manageable segments. This: + +* Prevents error accumulation, +* Encourages better generalization, +* And enforces continuity between segments. + +First, we have to conduct the time segmentation: +```@example bruss +segment_duration = 2.5f0 # 5 steps of SAVE_AT +n_segments = floor(Int, T_FINAL / segment_duration) # This will calculate n_segments = 4 + +# Create segments based on the duration, not a fixed number +segment_times = range(tspan[1], step=segment_duration, length=n_segments + 1) +segment_spans = [(segment_times[i], segment_times[i+1]) for i in 1:n_segments] + +# The rest of the code remains the same +segment_saves = [collect(range(t[1], stop=t[2], step=SAVE_AT)) for t in segment_spans] -To efficiently compute gradients of the loss with respect to the neural network parameters, we use an adjoint sensitivity method (`GaussAdjoint`), which performs high-accuracy quadrature-based integration of the adjoint equations. This approach enables scalable and memory-efficient training for stiff PDEs by avoiding full trajectory storage while maintaining accurate gradient estimates. +function match_time_indices(t_points, segment_saves) + return [map(ti -> findmin(abs.(t_points .- ti))[2], segment_saves[i]) for i in 1:length(segment_saves)] +end -The loss function and initial evaluation are implemented as follows: +segment_time_indices = match_time_indices(t_points, segment_saves) +``` +Then, we create an individual problem for each segment: ```@example bruss -println("[Loss] Defining loss function...") -function loss_fn(ps, _) - prob = remake(prob_ude_template, p=ps) - sol = solve(prob, FBDF(), saveat=t_points) - # Failed solve - if !SciMLBase.successful_retcode(sol) - return Inf32 - end - pred = Array(sol) - lval = sum(abs2, pred .- u_true) / length(u_true) - return lval +function get_segment_prob(ps, u0_seg, seg_idx) + remake(prob_ude_template, u0=u0_seg, tspan=segment_spans[seg_idx], p=ps) end ``` -Once the loss function is defined, we use the ADAM optimizer to train the neural network. The optimization problem is defined using SciML's ```Optimization.jl``` tools, and gradients are computed via automatic differentiation using ```AutoZygote()``` from ```SciMLSensitivity```: +#### Loss Function and Optimization +To train the neural network +$\mathcal{N}_\theta(U, V)$ embedded in the UDE, we implement a multiple shooting loss function that segments the full simulation into smaller time intervals and enforces temporal consistency across them. +For each segment, the loss is computed as the sum of squared errors between the predicted solution and the ground truth data at saved time points. To ensure continuity across segments, we introduce a penalty ($\lambda$) that measures the difference between the final predicted state of one segment and the initial true state of the next. If any segment fails to solve (due to instability or divergence), an infinite loss is returned to discard that parameter configuration during optimization. + +Although adjoint sensitivity methods such as `GaussAdjoint` are often used in stiff problems to reduce memory load, multiple shooting naturally mitigates this need by shortening the integration window for each segment. Hence, we rely on `AutoZygote()` for automatic differentiation in our implementation. + +This approach improves training robustness by constraining long-term predictions and encouraging accurate short-term learning within each segment. The final optimization is carried out using the `ADAM` algorithm over all neural network parameters. + + +The loss function is defined below: ```@example bruss -println("[Training] Starting optimization...") -using OptimizationOptimisers -optf = OptimizationFunction(loss_fn, AutoZygote()) +λ = 10.0f0 +function loss_fn_multi(ps, _) + total_loss = 0f0 + u0_seg = copy(u0) + for i in 1:n_segments + prob_i = get_segment_prob(ps, u0_seg, i) + sol_i = solve(prob_i, FBDF(), saveat=segment_saves[i]) + if !SciMLBase.successful_retcode(sol_i) + return Inf32 + end + pred_i = Array(sol_i) + t_idxs = segment_time_indices[i] + println("Segment $i: matched indices = ", t_idxs) + if isempty(t_idxs) + error("No matching time points for segment $i — check SAVE_AT, t_points, or tolerance.") + end + true_i = u_true[:,:,:,t_idxs] + total_loss += sum(abs2, pred_i .- true_i) / length(true_i) + if i < n_segments + u0_seg = pred_i[:,:,:,end] + next_u0 = u_true[:,:,:,t_idxs[end]+1] + total_loss += λ * sum(abs2, u0_seg .- next_u0) / length(next_u0) + end + end + return total_loss +end +``` +Once the loss function is defined, we use the ADAM optimizer to train the neural network. The optimization problem is defined using SciML's ```Optimization.jl``` tools, and gradients are computed via automatic differentiation using ```AutoZygote()``` from ```SciMLSensitivity```: +```@example bruss +optf = OptimizationFunction(loss_fn_multi, AutoZygote()) optprob = OptimizationProblem(optf, ps_init) loss_history = Float32[] - +epoch_counter = Ref(0) callback = (ps, l) -> begin + epoch_counter[] += 1 push!(loss_history, l) - println("Epoch $(length(loss_history)): Loss = $l") + println("Epoch $(epoch_counter[]): Loss = $l") false end ``` @@ -261,7 +301,7 @@ end Finally to run everything: ```@example bruss -res = solve(optprob, Optimisers.Adam(0.01), callback=callback, maxiters=100) +res = solve(optprob, Optimisers.Adam(0.01), callback=callback, maxiters=10000) ``` ```@example bruss @@ -269,9 +309,9 @@ res.objective ``` ```@example bruss -println("[Plot] Final U/V comparison plots...") center = N_GRID ÷ 2 -sol_final = solve(remake(prob_ude_template, p=res.u), FBDF(), saveat=t_points) +sol_final = solve(remake(prob_ude_template, p=res.u), FBDF(), saveat=t_points, abstol=1e-6, reltol=1e-6) + pred = Array(sol_final) p1 = plot(t_points, u_true[center,center,1,:], lw=2, label="U True") @@ -287,6 +327,6 @@ plot(p1, p2, layout=(1,2), size=(900,400)) ## Results and Conclusion -After training the Universal Differential Equation (UDE), we compared the predicted dynamics to the ground truth for both chemical species. +After training the Universal Differential Equation (UDE) using the multiple shooting strategy, we compared the predicted dynamics to the ground truth for both chemical species. -The low training loss shows us that the neural network in the UDE was able to understand the underlying dynamics, and it was able to learn the $U^2V$ term in the partial differential equation. +The low training loss across segments demonstrates that the neural network was able to accurately capture the underlying reaction dynamics. The model effectively learned the nonlinear $U^2V$ term through a segment-wise optimization process that enforces both data fidelity and inter-segment continuity. This confirms that multiple shooting not only stabilizes training but also enhances temporal consistency in learning complex spatiotemporal PDE systems. \ No newline at end of file