Skip to content

Commit 8493ed3

Browse files
safe ODEFunction
1 parent c332a17 commit 8493ed3

File tree

1 file changed

+48
-7
lines changed

1 file changed

+48
-7
lines changed

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -209,15 +209,56 @@ Create an `ODEFunction` from the [`ODESystem`](@ref). The arguments `dvs` and `p
209209
are used to set the order of the dependent variable and parameter vectors,
210210
respectively.
211211
"""
212-
function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs, ps;
212+
function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs, ps,
213+
safe = Val{true};
213214
version = nothing,
214-
jac = false, Wfact = false) where iip
215-
expr = eval(generate_function(sys, dvs, ps))
216-
jac_expr = jac ? nothing : eval(generate_jacobian(sys, dvs, ps))
217-
Wfact_expr,Wfact_t_expr = Wfact ? (nothing,nothing) : eval.(generate_factorized_W(sys, dvs, ps))
218-
ODEFunction{iip}(eval(expr),jac=jac_expr,
219-
Wfact = Wfact_expr, Wfact_t = Wfact_t_expr)
215+
jac = false, Wfact = false) where {iip}
216+
_f = eval(generate_function(sys, dvs, ps))
217+
out_f_safe(u,p,t) = ModelingToolkit.fast_invokelatest(_f,typeof(u),u,p,t)
218+
out_f_safe(du,u,p,t) = ModelingToolkit.fast_invokelatest(_f,Nothing,du,u,p,t)
219+
out_f(u,p,t) = _f(u,p,t)
220+
out_f(du,u,p,t) = _f(du,u,p,t)
221+
222+
if jac
223+
@show generate_jacobian(sys, dvs, ps)
224+
_jac = eval(generate_jacobian(sys, dvs, ps))
225+
jac_f_safe(u,p,t) = ModelingToolkit.fast_invokelatest(_jac,Matrix{eltype(u)},u,p,t)
226+
jac_f_safe(J,u,p,t) = ModelingToolkit.fast_invokelatest(_jac,Nothing,J,u,p,t)
227+
jac_f(u,p,t) = _jac(u,p,t)
228+
jac_f(J,u,p,t) = _jac(J,u,p,t)
229+
else
230+
jac_f_safe = nothing
231+
jac_f = nothing
232+
end
233+
234+
if Wfact
235+
_Wfact,_Wfact_t = eval.(generate_factorized_W(sys, dvs, ps))
236+
Wfact_f_safe(u,p,t) = ModelingToolkit.fast_invokelatest(_Wfact,Matrix{eltype(u)},u,p,t)
237+
Wfact_f_safe(J,u,p,t) = ModelingToolkit.fast_invokelatest(_Wfact,Nothing,J,u,p,t)
238+
Wfact_f_t_safe(u,p,t) = ModelingToolkit.fast_invokelatest(_Wfact,Matrix{eltype(u)},u,p,t)
239+
Wfact_f_t_safe(J,u,p,t) = ModelingToolkit.fast_invokelatest(_Wfact,Nothing,J,u,p,t)
240+
Wfact_f(u,p,t) = _Wfact(u,p,t)
241+
Wfact_f(J,u,p,t) = _Wfact(J,u,p,t)
242+
Wfact_f_t(u,p,t) = _Wfact_t(u,p,t)
243+
Wfact_f_t(J,u,p,t) = _Wfact_t(J,u,p,t)
244+
else
245+
Wfact_f_safe = nothing
246+
Wfact_f_t_safe = nothing
247+
Wfact_f = nothing
248+
Wfact_t_f = nothing
249+
end
250+
251+
if safe === Val{true}
252+
ODEFunction{iip}(out_f_safe,jac=jac_f_safe,
253+
Wfact = Wfact_f_safe,
254+
Wfact_t = Wfact_f_t_safe)
255+
else
256+
ODEFunction{iip}(out_f,jac=jac_f,
257+
Wfact = Wfact_f,
258+
Wfact_t = Wfact_t_f)
259+
end
220260
end
261+
221262
function DiffEqBase.ODEFunction(sys::ODESystem, args...; kwargs...)
222263
ODEFunction{true}(sys, args...; kwargs...)
223264
end

0 commit comments

Comments
 (0)