You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/src/examples/neural_ode/simplechains.md
+2-1Lines changed: 2 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -7,7 +7,8 @@
7
7
First, we'll need data for training the NeuralODE, which can be obtained by solving the ODE `u' = f(u,p,t)` numerically using the SciML ecosystem in Julia.
8
8
9
9
```@example sc_neuralode
10
-
import SimpleChains, OrdinaryDiffEq as ODE, SciMLSensitivity as SMS, Optimization as OPT, OptimizationOptimisers as OPO, Plots
10
+
import SimpleChains, OrdinaryDiffEq as ODE, SciMLSensitivity as SMS, Optimization as OPT,
With the ground truth data generated and visualized, we are now ready to construct a Universal Differential Equation (UDE) by replacing the nonlinear term $U^2V$ with a neural network. The next section outlines how we define this hybrid model and train it to recover the reaction dynamics from data.
@@ -154,7 +155,8 @@ Here, $\mathcal{N}_\theta(U, V)$ is trained to approximate the true interaction
154
155
First, we have to define and configure the neural network that has to be used for the training. The implementation for that is as follows:
155
156
156
157
```@example bruss
157
-
import Lux, Random, Optimization as OPT, OptimizationOptimJL as OOJ, SciMLSensitivity as SMS, Zygote
158
+
import Lux, Random, Optimization as OPT, OptimizationOptimJL as OOJ,
159
+
SciMLSensitivity as SMS, Zygote
158
160
159
161
model = Lux.Chain(Lux.Dense(2 => 16, tanh), Lux.Dense(16 => 1))
160
162
rng = Random.default_rng()
@@ -166,14 +168,15 @@ We use a simple fully connected neural network with one hidden layer of 16 tanh-
166
168
167
169
To ensure consistency between the ground truth simulation and the learned Universal Differential Equation (UDE) model, we preserve the same spatial discretization scheme used in the original ODEProblem. This includes:
168
170
169
-
* the finite difference Laplacian,
170
-
* periodic boundary conditions, and
171
-
* the external forcing function.
171
+
- the finite difference Laplacian,
172
+
- periodic boundary conditions, and
173
+
- the external forcing function.
172
174
173
-
The only change lies in the replacement of the known nonlinear term $U^2V$ with a neural network approximation
175
+
The only change lies in the replacement of the known nonlinear term $U^2V$ with a neural network approximation
174
176
$\mathcal{N}_\theta(U, V)$. This design enables the UDE to learn complex or unknown dynamics from data while maintaining the underlying physical structure of the system.
175
177
176
178
The function below implements this hybrid formulation:
179
+
177
180
```@example bruss
178
181
function pde_ude!(du, u, ps_nn, t)
179
182
αdx = alpha / dx^2
@@ -182,22 +185,24 @@ function pde_ude!(du, u, ps_nn, t)
$\mathcal{N}_\theta(U, V)$ embedded in the UDE, we define a loss function that measures how closely the solution of the UDE matches the ground truth data generated earlier.
199
204
200
-
The loss is computed as the sum of squared errors between the predicted solution from the UDE and the true solution at each saved time point. If the solver fails (e.g., due to numerical instability or incorrect parameters), we return an infinite loss to discard that configuration during optimization. We use ```FBDF()``` as the solver due to the stiff nature of the brusselators euqation. Other solvers like ```KenCarp47()``` could also be used.
205
+
The loss is computed as the sum of squared errors between the predicted solution from the UDE and the true solution at each saved time point. If the solver fails (e.g., due to numerical instability or incorrect parameters), we return an infinite loss to discard that configuration during optimization. We use `FBDF()` as the solver due to the stiff nature of the brusselators euqation. Other solvers like `KenCarp47()` could also be used.
201
206
202
207
To efficiently compute gradients of the loss with respect to the neural network parameters, we use an adjoint sensitivity method (`GaussAdjoint`), which performs high-accuracy quadrature-based integration of the adjoint equations. This approach enables scalable and memory-efficient training for stiff PDEs by avoiding full trajectory storage while maintaining accurate gradient estimates.
203
208
@@ -206,8 +211,8 @@ The loss function and initial evaluation are implemented as follows:
206
211
```@example bruss
207
212
println("[Loss] Defining loss function...")
208
213
function loss_fn(ps, _)
209
-
prob = ODE.remake(prob_ude_template, p=ps)
210
-
sol = ODE.solve(prob, ODE.FBDF(), saveat=t_points)
214
+
prob = ODE.remake(prob_ude_template, p = ps)
215
+
sol = ODE.solve(prob, ODE.FBDF(), saveat = t_points)
211
216
# Failed solve
212
217
if !SciMLBase.successful_retcode(sol)
213
218
return Inf32
@@ -218,7 +223,7 @@ function loss_fn(ps, _)
218
223
end
219
224
```
220
225
221
-
Once the loss function is defined, we use the ADAM optimizer to train the neural network. The optimization problem is defined using SciML's ```Optimization.jl``` tools, and gradients are computed via automatic differentiation using ```AutoZygote()``` from ```SciMLSensitivity```:
226
+
Once the loss function is defined, we use the ADAM optimizer to train the neural network. The optimization problem is defined using SciML's `Optimization.jl` tools, and gradients are computed via automatic differentiation using `AutoZygote()` from `SciMLSensitivity`:
After training the Universal Differential Equation (UDE), we compared the predicted dynamics to the ground truth for both chemical species.
268
272
269
-
The low training loss shows us that the neural network in the UDE was able to understand the underlying dynamics, and it was able to learn the $U^2V$ term in the partial differential equation.
273
+
The low training loss shows us that the neural network in the UDE was able to understand the underlying dynamics, and it was able to learn the $U^2V$ term in the partial differential equation.
0 commit comments