Skip to content

Commit 8668cde

Browse files
Merge pull request #3337 from hersle/fix_remake_dummy_derivative
Fix dual type promotion in remake with dummy derivatives
2 parents 009d8b8 + 60e4723 commit 8668cde

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,6 +1373,21 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
13731373
end
13741374

13751375
u0map = merge(ModelingToolkit.guesses(sys), todict(guesses), todict(u0map))
1376+
1377+
# Replace dummy derivatives in u0map: D(x) -> x_t etc.
1378+
if has_schedule(sys)
1379+
schedule = get_schedule(sys)
1380+
if !isnothing(schedule)
1381+
for (var, val) in u0map
1382+
dvar = get(schedule.dummy_sub, var, var) # with dummy derivatives
1383+
if dvar !== var # then replace it
1384+
delete!(u0map, var)
1385+
push!(u0map, dvar => val)
1386+
end
1387+
end
1388+
end
1389+
end
1390+
13761391
fullmap = merge(u0map, parammap)
13771392
u0T = Union{}
13781393
for sym in unknowns(isys)

test/extensions/ad.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Zygote
44
using SymbolicIndexingInterface
55
using SciMLStructures
66
using OrdinaryDiffEq
7+
using NonlinearSolve
78
using SciMLSensitivity
89
using ForwardDiff
910
using ChainRulesCore
@@ -103,3 +104,23 @@ vals = (1.0f0, 3ones(Float32, 3))
103104
tangent = rand_tangent(ps)
104105
fwd, back = ChainRulesCore.rrule(remake_buffer, sys, ps, idxs, vals)
105106
@inferred back(tangent)
107+
108+
@testset "Dual type promotion in remake with dummy derivatives" begin # https://github.com/SciML/ModelingToolkit.jl/issues/3336
109+
# Throw ball straight up into the air
110+
@variables y(t)
111+
eqs = [D(D(y)) ~ -9.81]
112+
initialization_eqs = [y^2 ~ 0] # initialize y = 0 in a way that builds an initialization problem
113+
@named sys = ODESystem(eqs, t; initialization_eqs)
114+
sys = structural_simplify(sys)
115+
116+
# Find initial throw velocity that reaches exactly 10 m after 1 s
117+
dprob0 = ODEProblem(sys, [D(y) => NaN], (0.0, 1.0), []; guesses = [y => 0.0])
118+
function f(ics, _)
119+
dprob = remake(dprob0, u0 = Dict(D(y) => ics[1]))
120+
dsol = solve(dprob, Tsit5())
121+
return [dsol[y][end] - 10.0]
122+
end
123+
nprob = NonlinearProblem(f, [1.0])
124+
nsol = solve(nprob, NewtonRaphson())
125+
@test nsol[1] 10.0 / 1.0 + 9.81 * 1.0 / 2 # anal free fall solution is y = v0*t - g*t^2/2 -> v0 = y/t + g*t/2
126+
end

0 commit comments

Comments
 (0)