Skip to content

Commit 39506eb

Browse files
committed
Fix calculate_tgrad
1 parent addfdf6 commit 39506eb

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@ function calculate_tgrad(sys::AbstractODESystem;
77
# t + u(t)`.
88
rhs = [detime_dvs(eq.rhs) for eq equations(sys)]
99
iv = sys.iv
10-
notime_tgrad = [expand_derivatives(ModelingToolkit.Differential(iv)(r)) for r in rhs]
11-
if simplify
12-
tgrad = ModelingToolkit.simplify.(notime_tgrad)
13-
end
1410
xs = states(sys)
15-
rule = Dict(map((x, xt) -> x=>xt, detime_dvs.(xs), xs))
16-
tgrad = substitute.(tgrad, Ref(rule))
11+
rule = Dict(map((x, xt) -> xt=>x, detime_dvs.(xs), xs))
12+
rhs = substitute.(rhs, Ref(rule))
13+
tgrad = [expand_derivatives(ModelingToolkit.Differential(iv)(r), simplify) for r in rhs]
14+
reverse_rule = Dict(map((x, xt) -> x=>xt, detime_dvs.(xs), xs))
15+
tgrad = Num.(substitute.(tgrad, Ref(reverse_rule)))
1716
sys.tgrad[] = tgrad
1817
return tgrad
1918
end

0 commit comments

Comments
 (0)