Skip to content

Commit 3619971

Browse files
Use ODEFunctionExpr and simplify code
1 parent 8e63516 commit 3619971

File tree

1 file changed

+1
-123
lines changed

1 file changed

+1
-123
lines changed

src/ode_def_opts.jl

Lines changed: 1 addition & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -78,127 +78,5 @@ function ode_def_opts(name::Symbol, opts::Dict{Symbol, Bool}, curmod, ex::Expr,
7878
mtk_diffeqs = [D(vars[i]) ~ mtk_ops[i] for i in 1:length(vars)]
7979

8080
sys = ODESystem(mtk_diffeqs, t, vars, params, name = gensym(:Parameterized))
81-
82-
f_ex_oop, f_ex_iip = ModelingToolkit.generate_function(sys, vars, params)
83-
84-
if opts[:build_tgrad]
85-
try
86-
tgrad_ex_oop, tgrad_ex_iip = ModelingToolkit.generate_tgrad(sys, vars, params)
87-
catch
88-
@warn "tgrad construction failed"
89-
tgrad_ex_oop, tgrad_ex_iip = nothing, nothing
90-
end
91-
else
92-
tgrad_ex_oop, tgrad_ex_iip = nothing, nothing
93-
end
94-
95-
if opts[:build_jac]
96-
try
97-
J_ex_oop, J_ex_iip = ModelingToolkit.generate_jacobian(sys, vars, params)
98-
catch
99-
@warn "Jacobian construction failed"
100-
J_ex_oop, J_ex_iip = nothing, nothing
101-
end
102-
else
103-
J_ex_oop, J_ex_iip = nothing, nothing
104-
end
105-
106-
if opts[:build_invW] && length(mtk_diffeqs) < 4
107-
try
108-
W_exs = ModelingToolkit.generate_factorized_W(sys, vars, params, false)
109-
W_ex_oop, W_ex_iip = W_exs[1]
110-
W_t_ex_oop, W_t_ex_iip = W_exs[2]
111-
catch
112-
@warn "W-expression construction failed"
113-
W_ex_oop, W_ex_iip = (nothing, nothing)
114-
W_t_ex_oop, W_t_ex_iip = (nothing, nothing)
115-
end
116-
else
117-
W_ex_oop, W_ex_iip = (nothing, nothing)
118-
W_t_ex_oop, W_t_ex_iip = (nothing, nothing)
119-
end
120-
121-
fname = gensym(:ParameterizedDiffEqFunction)
122-
tname = gensym(:ParameterizedTGradFunction)
123-
jname = gensym(:ParameterizedJacobianFunction)
124-
Wname = gensym(:ParameterizedWFactFunction)
125-
W_tname = gensym(:ParameterizedW_tFactFunction)
126-
funcname = gensym(:ParameterizedODEFunction)
127-
128-
if tgrad_ex_oop !== nothing
129-
full_tex = quote
130-
$tname($(tgrad_ex_oop.args[1].args...)) = $(tgrad_ex_oop.args[2])
131-
$tname($(tgrad_ex_iip.args[1].args...)) = $(tgrad_ex_iip.args[2])
132-
end
133-
else
134-
full_tex = quote
135-
$tname = nothing
136-
end
137-
end
138-
139-
if J_ex_oop !== nothing
140-
full_jex = quote
141-
$jname($(J_ex_oop.args[1].args...)) = $(J_ex_oop.args[2])
142-
$jname($(J_ex_iip.args[1].args...)) = $(J_ex_iip.args[2])
143-
end
144-
else
145-
full_jex = quote
146-
$jname = nothing
147-
end
148-
end
149-
150-
if W_ex_oop !== nothing
151-
full_wex = quote
152-
$Wname($(W_ex_oop.args[1].args...)) = $(W_ex_oop.args[2])
153-
$Wname($(W_ex_iip.args[1].args...)) = $(W_ex_iip.args[2])
154-
$W_tname($(W_t_ex_oop.args[1].args...)) = $(W_t_ex_oop.args[2])
155-
$W_tname($(W_t_ex_iip.args[1].args...)) = $(W_t_ex_iip.args[2])
156-
end
157-
else
158-
full_wex = quote
159-
$Wname = nothing
160-
$W_tname = nothing
161-
end
162-
end
163-
164-
quote
165-
struct $name{F, TG, TJ, TW, TWt, S} <:
166-
ParameterizedFunctions.DiffEqBase.AbstractParameterizedFunction{true}
167-
f::F
168-
mass_matrix::ParameterizedFunctions.LinearAlgebra.UniformScaling{Bool}
169-
analytic::Nothing
170-
tgrad::TG
171-
jac::TJ
172-
jvp::Nothing
173-
vjp::Nothing
174-
jac_prototype::Nothing
175-
sparsity::Nothing
176-
Wfact::TW
177-
Wfact_t::TWt
178-
paramjac::Nothing
179-
syms::Vector{Symbol}
180-
indepvar::Symbol
181-
colorvec::Nothing
182-
sys::S
183-
initialization_data::Nothing
184-
nlprob_data::Nothing
185-
end
186-
187-
(f::$name)(args...) = f.f(args...)
188-
189-
function ParameterizedFunctions.SciMLBase.remake(func::$name; kwargs...)
190-
return func
191-
end
192-
193-
$fname($(f_ex_oop.args[1].args...)) = $(f_ex_oop.args[2])
194-
$fname($(f_ex_iip.args[1].args...)) = $(f_ex_iip.args[2])
195-
$full_tex
196-
$full_jex
197-
$full_wex
198-
199-
$name($fname, ParameterizedFunctions.LinearAlgebra.I, nothing, $tname, $jname,
200-
nothing, nothing,
201-
nothing, nothing, $Wname, $W_tname, nothing, $syms, $(Meta.quot(depvar)),
202-
nothing, $sys, nothing, nothing)
203-
end |> esc
81+
ODEFunctionExpr(sys, tgrad = opts[:build_tgrad], jac = opts[:build_jac])
20482
end

0 commit comments

Comments
 (0)