Skip to content

Commit dff202d

Browse files
Avoid GG via EvalFunc for standard diffeq usage
1 parent ea35575 commit dff202d

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -144,29 +144,29 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
144144
jac = false, Wfact = false,
145145
sparse = false,
146146
kwargs...) where {iip}
147-
f_oop,f_iip = generate_function(sys, dvs, ps; expression=Val{false}, kwargs...)
148147

148+
f_oop,f_iip = ModelingToolkit.eval.(generate_function(sys, dvs, ps; expression=Val{false}, kwargs...))
149149
f(u,p,t) = f_oop(u,p,t)
150150
f(du,u,p,t) = f_iip(du,u,p,t)
151151

152152
if tgrad
153-
tgrad_oop,tgrad_iip = generate_tgrad(sys, dvs, ps; expression=Val{false}, kwargs...)
153+
tgrad_oop,tgrad_iip = ModelingToolkit.eval.(generate_tgrad(sys, dvs, ps; expression=Val{false}, kwargs...))
154154
_tgrad(u,p,t) = tgrad_oop(u,p,t)
155155
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
156156
else
157157
_tgrad = nothing
158158
end
159159

160160
if jac
161-
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps; sparse = sparse, expression=Val{false}, kwargs...)
161+
jac_oop,jac_iip = ModelingToolkit.eval.(generate_jacobian(sys, dvs, ps; sparse = sparse, expression=Val{false}, kwargs...))
162162
_jac(u,p,t) = jac_oop(u,p,t)
163163
_jac(J,u,p,t) = jac_iip(J,u,p,t)
164164
else
165165
_jac = nothing
166166
end
167167

168168
if Wfact
169-
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps; expression=Val{false}, kwargs...)
169+
tmp_Wfact,tmp_Wfact_t = ModelingToolkit.eval.(generate_factorized_W(sys, dvs, ps; expression=Val{false}, kwargs...))
170170
Wfact_oop, Wfact_iip = tmp_Wfact
171171
Wfact_oop_t, Wfact_iip_t = tmp_Wfact_t
172172
_Wfact(u,p,dtgamma,t) = Wfact_oop(u,p,dtgamma,t)
@@ -181,10 +181,11 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
181181

182182
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0',M)
183183

184-
ODEFunction{iip}(f,jac=_jac,
185-
tgrad = _tgrad,
186-
Wfact = _Wfact,
187-
Wfact_t = _Wfact_t,
184+
ODEFunction{iip}(DiffEqBase.EvalFunc(f),
185+
jac = _jac === nothing ? nothing : DiffEqBase.EvalFunc(_jac),
186+
tgrad = _tgrad === nothing ? nothing : DiffEqBase.EvalFunc(_tgrad),
187+
Wfact = _Wfact === nothing ? nothing : DiffEqBase.EvalFunc(_Wfact),
188+
Wfact_t = _Wfact_t === nothing ? nothing : DiffEqBase.EvalFunc(_Wfact_t),
188189
mass_matrix = _M,
189190
syms = Symbol.(states(sys)))
190191
end

0 commit comments

Comments
 (0)