@@ -209,15 +209,56 @@ Create an `ODEFunction` from the [`ODESystem`](@ref). The arguments `dvs` and `p
209
209
are used to set the order of the dependent variable and parameter vectors,
210
210
respectively.
211
211
"""
212
- function DiffEqBase. ODEFunction {iip} (sys:: ODESystem , dvs, ps;
212
+ function DiffEqBase. ODEFunction {iip} (sys:: ODESystem , dvs, ps,
213
+ safe = Val{true };
213
214
version = nothing ,
214
- jac = false , Wfact = false ) where iip
215
- expr = eval (generate_function (sys, dvs, ps))
216
- jac_expr = jac ? nothing : eval (generate_jacobian (sys, dvs, ps))
217
- Wfact_expr,Wfact_t_expr = Wfact ? (nothing ,nothing ) : eval .(generate_factorized_W (sys, dvs, ps))
218
- ODEFunction {iip} (eval (expr),jac= jac_expr,
219
- Wfact = Wfact_expr, Wfact_t = Wfact_t_expr)
215
+ jac = false , Wfact = false ) where {iip}
216
+ _f = eval (generate_function (sys, dvs, ps))
217
+ out_f_safe (u,p,t) = ModelingToolkit. fast_invokelatest (_f,typeof (u),u,p,t)
218
+ out_f_safe (du,u,p,t) = ModelingToolkit. fast_invokelatest (_f,Nothing,du,u,p,t)
219
+ out_f (u,p,t) = _f (u,p,t)
220
+ out_f (du,u,p,t) = _f (du,u,p,t)
221
+
222
+ if jac
223
+ @show generate_jacobian (sys, dvs, ps)
224
+ _jac = eval (generate_jacobian (sys, dvs, ps))
225
+ jac_f_safe (u,p,t) = ModelingToolkit. fast_invokelatest (_jac,Matrix{eltype (u)},u,p,t)
226
+ jac_f_safe (J,u,p,t) = ModelingToolkit. fast_invokelatest (_jac,Nothing,J,u,p,t)
227
+ jac_f (u,p,t) = _jac (u,p,t)
228
+ jac_f (J,u,p,t) = _jac (J,u,p,t)
229
+ else
230
+ jac_f_safe = nothing
231
+ jac_f = nothing
232
+ end
233
+
234
+ if Wfact
235
+ _Wfact,_Wfact_t = eval .(generate_factorized_W (sys, dvs, ps))
236
+ Wfact_f_safe (u,p,t) = ModelingToolkit. fast_invokelatest (_Wfact,Matrix{eltype (u)},u,p,t)
237
+ Wfact_f_safe (J,u,p,t) = ModelingToolkit. fast_invokelatest (_Wfact,Nothing,J,u,p,t)
238
+ Wfact_f_t_safe (u,p,t) = ModelingToolkit. fast_invokelatest (_Wfact,Matrix{eltype (u)},u,p,t)
239
+ Wfact_f_t_safe (J,u,p,t) = ModelingToolkit. fast_invokelatest (_Wfact,Nothing,J,u,p,t)
240
+ Wfact_f (u,p,t) = _Wfact (u,p,t)
241
+ Wfact_f (J,u,p,t) = _Wfact (J,u,p,t)
242
+ Wfact_f_t (u,p,t) = _Wfact_t (u,p,t)
243
+ Wfact_f_t (J,u,p,t) = _Wfact_t (J,u,p,t)
244
+ else
245
+ Wfact_f_safe = nothing
246
+ Wfact_f_t_safe = nothing
247
+ Wfact_f = nothing
248
+ Wfact_t_f = nothing
249
+ end
250
+
251
+ if safe === Val{true }
252
+ ODEFunction {iip} (out_f_safe,jac= jac_f_safe,
253
+ Wfact = Wfact_f_safe,
254
+ Wfact_t = Wfact_f_t_safe)
255
+ else
256
+ ODEFunction {iip} (out_f,jac= jac_f,
257
+ Wfact = Wfact_f,
258
+ Wfact_t = Wfact_t_f)
259
+ end
220
260
end
261
+
221
262
function DiffEqBase. ODEFunction (sys:: ODESystem , args... ; kwargs... )
222
263
ODEFunction {true} (sys, args... ; kwargs... )
223
264
end
0 commit comments