Skip to content

Commit 3b8e510

Browse files
authored
Merge pull request #1233 from SciML/myb/pre
Don't use `get_postprocess_fbody` in `generate_function` when `has_di…
2 parents ae75bd8 + 4d0dde7 commit 3b8e510

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ function generate_function(
8484
sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
8585
implicit_dae=false,
8686
ddvs=implicit_dae ? map(Differential(get_iv(sys)), dvs) : nothing,
87+
has_difference=false,
8788
kwargs...
8889
)
8990
# optimization
@@ -102,7 +103,7 @@ function generate_function(
102103
p = map(x->time_varying_as_func(value(x), sys), ps)
103104
t = get_iv(sys)
104105

105-
pre = get_postprocess_fbody(sys)
106+
pre = has_difference ? (ex -> ex) : get_postprocess_fbody(sys)
106107

107108
if implicit_dae
108109
build_function(rhss, ddvs, u, p, t; postprocess_fbody=pre, kwargs...)
@@ -578,8 +579,9 @@ symbolically calculating numerical enhancements.
578579
"""
579580
function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
580581
parammap=DiffEqBase.NullParameters();kwargs...) where iip
581-
f, u0, p = process_DEProblem(ODEFunction{iip}, sys, u0map, parammap; kwargs...)
582-
if any(isdifferenceeq, equations(sys))
582+
has_difference = any(isdifferenceeq, equations(sys))
583+
f, u0, p = process_DEProblem(ODEFunction{iip}, sys, u0map, parammap; has_difference=has_difference, kwargs...)
584+
if has_difference
583585
ODEProblem{iip}(f,u0,tspan,p;difference_cb=generate_difference_cb(sys;kwargs...),kwargs...)
584586
else
585587
ODEProblem{iip}(f,u0,tspan,p;kwargs...)
@@ -603,14 +605,15 @@ symbolically calculating numerical enhancements.
603605
"""
604606
function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem,du0map,u0map,tspan,
605607
parammap=DiffEqBase.NullParameters();kwargs...) where iip
608+
has_difference = any(isdifferenceeq, equations(sys))
606609
f, du0, u0, p = process_DEProblem(
607610
DAEFunction{iip}, sys, u0map, parammap;
608-
implicit_dae=true, du0map=du0map, kwargs...
611+
implicit_dae=true, du0map=du0map, has_difference=has_difference, kwargs...
609612
)
610613
diffvars = collect_differential_variables(sys)
611614
sts = states(sys)
612615
differential_vars = map(Base.Fix2(in, diffvars), sts)
613-
if any(isdifferenceeq, equations(sys))
616+
if has_difference
614617
DAEProblem{iip}(f,du0,u0,tspan,p;difference_cb=generate_difference_cb(sys; kwargs...),differential_vars=differential_vars,kwargs...)
615618
else
616619
DAEProblem{iip}(f,du0,u0,tspan,p;differential_vars=differential_vars,kwargs...)

0 commit comments

Comments
 (0)