Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions ext/DiffEqBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module DiffEqBaseChainRulesCoreExt

using DiffEqBase
using DiffEqBase.SciMLBase
import DiffEqBase: numargs, AbstractSensitivityAlgorithm, AbstractDEProblem, set_mooncakeoriginator_if_mooncake
import DiffEqBase: numargs, AbstractSensitivityAlgorithm, AbstractDEProblem,
set_mooncakeoriginator_if_mooncake

import ChainRulesCore
import ChainRulesCore: NoTangent
Expand All @@ -15,7 +16,8 @@ function ChainRulesCore.frule(::typeof(DiffEqBase.solve_up), prob,
u0, p, args...;
kwargs...)
DiffEqBase._solve_forward(
prob, sensealg, u0, p, set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...;
prob, sensealg, u0, p,
set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...;
kwargs...)
end

Expand All @@ -24,7 +26,8 @@ function ChainRulesCore.rrule(::typeof(DiffEqBase.solve_up), prob::AbstractDEPro
u0, p, args...;
kwargs...)
DiffEqBase._solve_adjoint(
prob, sensealg, u0, p, set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...;
prob, sensealg, u0, p,
set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...;
kwargs...)
end

Expand Down
18 changes: 7 additions & 11 deletions ext/DiffEqBaseMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,29 @@ module DiffEqBaseMooncakeExt
using DiffEqBase, Mooncake
using DiffEqBase: SciMLBase
using SciMLBase: ADOriginator, MooncakeOriginator
Mooncake.@from_rrule(
Mooncake.MinimalCtx,
Mooncake.@from_rrule(Mooncake.MinimalCtx,
Tuple{
typeof(DiffEqBase.solve_up),
DiffEqBase.AbstractDEProblem,
Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm},
Any,
Any,
Any,
Any
},
true,
)
true,)

# Dispatch for auto-alg
Mooncake.@from_rrule(
Mooncake.MinimalCtx,
Mooncake.@from_rrule(Mooncake.MinimalCtx,
Tuple{
typeof(DiffEqBase.solve_up),
DiffEqBase.AbstractDEProblem,
Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm},
Any,
Any,
Any
},
true,
)
true,)

Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any}
Mooncake.@mooncake_overlay DiffEqBase.set_mooncakeoriginator_if_mooncake(x::ADOriginator) = MooncakeOriginator

end
end
12 changes: 7 additions & 5 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ const allowedkeywords = (:dense,
# Termination condition for solvers
:termination_condition,
# For AbstractAliasSpecifier
:alias)
:alias,
# Parameter estimation with BVP
:fit_parameters)

const KWARGWARN_MESSAGE = """
Unrecognized keyword arguments found.
Expand Down Expand Up @@ -541,7 +543,8 @@ end
Get the innermost index provider using `SII.symbolic_container`.
"""
function _get_root_indp(indp)
if hasmethod(SII.symbolic_container, Tuple{typeof(indp)}) && (sc = SII.symbolic_container(indp)) !== indp
if hasmethod(SII.symbolic_container, Tuple{typeof(indp)}) &&
(sc = SII.symbolic_container(indp)) !== indp
return _get_root_indp(sc)
end
return indp
Expand Down Expand Up @@ -748,7 +751,7 @@ function build_null_solution(prob::AbstractDEProblem, args...;

prob, success = hack_null_solution_init(prob)
retcode = success ? ReturnCode.Success : ReturnCode.InitialFailure
build_solution(prob, nothing, ts, timeseries; dense=true, retcode)
build_solution(prob, nothing, ts, timeseries; dense = true, retcode)
end

function build_null_solution(
Expand Down Expand Up @@ -1139,7 +1142,6 @@ function solve(prob::NonlinearProblem, args...; sensealg = nothing,
sensealg = prob.kwargs[:sensealg]
end


if haskey(prob.kwargs, :alias_u0)
@warn "The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`."
alias_spec = NonlinearAliasSpecifier(alias_u0 = prob.kwargs[:alias_u0])
Expand All @@ -1152,7 +1154,7 @@ function solve(prob::NonlinearProblem, args...; sensealg = nothing,
alias_spec = NonlinearAliasSpecifier(alias = prob.kwargs[:alias])
elseif haskey(kwargs, :alias) && kwargs[:alias] isa Bool
alias_spec = NonlinearAliasSpecifier(alias = kwargs[:alias])
end
end

if haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa NonlinearAliasSpecifier
alias_spec = prob.kwargs[:alias]
Expand Down
4 changes: 2 additions & 2 deletions test/downstream/unitful.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ intg = init(prob, Tsit5())
@test_nowarn step!(intg, 0.02u"s", true)

@test DiffEqBase.unitfulvalue(u"1/s") == u"1/s"
@test DiffEqBase.value(ForwardDiff.Dual(1)*u"1/s") == 1
@test DiffEqBase.unitfulvalue(ForwardDiff.Dual(1)*u"1/s") == u"1/s"
@test DiffEqBase.value(ForwardDiff.Dual(1) * u"1/s") == 1
@test DiffEqBase.unitfulvalue(ForwardDiff.Dual(1) * u"1/s") == u"1/s"
4 changes: 2 additions & 2 deletions test/static/static_checks.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using DiffEqBase, ComponentArrays, AllocCheck, Test

u = ComponentArray(x=1.0, y=0.0, z=0.0)
u = ComponentArray(x = 1.0, y = 0.0, z = 0.0)
t = 0.0
@test length(check_allocs(DiffEqBase.ODE_DEFAULT_NORM, (typeof(u), typeof(t)))) == 0
@test length(check_allocs(DiffEqBase.ODE_DEFAULT_NORM, (typeof(u), typeof(t)))) == 0
Loading