From 36199712d7f6a3a86b1b8cc9269a7d0301c92e54 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 22 Jan 2025 18:29:29 +0100 Subject: [PATCH] Use ODEFunctionExpr and simplify code --- src/ode_def_opts.jl | 124 +------------------------------------------- 1 file changed, 1 insertion(+), 123 deletions(-) diff --git a/src/ode_def_opts.jl b/src/ode_def_opts.jl index 5603fbc..ca1b916 100644 --- a/src/ode_def_opts.jl +++ b/src/ode_def_opts.jl @@ -78,127 +78,5 @@ function ode_def_opts(name::Symbol, opts::Dict{Symbol, Bool}, curmod, ex::Expr, mtk_diffeqs = [D(vars[i]) ~ mtk_ops[i] for i in 1:length(vars)] sys = ODESystem(mtk_diffeqs, t, vars, params, name = gensym(:Parameterized)) - - f_ex_oop, f_ex_iip = ModelingToolkit.generate_function(sys, vars, params) - - if opts[:build_tgrad] - try - tgrad_ex_oop, tgrad_ex_iip = ModelingToolkit.generate_tgrad(sys, vars, params) - catch - @warn "tgrad construction failed" - tgrad_ex_oop, tgrad_ex_iip = nothing, nothing - end - else - tgrad_ex_oop, tgrad_ex_iip = nothing, nothing - end - - if opts[:build_jac] - try - J_ex_oop, J_ex_iip = ModelingToolkit.generate_jacobian(sys, vars, params) - catch - @warn "Jacobian construction failed" - J_ex_oop, J_ex_iip = nothing, nothing - end - else - J_ex_oop, J_ex_iip = nothing, nothing - end - - if opts[:build_invW] && length(mtk_diffeqs) < 4 - try - W_exs = ModelingToolkit.generate_factorized_W(sys, vars, params, false) - W_ex_oop, W_ex_iip = W_exs[1] - W_t_ex_oop, W_t_ex_iip = W_exs[2] - catch - @warn "W-expression construction failed" - W_ex_oop, W_ex_iip = (nothing, nothing) - W_t_ex_oop, W_t_ex_iip = (nothing, nothing) - end - else - W_ex_oop, W_ex_iip = (nothing, nothing) - W_t_ex_oop, W_t_ex_iip = (nothing, nothing) - end - - fname = gensym(:ParameterizedDiffEqFunction) - tname = gensym(:ParameterizedTGradFunction) - jname = gensym(:ParameterizedJacobianFunction) - Wname = gensym(:ParameterizedWFactFunction) - W_tname = gensym(:ParameterizedW_tFactFunction) - funcname = gensym(:ParameterizedODEFunction) - - if tgrad_ex_oop !== nothing - full_tex = quote - $tname($(tgrad_ex_oop.args[1].args...)) = $(tgrad_ex_oop.args[2]) - $tname($(tgrad_ex_iip.args[1].args...)) = $(tgrad_ex_iip.args[2]) - end - else - full_tex = quote - $tname = nothing - end - end - - if J_ex_oop !== nothing - full_jex = quote - $jname($(J_ex_oop.args[1].args...)) = $(J_ex_oop.args[2]) - $jname($(J_ex_iip.args[1].args...)) = $(J_ex_iip.args[2]) - end - else - full_jex = quote - $jname = nothing - end - end - - if W_ex_oop !== nothing - full_wex = quote - $Wname($(W_ex_oop.args[1].args...)) = $(W_ex_oop.args[2]) - $Wname($(W_ex_iip.args[1].args...)) = $(W_ex_iip.args[2]) - $W_tname($(W_t_ex_oop.args[1].args...)) = $(W_t_ex_oop.args[2]) - $W_tname($(W_t_ex_iip.args[1].args...)) = $(W_t_ex_iip.args[2]) - end - else - full_wex = quote - $Wname = nothing - $W_tname = nothing - end - end - - quote - struct $name{F, TG, TJ, TW, TWt, S} <: - ParameterizedFunctions.DiffEqBase.AbstractParameterizedFunction{true} - f::F - mass_matrix::ParameterizedFunctions.LinearAlgebra.UniformScaling{Bool} - analytic::Nothing - tgrad::TG - jac::TJ - jvp::Nothing - vjp::Nothing - jac_prototype::Nothing - sparsity::Nothing - Wfact::TW - Wfact_t::TWt - paramjac::Nothing - syms::Vector{Symbol} - indepvar::Symbol - colorvec::Nothing - sys::S - initialization_data::Nothing - nlprob_data::Nothing - end - - (f::$name)(args...) = f.f(args...) - - function ParameterizedFunctions.SciMLBase.remake(func::$name; kwargs...) - return func - end - - $fname($(f_ex_oop.args[1].args...)) = $(f_ex_oop.args[2]) - $fname($(f_ex_iip.args[1].args...)) = $(f_ex_iip.args[2]) - $full_tex - $full_jex - $full_wex - - $name($fname, ParameterizedFunctions.LinearAlgebra.I, nothing, $tname, $jname, - nothing, nothing, - nothing, nothing, $Wname, $W_tname, nothing, $syms, $(Meta.quot(depvar)), - nothing, $sys, nothing, nothing) - end |> esc + ODEFunctionExpr(sys, tgrad = opts[:build_tgrad], jac = opts[:build_jac]) end