From ceed5e7e0ef5450389f55de687e2c6d9fc937af3 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Sun, 25 May 2025 21:39:50 +0800 Subject: [PATCH] Allow fit_parameters keyword --- ext/DiffEqBaseChainRulesCoreExt.jl | 9 ++++++--- ext/DiffEqBaseMooncakeExt.jl | 18 +++++++----------- src/solve.jl | 12 +++++++----- test/downstream/unitful.jl | 4 ++-- test/static/static_checks.jl | 4 ++-- 5 files changed, 24 insertions(+), 23 deletions(-) diff --git a/ext/DiffEqBaseChainRulesCoreExt.jl b/ext/DiffEqBaseChainRulesCoreExt.jl index 9d44a6bee..2ea7a4cf2 100644 --- a/ext/DiffEqBaseChainRulesCoreExt.jl +++ b/ext/DiffEqBaseChainRulesCoreExt.jl @@ -2,7 +2,8 @@ module DiffEqBaseChainRulesCoreExt using DiffEqBase using DiffEqBase.SciMLBase -import DiffEqBase: numargs, AbstractSensitivityAlgorithm, AbstractDEProblem, set_mooncakeoriginator_if_mooncake +import DiffEqBase: numargs, AbstractSensitivityAlgorithm, AbstractDEProblem, + set_mooncakeoriginator_if_mooncake import ChainRulesCore import ChainRulesCore: NoTangent @@ -15,7 +16,8 @@ function ChainRulesCore.frule(::typeof(DiffEqBase.solve_up), prob, u0, p, args...; kwargs...) DiffEqBase._solve_forward( - prob, sensealg, u0, p, set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...; + prob, sensealg, u0, p, + set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...; kwargs...) end @@ -24,7 +26,8 @@ function ChainRulesCore.rrule(::typeof(DiffEqBase.solve_up), prob::AbstractDEPro u0, p, args...; kwargs...) DiffEqBase._solve_adjoint( - prob, sensealg, u0, p, set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...; + prob, sensealg, u0, p, + set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...; kwargs...) end diff --git a/ext/DiffEqBaseMooncakeExt.jl b/ext/DiffEqBaseMooncakeExt.jl index 862078be6..ad000e62e 100644 --- a/ext/DiffEqBaseMooncakeExt.jl +++ b/ext/DiffEqBaseMooncakeExt.jl @@ -3,33 +3,29 @@ module DiffEqBaseMooncakeExt using DiffEqBase, Mooncake using DiffEqBase: SciMLBase using SciMLBase: ADOriginator, MooncakeOriginator -Mooncake.@from_rrule( - Mooncake.MinimalCtx, +Mooncake.@from_rrule(Mooncake.MinimalCtx, Tuple{ typeof(DiffEqBase.solve_up), DiffEqBase.AbstractDEProblem, Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm}, Any, Any, - Any, + Any }, - true, - ) + true,) # Dispatch for auto-alg -Mooncake.@from_rrule( - Mooncake.MinimalCtx, +Mooncake.@from_rrule(Mooncake.MinimalCtx, Tuple{ typeof(DiffEqBase.solve_up), DiffEqBase.AbstractDEProblem, Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm}, Any, - Any, + Any }, - true, - ) + true,) Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any} Mooncake.@mooncake_overlay DiffEqBase.set_mooncakeoriginator_if_mooncake(x::ADOriginator) = MooncakeOriginator -end \ No newline at end of file +end diff --git a/src/solve.jl b/src/solve.jl index 829618556..23b2ddb1e 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -99,7 +99,9 @@ const allowedkeywords = (:dense, # Termination condition for solvers :termination_condition, # For AbstractAliasSpecifier - :alias) + :alias, + # Parameter estimation with BVP + :fit_parameters) const KWARGWARN_MESSAGE = """ Unrecognized keyword arguments found. @@ -541,7 +543,8 @@ end Get the innermost index provider using `SII.symbolic_container`. """ function _get_root_indp(indp) - if hasmethod(SII.symbolic_container, Tuple{typeof(indp)}) && (sc = SII.symbolic_container(indp)) !== indp + if hasmethod(SII.symbolic_container, Tuple{typeof(indp)}) && + (sc = SII.symbolic_container(indp)) !== indp return _get_root_indp(sc) end return indp @@ -748,7 +751,7 @@ function build_null_solution(prob::AbstractDEProblem, args...; prob, success = hack_null_solution_init(prob) retcode = success ? ReturnCode.Success : ReturnCode.InitialFailure - build_solution(prob, nothing, ts, timeseries; dense=true, retcode) + build_solution(prob, nothing, ts, timeseries; dense = true, retcode) end function build_null_solution( @@ -1139,7 +1142,6 @@ function solve(prob::NonlinearProblem, args...; sensealg = nothing, sensealg = prob.kwargs[:sensealg] end - if haskey(prob.kwargs, :alias_u0) @warn "The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`." alias_spec = NonlinearAliasSpecifier(alias_u0 = prob.kwargs[:alias_u0]) @@ -1152,7 +1154,7 @@ function solve(prob::NonlinearProblem, args...; sensealg = nothing, alias_spec = NonlinearAliasSpecifier(alias = prob.kwargs[:alias]) elseif haskey(kwargs, :alias) && kwargs[:alias] isa Bool alias_spec = NonlinearAliasSpecifier(alias = kwargs[:alias]) - end + end if haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa NonlinearAliasSpecifier alias_spec = prob.kwargs[:alias] diff --git a/test/downstream/unitful.jl b/test/downstream/unitful.jl index 6ee0b309e..89bf38836 100644 --- a/test/downstream/unitful.jl +++ b/test/downstream/unitful.jl @@ -5,5 +5,5 @@ intg = init(prob, Tsit5()) @test_nowarn step!(intg, 0.02u"s", true) @test DiffEqBase.unitfulvalue(u"1/s") == u"1/s" -@test DiffEqBase.value(ForwardDiff.Dual(1)*u"1/s") == 1 -@test DiffEqBase.unitfulvalue(ForwardDiff.Dual(1)*u"1/s") == u"1/s" +@test DiffEqBase.value(ForwardDiff.Dual(1) * u"1/s") == 1 +@test DiffEqBase.unitfulvalue(ForwardDiff.Dual(1) * u"1/s") == u"1/s" diff --git a/test/static/static_checks.jl b/test/static/static_checks.jl index 264cbff22..5c36aaa16 100644 --- a/test/static/static_checks.jl +++ b/test/static/static_checks.jl @@ -1,5 +1,5 @@ using DiffEqBase, ComponentArrays, AllocCheck, Test -u = ComponentArray(x=1.0, y=0.0, z=0.0) +u = ComponentArray(x = 1.0, y = 0.0, z = 0.0) t = 0.0 -@test length(check_allocs(DiffEqBase.ODE_DEFAULT_NORM, (typeof(u), typeof(t)))) == 0 \ No newline at end of file +@test length(check_allocs(DiffEqBase.ODE_DEFAULT_NORM, (typeof(u), typeof(t)))) == 0