Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
end

u0map = merge(ModelingToolkit.guesses(sys), todict(guesses), todict(u0map))
u0map = Dict(diff2term(var) => val for (var, val) in u0map) # replace D(x) -> x_t etc.
fullmap = merge(u0map, parammap)
u0T = Union{}
for sym in unknowns(isys)
Expand Down
20 changes: 20 additions & 0 deletions test/extensions/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Zygote
using SymbolicIndexingInterface
using SciMLStructures
using OrdinaryDiffEq
using NonlinearSolve
using SciMLSensitivity
using ForwardDiff
using ChainRulesCore
Expand Down Expand Up @@ -103,3 +104,22 @@ vals = (1.0f0, 3ones(Float32, 3))
tangent = rand_tangent(ps)
fwd, back = ChainRulesCore.rrule(remake_buffer, sys, ps, idxs, vals)
@inferred back(tangent)

@testset "Dual type promotion in remake with dummy derivatives" begin # https://github.com/SciML/ModelingToolkit.jl/issues/3336
# Throw ball straight up into the air
@variables y(t)
eqs = [D(D(y)) ~ -9.81]
initialization_eqs = [y^2 ~ 0] # initialize y = 0 in a way that builds an initialization problem
@named sys = ODESystem(eqs, t; initialization_eqs)
sys = structural_simplify(sys)

# Find initial throw velocity that reaches exactly 10 m after 1 s
dprob0 = ODEProblem(sys, [D(y) => NaN], (0.0, 1.0), []; guesses = [y => 0.0])
nprob = NonlinearProblem((ics, _) -> begin
dprob = remake(dprob0, u0 = Dict(D(y) => ics[1]))
dsol = solve(dprob, Tsit5())
return [dsol[y][end] - 10.0]
end, [1.0])
nsol = solve(nprob, NewtonRaphson())
@test nsol[1] ≈ 10.0/1.0 + 9.81*1.0/2 # anal free fall solution is y = v0*t - g*t^2/2 -> v0 = y/t + g*t/2
end
Loading