Skip to content

Commit 54ba222

Browse files
Drop initialization and get tutorial finished
1 parent d2db6f5 commit 54ba222

File tree

4 files changed

+22
-5
lines changed

4 files changed

+22
-5
lines changed

docs/src/tutorials/modelingtoolkit.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ the simplest set of equations to solve and exploiting things that normally canno
77
by hand. Those exact features are also potentially useful for GPU computing, and thus this
88
tutorial showcases how to effectively use MTK with DiffEqGPU.jl.
99

10+
!!! warn
11+
This tutorial currently only works for ODEs defined by ModelingToolkit. More work
12+
will be required to support DAEs in full. This is work that is ongoing and expected
13+
to be completed by the summer of 2025.
14+
1015
The core aspect to doing this right is two things. First of all, MTK respects the types
1116
chosen by the user, and thus in order for GPU kernel generation to work the user needs
1217
to ensure that the problem that is built uses static structures. For example this means
@@ -24,7 +29,6 @@ eqs = [D(D(x)) ~ σ * (y - x),
2429
D(z) ~ x * y - β * z]
2530
2631
@mtkbuild sys = ODESystem(eqs, t)
27-
2832
u0 = SA[D(x) => 2f0,
2933
x => 1f0,
3034
y => 0f0,
@@ -72,6 +76,7 @@ function prob_func2(prob, i, repeat)
7276
u0, p = sym_setter(prob,@SVector(rand(Float32,3)))
7377
remake(prob, u0 = u0, p = p)
7478
end
79+
7580
monteprob = EnsembleProblem(prob, prob_func = prob_func2, safetycopy = false)
7681
sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(CUDA.CUDABackend()),
7782
trajectories = 10_000)

src/DiffEqGPU.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ include("ensemblegpuarray/kernels.jl")
3737
include("ensemblegpuarray/problem_generation.jl")
3838
include("ensemblegpuarray/lowerlevel_solve.jl")
3939

40+
include("ensemblegpukernel/problems/ode_problems.jl")
4041
include("ensemblegpukernel/callbacks.jl")
4142
include("ensemblegpukernel/lowerlevel_solve.jl")
4243
include("ensemblegpukernel/gpukernel_algorithms.jl")
@@ -67,7 +68,7 @@ include("ensemblegpukernel/tableaus/verner_tableaus.jl")
6768
include("ensemblegpukernel/tableaus/rodas_tableaus.jl")
6869
include("ensemblegpukernel/tableaus/kvaerno_tableaus.jl")
6970

70-
include("ensemblegpukernel/problems/ode_problems.jl")
71+
7172

7273
include("utils.jl")
7374
include("algorithms.jl")

src/ensemblegpukernel/integrators/nonstiff/types.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
11
## Fixed TimeStep Integrator
22

3-
function Adapt.adapt_structure(to, prob::ODEProblem{<:Any, <:Any, iip}) where {iip}
4-
ODEProblem{iip, true}(adapt(to, prob.f),
3+
function Adapt.adapt_structure(to, prob::Union{ODEProblem{<:Any, <:Any, iip}, ImmutableODEProblem{<:Any, <:Any, iip}}) where {iip}
4+
ImmutableODEProblem{iip, true}(adapt(to, prob.f),
55
adapt(to, prob.u0),
66
adapt(to, prob.tspan),
77
adapt(to, prob.p);
88
adapt(to, prob.kwargs)...)
99
end
1010

11+
function Adapt.adapt_structure(to, f::ODEFunction{iip}) where {iip}
12+
if f.mass_matrix !== I && f.initialization_data !== nothing
13+
error("Adaptation to GPU failed: DAEs of ModelingToolkit currently not supported.")
14+
end
15+
ODEFunction{iip, SciMLBase.FullSpecialize}(f.f, jac = f.jac, mass_matrix = f.mass_matrix,
16+
initializeprobmap = f.initializeprobmap,
17+
initializeprobpmap = f.initializeprobpmap,
18+
update_initializeprob! = f.update_initializeprob!,
19+
initialization_data = nothing, initializeprob = nothing)
20+
end
21+
1122
mutable struct GPUTsit5Integrator{IIP, S, T, ST, P, F, TS, CB, AlgType} <:
1223
DiffEqBase.AbstractODEIntegrator{AlgType, IIP, S, T}
1324
alg::AlgType

src/solve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ function batch_solve_up_kernel(ensembleprob, probs, alg, ensemblealg, I, adaptiv
273273
_callback.continuous_callbacks)...)
274274

275275
dev = ensemblealg.dev
276-
probs = adapt(dev, probs)
276+
probs = adapt(dev,adapt.((dev,), probs))
277277

278278
#Adaptive version only works with saveat
279279
if adaptive

0 commit comments

Comments
 (0)