@@ -542,7 +542,7 @@ Create a Julia expression for an `ODEFunction` from the [`ODESystem`](@ref).
542542The arguments `dvs` and `ps` are used to set the order of the dependent
543543variable and parameter vectors, respectively.
544544"""
545- struct ODEFunctionExpr{iip} end
545+ struct ODEFunctionExpr{iip, specialize } end
546546
547547struct ODEFunctionClosure{O, I} <: Function
548548 f_oop:: O
551551(f:: ODEFunctionClosure )(u, p, t) = f. f_oop (u, p, t)
552552(f:: ODEFunctionClosure )(du, u, p, t) = f. f_iip (du, u, p, t)
553553
554- function ODEFunctionExpr {iip} (sys:: AbstractODESystem , dvs = unknowns (sys),
554+ function ODEFunctionExpr {iip, specialize } (sys:: AbstractODESystem , dvs = unknowns (sys),
555555 ps = parameters (sys), u0 = nothing ;
556556 version = nothing , tgrad = false ,
557557 jac = false , p = nothing ,
@@ -560,14 +560,12 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
560560 steady_state = false ,
561561 sparsity = false ,
562562 observedfun_exp = nothing ,
563- kwargs... ) where {iip}
563+ kwargs... ) where {iip, specialize }
564564 if ! iscomplete (sys)
565565 error (" A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunctionExpr`" )
566566 end
567567 f_oop, f_iip = generate_function (sys, dvs, ps; expression = Val{true }, kwargs... )
568568
569- dict = Dict ()
570-
571569 fsym = gensym (:f )
572570 _f = :($ fsym = $ ODEFunctionClosure ($ f_oop, $ f_iip))
573571 tgradsym = gensym (:tgrad )
@@ -590,30 +588,28 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
590588 _jac = :($ jacsym = nothing )
591589 end
592590
591+ Msym = gensym (:M )
593592 M = calculate_massmatrix (sys)
594-
595- _M = if sparse && ! (u0 === nothing || M === I)
596- SparseArrays. sparse (M)
593+ if sparse && ! (u0 === nothing || M === I)
594+ _M = :($ Msym = $ (SparseArrays. sparse (M)))
597595 elseif u0 === nothing || M === I
598- M
596+ _M = :( $ Msym = $ M)
599597 else
600- ArrayInterface. restructure (u0 .* u0' , M)
598+ _M = :( $ Msym = $ ( ArrayInterface. restructure (u0 .* u0' , M)) )
601599 end
602600
603601 jp_expr = sparse ? :($ similar ($ (get_jac (sys)[]), Float64)) : :nothing
604602 ex = quote
605- $ _f
606- $ _tgrad
607- $ _jac
608- M = $ _M
609- ODEFunction {$iip} ($ fsym,
610- sys = $ sys,
611- jac = $ jacsym,
612- tgrad = $ tgradsym,
613- mass_matrix = M,
614- jac_prototype = $ jp_expr,
615- sparsity = $ (sparsity ? jacobian_sparsity (sys) : nothing ),
616- observed = $ observedfun_exp)
603+ let $ _f, $ _tgrad, $ _jac, $ _M
604+ ODEFunction {$iip, $specialize} ($ fsym,
605+ sys = $ sys,
606+ jac = $ jacsym,
607+ tgrad = $ tgradsym,
608+ mass_matrix = $ Msym,
609+ jac_prototype = $ jp_expr,
610+ sparsity = $ (sparsity ? jacobian_sparsity (sys) : nothing ),
611+ observed = $ observedfun_exp)
612+ end
617613 end
618614 ! linenumbers ? Base. remove_linenums! (ex) : ex
619615end
@@ -622,6 +618,14 @@ function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
622618 ODEFunctionExpr {true} (sys, args... ; kwargs... )
623619end
624620
621+ function ODEFunctionExpr {true} (sys:: AbstractODESystem , args... ; kwargs... )
622+ return ODEFunctionExpr {true, SciMLBase.AutoSpecialize} (sys, args... ; kwargs... )
623+ end
624+
625+ function ODEFunctionExpr {false} (sys:: AbstractODESystem , args... ; kwargs... )
626+ return ODEFunctionExpr {false, SciMLBase.FullSpecialize} (sys, args... ; kwargs... )
627+ end
628+
625629"""
626630```julia
627631DAEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
0 commit comments