diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index b75d60463..e2170a5ec 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -22,12 +22,14 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" +SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" [weakdeps] BandedMatrices = "aae01518-5342-5314-be14-df237901396f" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b" @@ -43,6 +45,7 @@ NonlinearSolveBaseLineSearchExt = "LineSearch" NonlinearSolveBaseLinearSolveExt = "LinearSolve" NonlinearSolveBaseSparseArraysExt = "SparseArrays" NonlinearSolveBaseSparseMatrixColoringsExt = "SparseMatrixColorings" +NonlinearSolveBaseChainRulesCoreExt = "ChainRulesCore" [compat] ADTypes = "1.9" @@ -50,6 +53,7 @@ Adapt = "4.1.0" Aqua = "0.8.7" ArrayInterface = "7.9" BandedMatrices = "1.5" +ChainRulesCore = "1" CommonSolve = "0.2.4" Compat = "4.15" ConcreteStructs = "0.2.3" @@ -71,6 +75,7 @@ RecursiveArrayTools = "3" SciMLBase = "2.92" SciMLJacobianOperators = "0.1.1" SciMLOperators = "0.4, 1.0" +SciMLStructures = "1.5" SparseArrays = "1.10" SparseMatrixColorings = "0.4.5" StaticArraysCore = "1.4" diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl new file mode 100644 index 000000000..b15f139be --- /dev/null +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl @@ -0,0 +1,33 @@ +module NonlinearSolveBaseChainRulesCoreExt + +using NonlinearSolveBase +using NonlinearSolveBase: AbstractNonlinearProblem +using SciMLBase +using SciMLBase: AbstractSensitivityAlgorithm + +import ChainRulesCore +import ChainRulesCore: NoTangent + +ChainRulesCore.@non_differentiable NonlinearSolveBase.checkkwargs(kwargshandle) + +function ChainRulesCore.frule(::typeof(NonlinearSolveBase.solve_up), prob, + sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, + u0, p, args...; originator = SciMLBase.ChainRulesOriginator(), + kwargs...) + NonlinearSolveBase._solve_forward( + prob, sensealg, u0, p, + originator, args...; + kwargs...) +end + +function ChainRulesCore.rrule(::typeof(NonlinearSolveBase.solve_up), prob::AbstractNonlinearProblem, + sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, + u0, p, args...; originator = SciMLBase.ChainRulesOriginator(), + kwargs...) + NonlinearSolveBase._solve_adjoint( + prob, sensealg, u0, p, + originator, args...; + kwargs...) +end + +end \ No newline at end of file diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl index 7cfc792e8..73d0d9075 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl @@ -195,4 +195,13 @@ NonlinearSolveBase.nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.valu @inline NonlinearSolveBase.pickchunksize(x) = pickchunksize(length(x)) @inline NonlinearSolveBase.pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) +eltypedual(x) = eltype(x) <: ForwardDiff.Dual +isdualtype(::Type{<:ForwardDiff.Dual}) = true + +function anyeltypedual( + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, + ::Type{Val{counter}} = Val{0}) where {counter} + anyeltypedual((prob.u0, prob.p)) +end + end diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index e3efc9ef5..bddf9c674 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -15,15 +15,17 @@ using StaticArraysCore: StaticArray, SMatrix, SArray, MArray using CommonSolve: CommonSolve, init using EnzymeCore: EnzymeCore using MaybeInplace: @bb -using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition +using RecursiveArrayTools: RecursiveArrayTools, AbstractVectorOfArray, ArrayPartition using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem, AbstractNonlinearAlgorithm, NonlinearProblem, NonlinearLeastSquaresProblem, NonlinearFunction, NLStats, LinearProblem, - LinearAliasSpecifier, ImmutableNonlinearProblem + LinearAliasSpecifier, ImmutableNonlinearProblem, NonlinearAliasSpecifier +import SciMLBase: solve, init, solve!, __init, __solve, wrap_sol, get_root_indp, isinplace, remake using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator using SciMLOperators: AbstractSciMLOperator, IdentityOperator using SymbolicIndexingInterface: SymbolicIndexingInterface +import SciMLStructures using LinearAlgebra: LinearAlgebra, Diagonal, norm, ldiv!, diagind, mul! using Markdown: @doc_str @@ -32,6 +34,11 @@ using Printf: @printf const DI = DifferentiationInterface const SII = SymbolicIndexingInterface +# Extension Functions +eltypedual(x) = false +promote_u0(::Nothing, p, t0) = nothing +isdualtype(::Type{T}) where {T} = false + include("public.jl") include("utils.jl") diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index 91b7a6aa6..c2c5b4774 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -1,3 +1,449 @@ +const allowedkeywords = (:dense, + :saveat, + :save_idxs, + :tstops, + :tspan, + :d_discontinuities, + :save_everystep, + :save_on, + :save_start, + :save_end, + :initialize_save, + :adaptive, + :abstol, + :reltol, + :dt, + :dtmax, + :dtmin, + :force_dtmin, + :internalnorm, + :controller, + :gamma, + :beta1, + :beta2, + :qmax, + :qmin, + :qsteady_min, + :qsteady_max, + :qoldinit, + :failfactor, + :calck, + :alias_u0, + :maxiters, + :maxtime, + :callback, + :isoutofdomain, + :unstable_check, + :verbose, + :merge_callbacks, + :progress, + :progress_steps, + :progress_name, + :progress_message, + :progress_id, + :timeseries_errors, + :dense_errors, + :weak_timeseries_errors, + :weak_dense_errors, + :wrap, + :calculate_error, + :initializealg, + :alg, + :save_noise, + :delta, + :seed, + :alg_hints, + :kwargshandle, + :trajectories, + :batch_size, + :sensealg, + :advance_to_tstop, + :stop_at_next_tstop, + :u0, + :p, + # These two are from the default algorithm handling + :default_set, + :second_time, + # This is for DiffEqDevTools + :prob_choice, + # Jump problems + :alias_jump, + # This is for copying/deepcopying noise in StochasticDiffEq + :alias_noise, + # This is for SimpleNonlinearSolve handling for batched Nonlinear Solves + :batch, + # Shooting method in BVP needs to differentiate between these two categories + :nlsolve_kwargs, + :odesolve_kwargs, + # If Solvers which internally use linsolve + :linsolve_kwargs, + # Solvers internally using EnsembleProblem + :ensemblealg, + # Fine Grained Control of Tracing (Storing and Logging) during Solve + :show_trace, + :trace_level, + :store_trace, + # Termination condition for solvers + :termination_condition, + # For AbstractAliasSpecifier + :alias, + # Parameter estimation with BVP + :fit_parameters) + +const KWARGWARN_MESSAGE = """ +Unrecognized keyword arguments found. +The only allowed keyword arguments to `solve` are: +$allowedkeywords + +See https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/ for more details. + +Set kwargshandle=KeywordArgError for an error message. +Set kwargshandle=KeywordArgSilent to ignore this message. +""" + +const KWARGERROR_MESSAGE = """ + Unrecognized keyword arguments found. + The only allowed keyword arguments to `solve` are: + $allowedkeywords + + See https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/ for more details. + """ + +struct CommonKwargError <: Exception + kwargs::Any +end + +function Base.showerror(io::IO, e::CommonKwargError) + println(io, KWARGERROR_MESSAGE) + notin = collect(map(x -> x ∉ allowedkeywords, keys(e.kwargs))) + unrecognized = collect(keys(e.kwargs))[notin] + print(io, "Unrecognized keyword arguments: ") + printstyled(io, unrecognized; bold = true, color = :red) + print(io, "\n\n") + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +@enum KeywordArgError KeywordArgWarn KeywordArgSilent + +const INCOMPATIBLE_U0_MESSAGE = """ + Initial condition incompatible with functional form. + Detected an in-place function with an initial condition of type Number or SArray. + This is incompatible because Numbers cannot be mutated, i.e. + `x = 2.0; y = 2.0; x .= y` will error. + + If using a immutable initial condition type, please use the out-of-place form. + I.e. define the function `du=f(u,p,t)` instead of attempting to "mutate" the immutable `du`. + + If your differential equation function was defined with multiple dispatches and one is + in-place, then the automatic detection will choose in-place. In this case, override the + choice in the problem constructor, i.e. `ODEProblem{false}(f,u0,tspan,p,kwargs...)`. + + For a longer discussion on mutability vs immutability and in-place vs out-of-place, see: + https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/#Example-Accelerating-a-Non-Stiff-Equation:-The-Lorenz-Equation + """ + +struct IncompatibleInitialConditionError <: Exception end + +function Base.showerror(io::IO, e::IncompatibleInitialConditionError) + print(io, INCOMPATIBLE_U0_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const NO_DEFAULT_ALGORITHM_MESSAGE = """ + Default algorithm choices require NonlinearSolve.jl. + Please specify an algorithm (e.g., `solve(prob, NewtonRaphson())` or + init(prob, NewtonRaphson()) or + import NonlinearSolve.jl directly. + + You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ + and its associated pages. + """ + +struct NoDefaultAlgorithmError <: Exception end + +function Base.showerror(io::IO, e::NoDefaultAlgorithmError) + print(io, NO_DEFAULT_ALGORITHM_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const NON_SOLVER_MESSAGE = """ + The arguments to solve are incorrect. + The second argument must be a solver choice, `solve(prob,alg)` + where `alg` is a `<: AbstractDEAlgorithm`, e.g. `Tsit5()`. + + Please double check the arguments being sent to the solver. + + You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ + and its associated pages. + """ + +struct NonSolverError <: Exception end + +function Base.showerror(io::IO, e::NonSolverError) + print(io, NON_SOLVER_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE = """ + Incompatible solver + automatic differentiation pairing. + The chosen automatic differentiation algorithm requires the ability + for compiler transforms on the code which is only possible on pure-Julia + solvers such as those from OrdinaryDiffEq.jl. Direct differentiation methods + which require this ability include: + + - Direct use of ForwardDiff.jl on the solver + - `ForwardDiffSensitivity`, `ReverseDiffAdjoint`, `TrackerAdjoint`, and `ZygoteAdjoint` + sensealg choices for adjoint differentiation. + + Either switch the choice of solver to a pure Julia method, or change the automatic + differentiation method to one that does not require such transformations. + + For more details on automatic differentiation, adjoint, and sensitivity analysis + of differential equations, see the documentation page: + + https://diffeq.sciml.ai/stable/analysis/sensitivity/ + """ + +struct DirectAutodiffError <: Exception end + +function Base.showerror(io::IO, e::DirectAutodiffError) + println(io, DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +struct EvalFunc{F} <: Function + f::F +end +(f::EvalFunc)(args...) = f.f(args...) + +""" +```julia +solve(prob::NonlinearProblem, alg::Union{AbstractNonlinearAlgorithm,Nothing}; kwargs...) +``` + +## Arguments + +The only positional argument is `alg` which is optional. By default, `alg = nothing`. +If `alg = nothing`, then `solve` dispatches to the NonlinearSolve.jl automated +algorithm selection (if `using NonlinearSolve` was done, otherwise it will +error with a `MethodError`). + +## Keyword Arguments + +The NonlinearSolve.jl universe has a large set of common arguments available +for the `solve` function. These arguments apply to `solve` on any problem type and +are only limited by limitations of the specific implementations. + +Many of the defaults depend on the algorithm or the package the algorithm derives +from. Not all of the interface is provided by every algorithm. +For more detailed information on the defaults and the available options +for specific algorithms / packages, see the manual pages for the solvers of specific +problems. + +#### Error Control + +* `abstol`: Absolute tolerance. +* `reltol`: Relative tolerance. + +### Miscellaneous + +* `maxiters`: Maximum number of iterations before stopping. Defaults to 1e5. +* `verbose`: Toggles whether warnings are thrown when the solver exits early. + Defaults to true. + +### Sensitivity Algorithms (`sensealg`) + +`sensealg` is used for choosing the way the automatic differentiation is performed. + For more information, see the documentation for SciMLSensitivity: + https://docs.sciml.ai/SciMLSensitivity/stable/ +""" +function solve(prob::AbstractNonlinearProblem, args...; sensealg = nothing, + u0 = nothing, p = nothing, wrap = Val(true), kwargs...) + if sensealg === nothing && haskey(prob.kwargs, :sensealg) + sensealg = prob.kwargs[:sensealg] + end + + if haskey(prob.kwargs, :alias_u0) + @warn "The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`." + alias_spec = NonlinearAliasSpecifier(alias_u0 = prob.kwargs[:alias_u0]) + elseif haskey(kwargs, :alias_u0) + @warn "The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`." + alias_spec = NonlinearAliasSpecifier(alias_u0 = kwargs[:alias_u0]) + end + + if haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa Bool + alias_spec = NonlinearAliasSpecifier(alias = prob.kwargs[:alias]) + elseif haskey(kwargs, :alias) && kwargs[:alias] isa Bool + alias_spec = NonlinearAliasSpecifier(alias = kwargs[:alias]) + end + + if haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa NonlinearAliasSpecifier + alias_spec = prob.kwargs[:alias] + elseif haskey(kwargs, :alias) && kwargs[:alias] isa NonlinearAliasSpecifier + alias_spec = kwargs[:alias] + else + alias_spec = NonlinearAliasSpecifier(alias_u0 = false) + end + + alias_u0 = alias_spec.alias_u0 + + u0 = u0 !== nothing ? u0 : prob.u0 + p = p !== nothing ? p : prob.p + + if wrap isa Val{true} + wrap_sol(solve_up(prob, + sensealg, + u0, + p, + args...; + alias_u0 = alias_u0, + originator = SciMLBase.ChainRulesOriginator(), + kwargs...)) + else + solve_up(prob, + sensealg, + u0, + p, + args...; + alias_u0 = alias_u0, + originator = SciMLBase.ChainRulesOriginator(), + kwargs...) + end +end + +function solve_up(prob::AbstractNonlinearProblem, sensealg, u0, p, + args...; originator = SciMLBase.ChainRulesOriginator(), + kwargs...) + alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs) + if isnothing(alg) || !(alg isa AbstractNonlinearSolveAlgorithm) # Default algorithm handling + _prob = get_concrete_problem(prob, true; u0 = u0, + p = p, kwargs...) + solve_call(_prob, args...; kwargs...) + else + _prob = get_concrete_problem(prob, true; u0 = u0, p = p, kwargs...) + #check_prob_alg_pairing(_prob, alg) # use alg for improved inference + if length(args) > 1 + solve_call(_prob, alg, Base.tail(args)...; kwargs...) + else + solve_call(_prob, alg; kwargs...) + end + end +end + +function solve_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing, + kwargs...) + kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle + kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ? + _prob.kwargs[:kwargshandle] : kwargshandle + + if has_kwargs(_prob) + if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) + kwargs_temp = NamedTuple{ + Base.diff_names(Base._nt_names(values(kwargs)), + (:callback,))}(values(kwargs)) + callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( + _prob.kwargs[:callback], + values(kwargs).callback),)) + kwargs = merge(kwargs_temp, callbacks) + end + kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) + end + + checkkwargs(kwargshandle; kwargs...) + if isdefined(_prob, :u0) + if _prob.u0 isa Array + if !isconcretetype(RecursiveArrayTools.recursive_unitless_eltype(_prob.u0)) + throw(NonConcreteEltypeError(RecursiveArrayTools.recursive_unitless_eltype(_prob.u0))) + end + + if !(eltype(_prob.u0) <: Number) && !(eltype(_prob.u0) <: Enum) && + !(_prob.u0 isa AbstractVector{<:AbstractArray} && _prob isa BVProblem) + # Allow Enums for FunctionMaps, make into a trait in the future + # BVPs use Vector of Arrays for initial guesses + throw(NonNumberEltypeError(eltype(_prob.u0))) + end + end + + if _prob.u0 === nothing + return build_null_solution(_prob, args...; kwargs...) + end + end + + if hasfield(typeof(_prob), :f) && hasfield(typeof(_prob.f), :f) && + _prob.f.f isa EvalFunc + Base.invokelatest(__solve, _prob, args...; kwargs...)#::T + else + __solve(_prob, args...; kwargs...)#::T + end +end + +function init( + prob::AbstractNonlinearProblem, args...; sensealg = nothing, + u0 = nothing, p = nothing, kwargs...) + if sensealg === nothing && has_kwargs(prob) && haskey(prob.kwargs, :sensealg) + sensealg = prob.kwargs[:sensealg] + end + + u0 = u0 !== nothing ? u0 : prob.u0 + p = p !== nothing ? p : prob.p + + init_up(prob, sensealg, u0, p, args...; kwargs...) +end + +function init_up(prob::AbstractNonlinearProblem, sensealg, u0, p, args...; kwargs...) + alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs) + if isnothing(alg) || !(alg isa AbstractNonlinearAlgorithm) # Default algorithm handling + _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, + p = p, kwargs...) + init_call(_prob, args...; kwargs...) + else + tstops = get(kwargs, :tstops, nothing) + if tstops === nothing && has_kwargs(prob) + tstops = get(prob.kwargs, :tstops, nothing) + end + if !(tstops isa Union{Nothing, AbstractArray, Tuple, Real}) && + !SciMLBase.allows_late_binding_tstops(alg) + throw(LateBindingTstopsNotSupportedError()) + end + _prob = get_concrete_problem(prob, true; u0 = u0, p = p, kwargs...) + #check_prob_alg_pairing(_prob, alg) # alg for improved inference + if length(args) > 1 + init_call(_prob, alg, Base.tail(args)...; kwargs...) + else + init_call(_prob, alg; kwargs...) + end + end +end + +function init_call(_prob, args...; merge_callbacks=true, kwargshandle=nothing, + kwargs...) + kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle + kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ? + _prob.kwargs[:kwargshandle] : kwargshandle + + if has_kwargs(_prob) + if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) + kwargs_temp = NamedTuple{ + Base.diff_names(Base._nt_names(values(kwargs)), + (:callback,))}(values(kwargs)) + callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( + _prob.kwargs[:callback], + values(kwargs).callback),)) + kwargs = merge(kwargs_temp, callbacks) + end + kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) + end + + checkkwargs(kwargshandle; kwargs...) + + if hasfield(typeof(_prob), :f) && hasfield(typeof(_prob.f), :f) && + _prob.f.f isa EvalFunc + Base.invokelatest(__init, _prob, args...; kwargs...)#::T + else + __init(_prob, args...; kwargs...)#::T + end +end + function SciMLBase.__solve( prob::AbstractNonlinearProblem, alg::AbstractNonlinearSolveAlgorithm, args...; kwargs... @@ -127,6 +573,18 @@ function SciMLBase.__solve( __generated_polysolve(prob, alg, args...; kwargs...) end +function SciMLBase.__solve( + prob::AbstractNonlinearProblem, args...; default_set = false, second_time = false, + kwargs...) + if second_time + throw(NoDefaultAlgorithmError()) + elseif length(args) > 0 && !(first(args) isa AbstractNonlinearAlgorithm) + throw(NonSolverError()) + else + __solve(prob, nothing, args...; default_set = false, second_time = true, kwargs...) + end +end + @generated function __generated_polysolve( prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm{Val{N}}, args...; stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true, @@ -297,6 +755,10 @@ SII.state_values(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob get_u(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob) +has_kwargs(_prob::AbstractNonlinearProblem) = has_kwargs(typeof(_prob)) +Base.@pure __has_kwargs(::Type{T}) where {T} = :kwargs ∈ fieldnames(T) +has_kwargs(::Type{T}) where {T} = __has_kwargs(T) + function SciMLBase.reinit!( cache::NonlinearSolveNoInitCache, u0 = cache.prob.u0; p = cache.prob.p, kwargs... ) @@ -328,3 +790,270 @@ function CommonSolve.solve!(cache::NonlinearSolveNoInitCache) end return CommonSolve.solve(cache.prob, cache.alg, cache.args...; cache.kwargs...) end + +function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callbacks = true, + kwargs...) + alg = extract_alg(args, kwargs, prob.kwargs) + if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling + _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, + p = p, kwargs...) + else + _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) + end + + if has_kwargs(_prob) + if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) + kwargs_temp = NamedTuple{ + Base.diff_names(Base._nt_names(values(kwargs)), + (:callback,))}(values(kwargs)) + callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( + _prob.kwargs[:callback], + values(kwargs).callback),)) + kwargs = merge(kwargs_temp, callbacks) + end + kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) + end + + if length(args) > 1 + _concrete_solve_adjoint(_prob, alg, sensealg, u0, p, originator, + Base.tail(args)...; kwargs...) + else + _concrete_solve_adjoint(_prob, alg, sensealg, u0, p, originator; kwargs...) + end +end + +function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callbacks = true, + kwargs...) + alg = extract_alg(args, kwargs, prob.kwargs) + if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling + _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, + p = p, kwargs...) + else + _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) + end + + if has_kwargs(_prob) + if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) + kwargs_temp = NamedTuple{ + Base.diff_names(Base._nt_names(values(kwargs)), + (:callback,))}(values(kwargs)) + callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( + _prob.kwargs[:callback], + values(kwargs).callback),)) + kwargs = merge(kwargs_temp, callbacks) + end + kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) + end + + if length(args) > 1 + _concrete_solve_forward(_prob, alg, sensealg, u0, p, originator, + Base.tail(args)...; kwargs...) + else + _concrete_solve_forward(_prob, alg, sensealg, u0, p, originator; kwargs...) + end +end + + +function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...) + oldprob = prob + prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...) + if prob !== oldprob + kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) + end + p = get_concrete_p(prob, kwargs) + u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) + u0 = promote_u0(u0, p, nothing) + remake(prob; u0 = u0, p = p) +end + +function get_concrete_problem(prob::NonlinearLeastSquaresProblem, isadapt; kwargs...) + oldprob = prob + prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...) + if prob !== oldprob + kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) + end + p = get_concrete_p(prob, kwargs) + u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) + u0 = promote_u0(u0, p, nothing) + remake(prob; u0 = u0, p = p) +end + +""" +Given the index provider `indp` used to construct the problem `prob` being solved, return +an updated `prob` to be used for solving. All implementations should accept arbitrary +keyword arguments. + +Should be called before the problem is solved, after performing type-promotion on the +problem. If the returned problem is not `===` the provided `prob`, it is assumed to +contain the `u0` and `p` passed as keyword arguments. + +# Keyword Arguments + +- `u0`, `p`: Override values for `state_values(prob)` and `parameter_values(prob)` which + should be used instead of the ones in `prob`. +""" +function get_updated_symbolic_problem(indp, prob; kw...) + return prob +end + +function build_null_solution( + prob::NonlinearProblem, + args...; + saveat = (), + save_everystep = true, + save_on = true, + save_start = save_everystep || isempty(saveat) || + saveat isa Number || prob.tspan[1] in saveat, + save_end = true, + kwargs...) + prob, success = hack_null_solution_init(prob) + retcode = success ? ReturnCode.Success : ReturnCode.InitialFailure + SciMLBase.build_solution(prob, nothing, Float64[], nothing; retcode) +end + +function build_null_solution( + prob::NonlinearLeastSquaresProblem, + args...; abstol = 1e-6, kwargs...) + prob, success = hack_null_solution_init(prob) + retcode = success ? ReturnCode.Success : ReturnCode.InitialFailure + + if isinplace(prob) + resid = isnothing(prob.f.resid_prototype) ? Float64[] : copy(prob.f.resid_prototype) + prob.f(resid, prob.u0, prob.p) + else + resid = prob.f(prob.f.resid_prototype, prob.p) + end + + if success + retcode = norm(resid) < abstol ? ReturnCode.Success : ReturnCode.Failure + end + + SciMLBase.build_solution(prob, nothing, Float64[], resid; retcode) +end + +@inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) + if isempty(solve_args) || isnothing(first(solve_args)) + if haskey(solve_kwargs, :alg) + solve_kwargs[:alg] + elseif haskey(prob_kwargs, :alg) + prob_kwargs[:alg] + else + nothing + end + elseif first(solve_args) isa SciMLBase.AbstractSciMLAlgorithm && + !(first(solve_args) isa SciMLBase.EnsembleAlgorithm) + first(solve_args) + else + nothing + end +end + +function get_concrete_u0(prob, isadapt, t0, kwargs) + if eval_u0(prob.u0) + u0 = prob.u0(prob.p, t0) + elseif haskey(kwargs, :u0) + u0 = kwargs[:u0] + else + u0 = prob.u0 + end + + isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) + + _u0 = handle_distribution_u0(u0) + + if isinplace(prob) && (_u0 isa Number || _u0 isa SArray) + throw(IncompatibleInitialConditionError()) + end + + if _u0 isa Tuple + throw(TupleStateError()) + end + + _u0 +end + +function get_concrete_p(prob, kwargs) + if haskey(kwargs, :p) + p = kwargs[:p] + else + p = prob.p + end +end + +eval_u0(u0::Function) = true +eval_u0(u0) = false + +handle_distribution_u0(_u0) = _u0 + +anyeltypedual(x) = anyeltypedual(x, Val{0}) +anyeltypedual(x, counter) = Any + +function promote_u0(u0, p, t0) + if SciMLStructures.isscimlstructure(p) + _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] + if !isequal(_p, p) + return promote_u0(u0, _p, t0) + end + end + Tu = eltype(u0) + if isdualtype(Tu) + return u0 + end + Tp = anyeltypedual(p, Val{0}) + if Tp == Any + Tp = Tu + end + Tt = anyeltypedual(t0, Val{0}) + if Tt == Any + Tt = Tu + end + Tcommon = promote_type(Tu, Tp, Tt) + return if isdualtype(Tcommon) + Tcommon.(u0) + else + u0 + end +end + +function promote_u0(u0::AbstractArray{<:Complex}, p, t0) + if SciMLStructures.isscimlstructure(p) + _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] + if !isequal(_p, p) + return promote_u0(u0, _p, t0) + end + end + Tu = real(eltype(u0)) + if isdualtype(Tu) + return u0 + end + Tp = anyeltypedual(p, Val{0}) + if Tp == Any + Tp = Tu + end + Tt = anyeltypedual(t0, Val{0}) + if Tt == Any + Tt = Tu + end + Tcommon = promote_type(eltype(u0), Tp, Tt) + return if isdualtype(real(Tcommon)) + Tcommon.(u0) + else + u0 + end +end + +function checkkwargs(kwargshandle; kwargs...) + if any(x -> x ∉ allowedkeywords, keys(kwargs)) + if kwargshandle == KeywordArgError + throw(CommonKwargError(kwargs)) + elseif kwargshandle == KeywordArgWarn + @warn KWARGWARN_MESSAGE + unrecognized = setdiff(keys(kwargs), allowedkeywords) + print("Unrecognized keyword arguments: ") + printstyled(unrecognized; bold = true, color = :red) + print("\n\n") + else + @assert kwargshandle == KeywordArgSilent + end + end +end \ No newline at end of file diff --git a/lib/NonlinearSolveBase/src/utils.jl b/lib/NonlinearSolveBase/src/utils.jl index 05bc71158..18d78c451 100644 --- a/lib/NonlinearSolveBase/src/utils.jl +++ b/lib/NonlinearSolveBase/src/utils.jl @@ -320,4 +320,6 @@ function clean_sprint_struct(x, indent::Int) return "$(name)(\n$(spacing)$(join(modifiers, ",\n$(spacing)"))\n$(spacing_last))" end +set_mooncakeoriginator_if_mooncake(x::SciMLBase.ADOriginator) = x + end diff --git a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl index 57a1f0105..79bc2faac 100644 --- a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl +++ b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl @@ -12,7 +12,7 @@ using LineSearch: BackTracking using StaticArraysCore: SArray using CommonSolve: CommonSolve -using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches +#using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches using LinearSolve: LinearSolve # Trigger Linear Solve extension in NonlinearSolveBase using MaybeInplace: @bb using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm, diff --git a/lib/NonlinearSolveQuasiNewton/src/NonlinearSolveQuasiNewton.jl b/lib/NonlinearSolveQuasiNewton/src/NonlinearSolveQuasiNewton.jl index 167f1fa85..fd55ca034 100644 --- a/lib/NonlinearSolveQuasiNewton/src/NonlinearSolveQuasiNewton.jl +++ b/lib/NonlinearSolveQuasiNewton/src/NonlinearSolveQuasiNewton.jl @@ -8,7 +8,7 @@ using ArrayInterface: ArrayInterface using StaticArraysCore: StaticArray, Size, MArray using CommonSolve: CommonSolve -using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches +#using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches using LinearAlgebra: LinearAlgebra, Diagonal, dot, diag using LinearSolve: LinearSolve # Trigger Linear Solve extension in NonlinearSolveBase using MaybeInplace: @bb diff --git a/lib/NonlinearSolveSpectralMethods/src/NonlinearSolveSpectralMethods.jl b/lib/NonlinearSolveSpectralMethods/src/NonlinearSolveSpectralMethods.jl index 93a620761..c0a6bf2e9 100644 --- a/lib/NonlinearSolveSpectralMethods/src/NonlinearSolveSpectralMethods.jl +++ b/lib/NonlinearSolveSpectralMethods/src/NonlinearSolveSpectralMethods.jl @@ -5,7 +5,7 @@ using Reexport: @reexport using PrecompileTools: @compile_workload, @setup_workload using CommonSolve: CommonSolve -using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches +#using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches using LineSearch: RobustNonMonotoneLineSearch using MaybeInplace: @bb using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm, diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl index 4954ffb26..5326b0a88 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl @@ -1,6 +1,6 @@ module SimpleNonlinearSolveDiffEqBaseExt -using DiffEqBase: DiffEqBase +#using DiffEqBase: DiffEqBase using SimpleNonlinearSolve: SimpleNonlinearSolve diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index e7bbf82b2..964ced53a 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -8,7 +8,7 @@ using FastClosures: @closure using ADTypes: ADTypes using ArrayInterface: ArrayInterface using CommonSolve: CommonSolve, init, solve, solve! -using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches +#using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches using LinearAlgebra: LinearAlgebra using LineSearch: BackTracking using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm,