Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "6.193.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Expand Down Expand Up @@ -66,6 +67,7 @@ DiffEqBaseUnitfulExt = "Unitful"

[compat]
ArrayInterface = "7.8"
BracketingNonlinearSolve = "1.6.1"
CUDA = "5"
ChainRulesCore = "1"
ConcreteStructs = "0.2.3"
Expand Down
66 changes: 1 addition & 65 deletions ext/DiffEqBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag,
AbstractTimeseriesSolution,
RecursiveArrayTools, reduce_tup, _promote_tspan, has_continuous_callback
import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, prob2dtmin,
promote_tspan, ODE_DEFAULT_NORM,
InternalITP, nextfloat_tdir
promote_tspan, ODE_DEFAULT_NORM
import SciMLBase: isdualtype, DualEltypeChecker, sse, __sum

const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1}
Expand Down Expand Up @@ -149,67 +148,4 @@ if !hasmethod(nextfloat, Tuple{ForwardDiff.Dual})
end
end

# bisection(f, tup::Tuple{T,T}, t_forward::Bool) where {T<:ForwardDiff.Dual} = find_zero(f, tup, Roots.AlefeldPotraShi())

# Differentiation of internal solver

function scalar_nlsolve_ad(prob, alg::InternalITP, args...; kwargs...)
f = prob.f
p = value(prob.p)

if prob isa IntervalNonlinearProblem
tspan = value(prob.tspan)
newprob = IntervalNonlinearProblem(f, tspan, p; prob.kwargs...)
else
u0 = value(prob.u0)
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
end

sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u
if p isa Number
f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p)
else
f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p)
end

f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu)
pp = prob.p
sumfun = let f_x′ = -f_x
((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p)
end
partials = sum(sumfun, zip(f_p, pp))
return sol, partials
end

function SciMLBase.solve(
prob::IntervalNonlinearProblem{uType, iip,
<:ForwardDiff.Dual{T, V, P}},
alg::InternalITP, args...;
kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials),
sol.resid; retcode = sol.retcode,
left = ForwardDiff.Dual{T, V, P}(sol.left, partials),
right = ForwardDiff.Dual{T, V, P}(sol.right, partials))
end

function SciMLBase.solve(
prob::IntervalNonlinearProblem{uType, iip,
<:AbstractArray{
<:ForwardDiff.Dual{T,
V,
P},
}},
alg::InternalITP, args...;
kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)

return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials),
sol.resid; retcode = sol.retcode,
left = ForwardDiff.Dual{T, V, P}(sol.left, partials),
right = ForwardDiff.Dual{T, V, P}(sol.right, partials))
end

end
4 changes: 3 additions & 1 deletion src/DiffEqBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ Reexport.@reexport using SciMLBase

SciMLBase.isfunctionwrapper(x::FunctionWrapper) = true

# Rootfinder for callbacks
using BracketingNonlinearSolve: ITP

import SymbolicIndexingInterface as SII

## Extension Functions
Expand Down Expand Up @@ -140,7 +143,6 @@ include("utils.jl")
include("stats.jl")
include("calculate_residuals.jl")
include("tableaus.jl")
include("internal_itp.jl")
include("dae_initialization.jl")

include("callbacks.jl")
Expand Down
30 changes: 16 additions & 14 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ has_continuous_callback(cb::VectorContinuousCallback) = true
has_continuous_callback(cb::CallbackSet) = !isempty(cb.continuous_callbacks)
has_continuous_callback(cb::Nothing) = false

isforward(integrator::DEIntegrator) = isone(integrator.tdir)

# Callback handling

function get_tmp(integrator::DEIntegrator, callback)
Expand Down Expand Up @@ -359,17 +361,17 @@ end
# always ensures that if r = bisection(f, (x0, x1))
# then either f(nextfloat(r)) == 0 or f(nextfloat(r)) * f(r) < 0
# note: not really using bisection - uses the ITP method
function bisection(
f, tup, t_forward::Bool, rootfind::SciMLBase.RootfindOpt, abstol, reltol;
maxiters = 1000)
if rootfind == SciMLBase.LeftRootFind
solve(IntervalNonlinearProblem{false}(f, tup),
InternalITP(), abstol = abstol,
reltol = reltol).left
function find_root(f, tup, t_forward::Bool,
rootfind::SciMLBase.RootfindOpt, abstol, reltol)
sol = solve(IntervalNonlinearProblem{false}(f, tup),
ITP(), abstol = 0.0, reltol = 0.0)
# ODE solver convention: right is toward integration direction
# ITP solver convention: right is toward increasing t
# Note: different non-linear solvers may have different convention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They shouldn't

