Skip to content

Commit 52a6315

Browse files
Adapt setup for DiffEqGPU DAEs
1 parent d3e51b8 commit 52a6315

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

src/adapt.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,25 @@ function adapt_structure(to,
1616
adapt(to, prob.kwargs)...)
1717
end
1818

19-
function adapt_structure(to, f::ODEFunction{iip}) where {iip}
20-
if f.mass_matrix !== I && f.initialization_data !== nothing
21-
error("Adaptation to GPU failed: DAEs of ModelingToolkit currently not supported.")
22-
end
23-
ODEFunction{iip, FullSpecialize}(f.f, jac = f.jac, mass_matrix = f.mass_matrix)
19+
# Allow DAE adaptation for GPU kernels
20+
function adapt_structure(to, f::SciMLBase.ODEFunction{iip}) where {iip}
21+
# For GPU kernels, we now support DAEs with mass matrices and initialization
22+
SciMLBase.ODEFunction{iip, SciMLBase.FullSpecialize}(
23+
f.f,
24+
jac = f.jac,
25+
mass_matrix = f.mass_matrix,
26+
initialization_data = f.initialization_data
27+
)
28+
end
29+
30+
# Adapt OverrideInitData for GPU compatibility
31+
function adapt_structure(to, f::SciMLBase.OverrideInitData)
32+
SciMLBase.OverrideInitData(
33+
adapt(to, f.initializeprob), # Also adapt initializeprob
34+
f.update_initializeprob!,
35+
f.initializeprobmap,
36+
f.initializeprobpmap,
37+
nothing, # Set metadata to nothing for GPU compatibility
38+
f.is_update_oop
39+
)
2440
end

0 commit comments

Comments
 (0)