Skip to content

Commit cb99be4

Browse files
Merge pull request #349 from SciML/mtktutorial
Add a tutorial which showcases ModelingToolkit and SII
2 parents 8807f1c + 54ba222 commit cb99be4

File tree

5 files changed

+106
-4
lines changed

5 files changed

+106
-4
lines changed

docs/pages.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ pages = ["index.md",
55
"Tutorials" => Any[
66
"GPU Ensembles" => Any["tutorials/gpu_ensemble_basic.md",
77
"tutorials/parallel_callbacks.md",
8+
"tutorials/modelingtoolkit.md",
89
"tutorials/multigpu.md",
910
"tutorials/lower_level_api.md",
1011
"tutorials/weak_order_conv_sde.md"],
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Symbolic-Numeric GPU Acceleration with ModelingToolkit
2+
3+
[ModelingToolkit.jl](https://docs.sciml.ai/ModelingToolkit/stable/) is a symbolic-numeric
4+
computing system which allows for using symbolic transformations of equations before
5+
code generation. The goal is to improve numerical simulations by first turning them into
6+
the simplest set of equations to solve and exploiting things that normally cannot be done
7+
by hand. Those exact features are also potentially useful for GPU computing, and thus this
8+
tutorial showcases how to effectively use MTK with DiffEqGPU.jl.
9+
10+
!!! warn
11+
This tutorial currently only works for ODEs defined by ModelingToolkit. More work
12+
will be required to support DAEs in full. This is work that is ongoing and expected
13+
to be completed by the summer of 2025.
14+
15+
The core aspect to doing this right is two things. First of all, MTK respects the types
16+
chosen by the user, and thus in order for GPU kernel generation to work the user needs
17+
to ensure that the problem that is built uses static structures. For example this means
18+
that the `u0` and `p` specifications should use static arrays. This looks as follows:
19+
20+
```@example mtk
21+
using OrdinaryDiffEqTsit5, ModelingToolkit, StaticArrays
22+
using ModelingToolkit: t_nounits as t, D_nounits as D
23+
24+
@parameters σ ρ β
25+
@variables x(t) y(t) z(t)
26+
27+
eqs = [D(D(x)) ~ σ * (y - x),
28+
D(y) ~ x * (ρ - z) - y,
29+
D(z) ~ x * y - β * z]
30+
31+
@mtkbuild sys = ODESystem(eqs, t)
32+
u0 = SA[D(x) => 2f0,
33+
x => 1f0,
34+
y => 0f0,
35+
z => 0f0]
36+
37+
p = SA[σ => 28f0,
38+
ρ => 10f0,
39+
β => 8f0 / 3f0]
40+
41+
tspan = (0f0, 100f0)
42+
prob = ODEProblem{false}(sys, u0, tspan, p)
43+
sol = solve(prob, Tsit5())
44+
```
45+
46+
with the core aspect to notice are the `SA`
47+
[StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) annotations on the parts
48+
for the problem construction, along with the use of `Float32`.
49+
50+
Now one of the difficulties for building an ensemble problem is that we must make a kernel
51+
for how to construct the problems, but the use of symbolics is inherently dynamic. As such,
52+
we need to make sure that any changes to `u0` and `p` are done on the CPU and that we
53+
compile an optimized function to run on the GPU. This can be done using the
54+
[SymbolicIndexingInterface.jl](https://docs.sciml.ai/SymbolicIndexingInterface/stable/).
55+
For example, let's define a problem which randomizes the choice of `(σ, ρ, β)`. We do this
56+
by first constructing the function that will change a `prob.p` object into the updated
57+
form by changing those 3 values by using the `setsym_oop` as follows:
58+
59+
```@example mtk
60+
using SymbolicIndexingInterface
61+
sym_setter = setsym_oop(sys, [σ, ρ, β])
62+
```
63+
64+
The return `sym_setter` is our optimized function, let's see it in action:
65+
66+
```@example mtk
67+
u0, p = sym_setter(prob,@SVector(rand(Float32,3)))
68+
```
69+
70+
Notice it takes in the vector of values for `[σ, ρ, β]` and spits out the new `u0, p`. So
71+
we can build and solve an MTK generated ODE on the GPU using the following:
72+
73+
```@example mtk
74+
using DiffEqGPU, CUDA
75+
function prob_func2(prob, i, repeat)
76+
u0, p = sym_setter(prob,@SVector(rand(Float32,3)))
77+
remake(prob, u0 = u0, p = p)
78+
end
79+
80+
monteprob = EnsembleProblem(prob, prob_func = prob_func2, safetycopy = false)
81+
sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(CUDA.CUDABackend()),
82+
trajectories = 10_000)
83+
```
84+
85+
We can then using symbolic indexing on the result to inspect it:
86+
87+
```@example mtk
88+
[sol.u[i][y] for i in 1:length(sol.u)]
89+
```

src/DiffEqGPU.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ include("ensemblegpuarray/kernels.jl")
3737
include("ensemblegpuarray/problem_generation.jl")
3838
include("ensemblegpuarray/lowerlevel_solve.jl")
3939

40+
include("ensemblegpukernel/problems/ode_problems.jl")
4041
include("ensemblegpukernel/callbacks.jl")
4142
include("ensemblegpukernel/lowerlevel_solve.jl")
4243
include("ensemblegpukernel/gpukernel_algorithms.jl")
@@ -67,7 +68,7 @@ include("ensemblegpukernel/tableaus/verner_tableaus.jl")
6768
include("ensemblegpukernel/tableaus/rodas_tableaus.jl")
6869
include("ensemblegpukernel/tableaus/kvaerno_tableaus.jl")
6970

70-
include("ensemblegpukernel/problems/ode_problems.jl")
71+
7172

7273
include("utils.jl")
7374
include("algorithms.jl")

src/ensemblegpukernel/integrators/nonstiff/types.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
11
## Fixed TimeStep Integrator
22

3-
function Adapt.adapt_structure(to, prob::ODEProblem{<:Any, <:Any, iip}) where {iip}
4-
ODEProblem{iip, true}(adapt(to, prob.f),
3+
function Adapt.adapt_structure(to, prob::Union{ODEProblem{<:Any, <:Any, iip}, ImmutableODEProblem{<:Any, <:Any, iip}}) where {iip}
4+
ImmutableODEProblem{iip, true}(adapt(to, prob.f),
55
adapt(to, prob.u0),
66
adapt(to, prob.tspan),
77
adapt(to, prob.p);
88
adapt(to, prob.kwargs)...)
99
end
1010

11+
function Adapt.adapt_structure(to, f::ODEFunction{iip}) where {iip}
12+
if f.mass_matrix !== I && f.initialization_data !== nothing
13+
error("Adaptation to GPU failed: DAEs of ModelingToolkit currently not supported.")
14+
end
15+
ODEFunction{iip, SciMLBase.FullSpecialize}(f.f, jac = f.jac, mass_matrix = f.mass_matrix,
16+
initializeprobmap = f.initializeprobmap,
17+
initializeprobpmap = f.initializeprobpmap,
18+
update_initializeprob! = f.update_initializeprob!,
19+
initialization_data = nothing, initializeprob = nothing)
20+
end
21+
1122
mutable struct GPUTsit5Integrator{IIP, S, T, ST, P, F, TS, CB, AlgType} <:
1223
DiffEqBase.AbstractODEIntegrator{AlgType, IIP, S, T}
1324
alg::AlgType

src/solve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ function batch_solve_up_kernel(ensembleprob, probs, alg, ensemblealg, I, adaptiv
273273
_callback.continuous_callbacks)...)
274274

275275
dev = ensemblealg.dev
276-
probs = adapt(dev, probs)
276+
probs = adapt(dev,adapt.((dev,), probs))
277277

278278
#Adaptive version only works with saveat
279279
if adaptive

0 commit comments

Comments
 (0)