if xor(rootfind == SciMLBase.LeftRootFind, !t_forward)
sol.left
else
solve(IntervalNonlinearProblem{false}(f, tup),
InternalITP(), abstol = abstol,
reltol = reltol).right
sol.right
end
end

Expand Down Expand Up @@ -431,7 +433,7 @@ function find_callback_time(integrator, callback::ContinuousCallback, counter)
sign(zero_func(bottom_t)) * sign_top >= zero(sign_top) &&
error("Double callback crossing floating pointer reducer errored. Report this issue.")
end
Θ = bisection(zero_func, (bottom_t, top_t), isone(integrator.tdir),
Θ = find_root(zero_func, (bottom_t, top_t), isforward(integrator),
callback.rootfind, callback.abstol, callback.reltol)
integrator.last_event_error = DiffEqBase.value(ODE_DEFAULT_NORM(
zero_func(Θ), Θ))
Expand Down Expand Up @@ -478,7 +480,7 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun
bottom_t = integrator.tprev
end
if callback.rootfind != SciMLBase.NoRootFind && !isdiscrete(integrator.alg)
min_t = isone(integrator.tdir) ? nextfloat(top_t) : prevfloat(top_t)
min_t = isforward(integrator) ? nextfloat(top_t) : prevfloat(top_t)
min_event_idx = -1
for idx in 1:length(event_idx)
if ArrayInterface.allowed_getindex(event_idx, idx) != 0
Expand Down Expand Up @@ -506,8 +508,8 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun
error("Double callback crossing floating pointer reducer errored. Report this issue.")
end

Θ = bisection(zero_func, (bottom_t, top_t),
isone(integrator.tdir), callback.rootfind,
Θ = find_root(zero_func, (bottom_t, top_t),
isforward(integrator), callback.rootfind,
callback.abstol, callback.reltol)
if integrator.tdir * Θ < integrator.tdir * min_t
integrator.last_event_error = DiffEqBase.value(ODE_DEFAULT_NORM(
Expand Down
87 changes: 0 additions & 87 deletions src/internal_itp.jl

This file was deleted.

24 changes: 24 additions & 0 deletions test/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,27 @@ function test_find_first_callback(callbacks, int)
end
test_find_first_callback(callbacks, find_first_integrator);
@test test_find_first_callback(callbacks, find_first_integrator).bytes == 0

# https://github.com/SciML/DiffEqBase.jl/issues/1233
@testset "Inexact rootfinding" begin
# function with irrational root (sqrt(2))
irrational_f(x, p=0.0) = x^2 - 2

# Forward integration
is_forward = true
tspan = (1.0, 2.0)
before = DiffEqBase.find_root(irrational_f, tspan, is_forward, SciMLBase.LeftRootFind, 0.0, 1e-14)
after = DiffEqBase.find_root(irrational_f, tspan, is_forward, SciMLBase.RightRootFind, 0.0, 1e-14)
@test irrational_f(before) < 0.0
@test irrational_f(after) > 0.0
@test nextfloat(before) == after

# Backward integration
is_forward = false
tspan = (2.0, 1.0)
before = DiffEqBase.find_root(irrational_f, tspan, is_forward, SciMLBase.LeftRootFind, 0.0, 1e-14)
after = DiffEqBase.find_root(irrational_f, tspan, is_forward, SciMLBase.RightRootFind, 0.0, 1e-14)
@test irrational_f(before) > 0.0
@test irrational_f(after) < 0.0
@test nextfloat(after) == before
end
49 changes: 0 additions & 49 deletions test/internal_rootfinder.jl

This file was deleted.

1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ end
@time begin
if GROUP == "All" || GROUP == "Core"
@time @safetestset "Callbacks" include("callbacks.jl")
@time @safetestset "Internal Rootfinders" include("internal_rootfinder.jl")
@time @safetestset "Plot Vars" include("plot_vars.jl")
@time @safetestset "Problem Creation Tests" include("problem_creation_tests.jl")
@time @safetestset "Export tests" include("export_tests.jl")
Expand Down
Loading